+ PhysCtrl achieves controllable and physics-grounded video generation from an initial force.
+
+
+
+
+
+
+
+
+
Abstract
+
+
+ Existing video generation models excel at producing photo-realistic videos from text or images, but often
+ lack physical plausibility and 3D controllability. To overcome these limitations, we introduce PhysCtrl, a
+ novel framework for physics-grounded image-to-video generation with physical parameters and force control.
+ At its core is a generative physics network that learns the distribution of physical dynamics across four
+ materials (elastic, sand, plasticine, and rigid) via a diffusion model conditioned on physics parameters
+ and applied forces. We represent physical dynamics as 3D point trajectories and train on a large-scale
+ synthetic dataset of 550K animations generated by physics simulators. We enhance the diffusion model with
+ a novel spatiotemporal attention block that emulates particle interactions and incorporates physics-based
+ constraints during training to enforce physical plausibility. Experiments show that PhysCtrl generates
+ realistic, physics-grounded motion trajectories which, when used to drive image-to-video models, yield
+ high-fidelity, controllable videos that outperform existing methods in both visual quality and physical
+ plausibility.
+
+
+
+
+
+
+
+
+
+
+
+
+
+
Pipeline
+
+
+
+ Given a single image, we lift the object in that image into 3D points. We train a diffusion-based
+ trajectory
+ generation model conditioned on physics parameters and external force for motion generation, which are
+ then
+ used as strong physics-grounded guidance for image-to-video generation.
+
+
+
+
+
+
+
+
+
+
+
Force Control
+
+
+
+
+
+
+
+
+
+
+
+
Material Control
+
+
+
+
+
+
+
+
+
+
+
+
Comparison
+
+
+
+
+
+ A pair of wireless headphones rests on a white table before lifting into the air, as if there is an
+ invisble force applied to its handle.
+
+
+
+
+
+
ObjCtrl
+
+
+
+
DragAnything
+
+
+
+
CogVideo
+
+
+
+
Wan2.2
+
+
+
+
Ours
+
+
+
+
+
+
+
+
+ A yellow plasticine dinasour toy free falls to the ground due to gravity. It has no deformation before
+ it touches the ground. After it touches the ground, it deforms.
+
+
+
+
+
+
ObjCtrl
+
+
+
+
DragAnything
+
+
+
+
CogVideo
+
+
+
+
Wan2.2
+
+
+
+
Ours
+
+
+
+
+
+
+
+
+
+ the penguin is fully lifted upwards and float into the air with a natural motion, as if there is a force
+ applied onto its left wing. No webbed feet, realistic claws and flippers.
+
+
+
+
+
+
ObjCtrl
+
+
+
+
DragAnything
+
+
+
+
CogVideo
+
+
+
+
Wan2.2
+
+
+
+
Ours
+
+
+
+
+
+
+
+
+ A black cylindrical pipe lies on a wooden surface before rising and bending at a sharp angle. The
+ transformation is smooth and fluid, as if an invisible upward force is applied in the middle of the
+ pipe.
+
+
+
+
+
+
ObjCtrl
+
+
+
+
DragAnything
+
+
+
+
CogVideo
+
+
+
+
Wan2.2
+
+
+
+
Ours
+
+
+
+
+
+
+
+
+
+
+
+
+
+
BibTeX
+
+@inproceedings{physctrl2025,
+ Author = {Chen Wang* and Chuhao Chen* and Yiming Huang and Zhiyang Dou and Yuan Liu and Jiatao Gu and Lingjie Liu},
+ Title = {PhysCtrl: Generative Physics for Controllable and Physics-Grounded Video Generation},
+ Year = {2025},
+ booktitle={NeurIPS},
+}
+
'}},function(t,e,i){"use strict";e.a=function(){return''}}]).default});
\ No newline at end of file
diff --git a/docs/static/js/bulma-slider.js b/docs/static/js/bulma-slider.js
new file mode 100644
index 0000000000000000000000000000000000000000..c6718de5c5ae59d2c22141a147f5afba41af9cbb
--- /dev/null
+++ b/docs/static/js/bulma-slider.js
@@ -0,0 +1,461 @@
+(function webpackUniversalModuleDefinition(root, factory) {
+ if(typeof exports === 'object' && typeof module === 'object')
+ module.exports = factory();
+ else if(typeof define === 'function' && define.amd)
+ define([], factory);
+ else if(typeof exports === 'object')
+ exports["bulmaSlider"] = factory();
+ else
+ root["bulmaSlider"] = factory();
+})(typeof self !== 'undefined' ? self : this, function() {
+return /******/ (function(modules) { // webpackBootstrap
+/******/ // The module cache
+/******/ var installedModules = {};
+/******/
+/******/ // The require function
+/******/ function __webpack_require__(moduleId) {
+/******/
+/******/ // Check if module is in cache
+/******/ if(installedModules[moduleId]) {
+/******/ return installedModules[moduleId].exports;
+/******/ }
+/******/ // Create a new module (and put it into the cache)
+/******/ var module = installedModules[moduleId] = {
+/******/ i: moduleId,
+/******/ l: false,
+/******/ exports: {}
+/******/ };
+/******/
+/******/ // Execute the module function
+/******/ modules[moduleId].call(module.exports, module, module.exports, __webpack_require__);
+/******/
+/******/ // Flag the module as loaded
+/******/ module.l = true;
+/******/
+/******/ // Return the exports of the module
+/******/ return module.exports;
+/******/ }
+/******/
+/******/
+/******/ // expose the modules object (__webpack_modules__)
+/******/ __webpack_require__.m = modules;
+/******/
+/******/ // expose the module cache
+/******/ __webpack_require__.c = installedModules;
+/******/
+/******/ // define getter function for harmony exports
+/******/ __webpack_require__.d = function(exports, name, getter) {
+/******/ if(!__webpack_require__.o(exports, name)) {
+/******/ Object.defineProperty(exports, name, {
+/******/ configurable: false,
+/******/ enumerable: true,
+/******/ get: getter
+/******/ });
+/******/ }
+/******/ };
+/******/
+/******/ // getDefaultExport function for compatibility with non-harmony modules
+/******/ __webpack_require__.n = function(module) {
+/******/ var getter = module && module.__esModule ?
+/******/ function getDefault() { return module['default']; } :
+/******/ function getModuleExports() { return module; };
+/******/ __webpack_require__.d(getter, 'a', getter);
+/******/ return getter;
+/******/ };
+/******/
+/******/ // Object.prototype.hasOwnProperty.call
+/******/ __webpack_require__.o = function(object, property) { return Object.prototype.hasOwnProperty.call(object, property); };
+/******/
+/******/ // __webpack_public_path__
+/******/ __webpack_require__.p = "";
+/******/
+/******/ // Load entry module and return exports
+/******/ return __webpack_require__(__webpack_require__.s = 0);
+/******/ })
+/************************************************************************/
+/******/ ([
+/* 0 */
+/***/ (function(module, __webpack_exports__, __webpack_require__) {
+
+"use strict";
+Object.defineProperty(__webpack_exports__, "__esModule", { value: true });
+/* harmony export (binding) */ __webpack_require__.d(__webpack_exports__, "isString", function() { return isString; });
+/* harmony import */ var __WEBPACK_IMPORTED_MODULE_0__events__ = __webpack_require__(1);
+var _extends = Object.assign || function (target) { for (var i = 1; i < arguments.length; i++) { var source = arguments[i]; for (var key in source) { if (Object.prototype.hasOwnProperty.call(source, key)) { target[key] = source[key]; } } } return target; };
+
+var _createClass = function () { function defineProperties(target, props) { for (var i = 0; i < props.length; i++) { var descriptor = props[i]; descriptor.enumerable = descriptor.enumerable || false; descriptor.configurable = true; if ("value" in descriptor) descriptor.writable = true; Object.defineProperty(target, descriptor.key, descriptor); } } return function (Constructor, protoProps, staticProps) { if (protoProps) defineProperties(Constructor.prototype, protoProps); if (staticProps) defineProperties(Constructor, staticProps); return Constructor; }; }();
+
+var _typeof = typeof Symbol === "function" && typeof Symbol.iterator === "symbol" ? function (obj) { return typeof obj; } : function (obj) { return obj && typeof Symbol === "function" && obj.constructor === Symbol && obj !== Symbol.prototype ? "symbol" : typeof obj; };
+
+function _classCallCheck(instance, Constructor) { if (!(instance instanceof Constructor)) { throw new TypeError("Cannot call a class as a function"); } }
+
+function _possibleConstructorReturn(self, call) { if (!self) { throw new ReferenceError("this hasn't been initialised - super() hasn't been called"); } return call && (typeof call === "object" || typeof call === "function") ? call : self; }
+
+function _inherits(subClass, superClass) { if (typeof superClass !== "function" && superClass !== null) { throw new TypeError("Super expression must either be null or a function, not " + typeof superClass); } subClass.prototype = Object.create(superClass && superClass.prototype, { constructor: { value: subClass, enumerable: false, writable: true, configurable: true } }); if (superClass) Object.setPrototypeOf ? Object.setPrototypeOf(subClass, superClass) : subClass.__proto__ = superClass; }
+
+
+
+var isString = function isString(unknown) {
+ return typeof unknown === 'string' || !!unknown && (typeof unknown === 'undefined' ? 'undefined' : _typeof(unknown)) === 'object' && Object.prototype.toString.call(unknown) === '[object String]';
+};
+
+var bulmaSlider = function (_EventEmitter) {
+ _inherits(bulmaSlider, _EventEmitter);
+
+ function bulmaSlider(selector) {
+ var options = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : {};
+
+ _classCallCheck(this, bulmaSlider);
+
+ var _this = _possibleConstructorReturn(this, (bulmaSlider.__proto__ || Object.getPrototypeOf(bulmaSlider)).call(this));
+
+ _this.element = typeof selector === 'string' ? document.querySelector(selector) : selector;
+ // An invalid selector or non-DOM node has been provided.
+ if (!_this.element) {
+ throw new Error('An invalid selector or non-DOM node has been provided.');
+ }
+
+ _this._clickEvents = ['click'];
+ /// Set default options and merge with instance defined
+ _this.options = _extends({}, options);
+
+ _this.onSliderInput = _this.onSliderInput.bind(_this);
+
+ _this.init();
+ return _this;
+ }
+
+ /**
+ * Initiate all DOM element containing selector
+ * @method
+ * @return {Array} Array of all slider instances
+ */
+
+
+ _createClass(bulmaSlider, [{
+ key: 'init',
+
+
+ /**
+ * Initiate plugin
+ * @method init
+ * @return {void}
+ */
+ value: function init() {
+ this._id = 'bulmaSlider' + new Date().getTime() + Math.floor(Math.random() * Math.floor(9999));
+ this.output = this._findOutputForSlider();
+
+ this._bindEvents();
+
+ if (this.output) {
+ if (this.element.classList.contains('has-output-tooltip')) {
+ // Get new output position
+ var newPosition = this._getSliderOutputPosition();
+
+ // Set output position
+ this.output.style['left'] = newPosition.position;
+ }
+ }
+
+ this.emit('bulmaslider:ready', this.element.value);
+ }
+ }, {
+ key: '_findOutputForSlider',
+ value: function _findOutputForSlider() {
+ var _this2 = this;
+
+ var result = null;
+ var outputs = document.getElementsByTagName('output') || [];
+
+ Array.from(outputs).forEach(function (output) {
+ if (output.htmlFor == _this2.element.getAttribute('id')) {
+ result = output;
+ return true;
+ }
+ });
+ return result;
+ }
+ }, {
+ key: '_getSliderOutputPosition',
+ value: function _getSliderOutputPosition() {
+ // Update output position
+ var newPlace, minValue;
+
+ var style = window.getComputedStyle(this.element, null);
+ // Measure width of range input
+ var sliderWidth = parseInt(style.getPropertyValue('width'), 10);
+
+ // Figure out placement percentage between left and right of input
+ if (!this.element.getAttribute('min')) {
+ minValue = 0;
+ } else {
+ minValue = this.element.getAttribute('min');
+ }
+ var newPoint = (this.element.value - minValue) / (this.element.getAttribute('max') - minValue);
+
+ // Prevent bubble from going beyond left or right (unsupported browsers)
+ if (newPoint < 0) {
+ newPlace = 0;
+ } else if (newPoint > 1) {
+ newPlace = sliderWidth;
+ } else {
+ newPlace = sliderWidth * newPoint;
+ }
+
+ return {
+ 'position': newPlace + 'px'
+ };
+ }
+
+ /**
+ * Bind all events
+ * @method _bindEvents
+ * @return {void}
+ */
+
+ }, {
+ key: '_bindEvents',
+ value: function _bindEvents() {
+ if (this.output) {
+ // Add event listener to update output when slider value change
+ this.element.addEventListener('input', this.onSliderInput, false);
+ }
+ }
+ }, {
+ key: 'onSliderInput',
+ value: function onSliderInput(e) {
+ e.preventDefault();
+
+ if (this.element.classList.contains('has-output-tooltip')) {
+ // Get new output position
+ var newPosition = this._getSliderOutputPosition();
+
+ // Set output position
+ this.output.style['left'] = newPosition.position;
+ }
+
+ // Check for prefix and postfix
+ var prefix = this.output.hasAttribute('data-prefix') ? this.output.getAttribute('data-prefix') : '';
+ var postfix = this.output.hasAttribute('data-postfix') ? this.output.getAttribute('data-postfix') : '';
+
+ // Update output with slider value
+ this.output.value = prefix + this.element.value + postfix;
+
+ this.emit('bulmaslider:ready', this.element.value);
+ }
+ }], [{
+ key: 'attach',
+ value: function attach() {
+ var _this3 = this;
+
+ var selector = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : 'input[type="range"].slider';
+ var options = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : {};
+
+ var instances = new Array();
+
+ var elements = isString(selector) ? document.querySelectorAll(selector) : Array.isArray(selector) ? selector : [selector];
+ elements.forEach(function (element) {
+ if (typeof element[_this3.constructor.name] === 'undefined') {
+ var instance = new bulmaSlider(element, options);
+ element[_this3.constructor.name] = instance;
+ instances.push(instance);
+ } else {
+ instances.push(element[_this3.constructor.name]);
+ }
+ });
+
+ return instances;
+ }
+ }]);
+
+ return bulmaSlider;
+}(__WEBPACK_IMPORTED_MODULE_0__events__["a" /* default */]);
+
+/* harmony default export */ __webpack_exports__["default"] = (bulmaSlider);
+
+/***/ }),
+/* 1 */
+/***/ (function(module, __webpack_exports__, __webpack_require__) {
+
+"use strict";
+var _createClass = function () { function defineProperties(target, props) { for (var i = 0; i < props.length; i++) { var descriptor = props[i]; descriptor.enumerable = descriptor.enumerable || false; descriptor.configurable = true; if ("value" in descriptor) descriptor.writable = true; Object.defineProperty(target, descriptor.key, descriptor); } } return function (Constructor, protoProps, staticProps) { if (protoProps) defineProperties(Constructor.prototype, protoProps); if (staticProps) defineProperties(Constructor, staticProps); return Constructor; }; }();
+
+function _classCallCheck(instance, Constructor) { if (!(instance instanceof Constructor)) { throw new TypeError("Cannot call a class as a function"); } }
+
+var EventEmitter = function () {
+ function EventEmitter() {
+ var listeners = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : [];
+
+ _classCallCheck(this, EventEmitter);
+
+ this._listeners = new Map(listeners);
+ this._middlewares = new Map();
+ }
+
+ _createClass(EventEmitter, [{
+ key: "listenerCount",
+ value: function listenerCount(eventName) {
+ if (!this._listeners.has(eventName)) {
+ return 0;
+ }
+
+ var eventListeners = this._listeners.get(eventName);
+ return eventListeners.length;
+ }
+ }, {
+ key: "removeListeners",
+ value: function removeListeners() {
+ var _this = this;
+
+ var eventName = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : null;
+ var middleware = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : false;
+
+ if (eventName !== null) {
+ if (Array.isArray(eventName)) {
+ name.forEach(function (e) {
+ return _this.removeListeners(e, middleware);
+ });
+ } else {
+ this._listeners.delete(eventName);
+
+ if (middleware) {
+ this.removeMiddleware(eventName);
+ }
+ }
+ } else {
+ this._listeners = new Map();
+ }
+ }
+ }, {
+ key: "middleware",
+ value: function middleware(eventName, fn) {
+ var _this2 = this;
+
+ if (Array.isArray(eventName)) {
+ name.forEach(function (e) {
+ return _this2.middleware(e, fn);
+ });
+ } else {
+ if (!Array.isArray(this._middlewares.get(eventName))) {
+ this._middlewares.set(eventName, []);
+ }
+
+ this._middlewares.get(eventName).push(fn);
+ }
+ }
+ }, {
+ key: "removeMiddleware",
+ value: function removeMiddleware() {
+ var _this3 = this;
+
+ var eventName = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : null;
+
+ if (eventName !== null) {
+ if (Array.isArray(eventName)) {
+ name.forEach(function (e) {
+ return _this3.removeMiddleware(e);
+ });
+ } else {
+ this._middlewares.delete(eventName);
+ }
+ } else {
+ this._middlewares = new Map();
+ }
+ }
+ }, {
+ key: "on",
+ value: function on(name, callback) {
+ var _this4 = this;
+
+ var once = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : false;
+
+ if (Array.isArray(name)) {
+ name.forEach(function (e) {
+ return _this4.on(e, callback);
+ });
+ } else {
+ name = name.toString();
+ var split = name.split(/,|, | /);
+
+ if (split.length > 1) {
+ split.forEach(function (e) {
+ return _this4.on(e, callback);
+ });
+ } else {
+ if (!Array.isArray(this._listeners.get(name))) {
+ this._listeners.set(name, []);
+ }
+
+ this._listeners.get(name).push({ once: once, callback: callback });
+ }
+ }
+ }
+ }, {
+ key: "once",
+ value: function once(name, callback) {
+ this.on(name, callback, true);
+ }
+ }, {
+ key: "emit",
+ value: function emit(name, data) {
+ var _this5 = this;
+
+ var silent = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : false;
+
+ name = name.toString();
+ var listeners = this._listeners.get(name);
+ var middlewares = null;
+ var doneCount = 0;
+ var execute = silent;
+
+ if (Array.isArray(listeners)) {
+ listeners.forEach(function (listener, index) {
+ // Start Middleware checks unless we're doing a silent emit
+ if (!silent) {
+ middlewares = _this5._middlewares.get(name);
+ // Check and execute Middleware
+ if (Array.isArray(middlewares)) {
+ middlewares.forEach(function (middleware) {
+ middleware(data, function () {
+ var newData = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : null;
+
+ if (newData !== null) {
+ data = newData;
+ }
+ doneCount++;
+ }, name);
+ });
+
+ if (doneCount >= middlewares.length) {
+ execute = true;
+ }
+ } else {
+ execute = true;
+ }
+ }
+
+ // If Middleware checks have been passed, execute
+ if (execute) {
+ if (listener.once) {
+ listeners[index] = null;
+ }
+ listener.callback(data);
+ }
+ });
+
+ // Dirty way of removing used Events
+ while (listeners.indexOf(null) !== -1) {
+ listeners.splice(listeners.indexOf(null), 1);
+ }
+ }
+ }
+ }]);
+
+ return EventEmitter;
+}();
+
+/* harmony default export */ __webpack_exports__["a"] = (EventEmitter);
+
+/***/ })
+/******/ ])["default"];
+});
\ No newline at end of file
diff --git a/docs/static/js/bulma-slider.min.js b/docs/static/js/bulma-slider.min.js
new file mode 100644
index 0000000000000000000000000000000000000000..7e62685763cf7668cfa8857fac0b27af2c277286
--- /dev/null
+++ b/docs/static/js/bulma-slider.min.js
@@ -0,0 +1 @@
+!function(t,e){"object"==typeof exports&&"object"==typeof module?module.exports=e():"function"==typeof define&&define.amd?define([],e):"object"==typeof exports?exports.bulmaSlider=e():t.bulmaSlider=e()}("undefined"!=typeof self?self:this,function(){return function(n){var r={};function i(t){if(r[t])return r[t].exports;var e=r[t]={i:t,l:!1,exports:{}};return n[t].call(e.exports,e,e.exports,i),e.l=!0,e.exports}return i.m=n,i.c=r,i.d=function(t,e,n){i.o(t,e)||Object.defineProperty(t,e,{configurable:!1,enumerable:!0,get:n})},i.n=function(t){var e=t&&t.__esModule?function(){return t.default}:function(){return t};return i.d(e,"a",e),e},i.o=function(t,e){return Object.prototype.hasOwnProperty.call(t,e)},i.p="",i(i.s=0)}([function(t,e,n){"use strict";Object.defineProperty(e,"__esModule",{value:!0}),n.d(e,"isString",function(){return l});var r=n(1),i=Object.assign||function(t){for(var e=1;e=l.length&&(s=!0)):s=!0),s&&(t.once&&(u[e]=null),t.callback(r))});-1!==u.indexOf(null);)u.splice(u.indexOf(null),1)}}]),e}();e.a=i}]).default});
\ No newline at end of file
diff --git a/docs/static/js/fontawesome.all.min.js b/docs/static/js/fontawesome.all.min.js
new file mode 100644
index 0000000000000000000000000000000000000000..9ee22fdb7753983bae3986b2436bdd167730cd5b
--- /dev/null
+++ b/docs/static/js/fontawesome.all.min.js
@@ -0,0 +1,5 @@
+/*!
+ * Font Awesome Free 5.15.1 by @fontawesome - https://fontawesome.com
+ * License - https://fontawesome.com/license/free (Icons: CC BY 4.0, Fonts: SIL OFL 1.1, Code: MIT License)
+ */
+!function(){"use strict";var c={},l={};try{"undefined"!=typeof window&&(c=window),"undefined"!=typeof document&&(l=document)}catch(c){}var h=(c.navigator||{}).userAgent,z=void 0===h?"":h,a=c,v=l,m=(a.document,!!v.documentElement&&!!v.head&&"function"==typeof v.addEventListener&&v.createElement,~z.indexOf("MSIE")||z.indexOf("Trident/"),"___FONT_AWESOME___"),e=function(){try{return!0}catch(c){return!1}}();var s=a||{};s[m]||(s[m]={}),s[m].styles||(s[m].styles={}),s[m].hooks||(s[m].hooks={}),s[m].shims||(s[m].shims=[]);var t=s[m];function M(c,z){var l=(2>>0;h--;)l[h]=c[h];return l}function Ac(c){return c.classList?bc(c.classList):(c.getAttribute("class")||"").split(" ").filter(function(c){return c})}function gc(c,l){var h,z=l.split("-"),a=z[0],v=z.slice(1).join("-");return a!==c||""===v||(h=v,~T.indexOf(h))?null:v}function Sc(c){return"".concat(c).replace(/&/g,"&").replace(/"/g,""").replace(/'/g,"'").replace(//g,">")}function yc(h){return Object.keys(h||{}).reduce(function(c,l){return c+"".concat(l,": ").concat(h[l],";")},"")}function wc(c){return c.size!==Lc.size||c.x!==Lc.x||c.y!==Lc.y||c.rotate!==Lc.rotate||c.flipX||c.flipY}function Zc(c){var l=c.transform,h=c.containerWidth,z=c.iconWidth,a={transform:"translate(".concat(h/2," 256)")},v="translate(".concat(32*l.x,", ").concat(32*l.y,") "),m="scale(".concat(l.size/16*(l.flipX?-1:1),", ").concat(l.size/16*(l.flipY?-1:1),") "),e="rotate(".concat(l.rotate," 0 0)");return{outer:a,inner:{transform:"".concat(v," ").concat(m," ").concat(e)},path:{transform:"translate(".concat(z/2*-1," -256)")}}}var kc={x:0,y:0,width:"100%",height:"100%"};function xc(c){var l=!(1").concat(m.map(Jc).join(""),"").concat(l,">")}var $c=function(){};function cl(c){return"string"==typeof(c.getAttribute?c.getAttribute(cc):null)}var ll={replace:function(c){var l=c[0],h=c[1].map(function(c){return Jc(c)}).join("\n");if(l.parentNode&&l.outerHTML)l.outerHTML=h+(lc.keepOriginalSource&&"svg"!==l.tagName.toLowerCase()?"\x3c!-- ".concat(l.outerHTML," Font Awesome fontawesome.com --\x3e"):"");else if(l.parentNode){var z=document.createElement("span");l.parentNode.replaceChild(z,l),z.outerHTML=h}},nest:function(c){var l=c[0],h=c[1];if(~Ac(l).indexOf(lc.replacementClass))return ll.replace(c);var z=new RegExp("".concat(lc.familyPrefix,"-.*"));delete h[0].attributes.style,delete h[0].attributes.id;var a=h[0].attributes.class.split(" ").reduce(function(c,l){return l===lc.replacementClass||l.match(z)?c.toSvg.push(l):c.toNode.push(l),c},{toNode:[],toSvg:[]});h[0].attributes.class=a.toSvg.join(" ");var v=h.map(function(c){return Jc(c)}).join("\n");l.setAttribute("class",a.toNode.join(" ")),l.setAttribute(cc,""),l.innerHTML=v}};function hl(c){c()}function zl(h,c){var z="function"==typeof c?c:$c;if(0===h.length)z();else{var l=hl;lc.mutateApproach===y&&(l=o.requestAnimationFrame||hl),l(function(){var c=!0===lc.autoReplaceSvg?ll.replace:ll[lc.autoReplaceSvg]||ll.replace,l=_c.begin("mutate");h.map(c),l(),z()})}}var al=!1;function vl(){al=!1}var ml=null;function el(c){if(t&&lc.observeMutations){var a=c.treeCallback,v=c.nodeCallback,m=c.pseudoElementsCallback,l=c.observeMutationsRoot,h=void 0===l?C:l;ml=new t(function(c){al||bc(c).forEach(function(c){if("childList"===c.type&&0 0, border_ratio=0.2)
+ image = image[..., :3] * image[..., -1:] + (1 - image[..., -1:])
+ mv_image.append(image)
+ # image-conditioned (may also input text, but no text usually works too)
+ else:
+ input_image = np.array(input_image) # uint8
+ # bg removal
+ carved_image = rembg.remove(input_image, session=bg_remover) # [H, W, 4]
+ mask = carved_image[..., -1] > 0
+ image = recenter(carved_image, mask, border_ratio=0.2)
+ image = image.astype(np.float32) / 255.0
+ image = image[..., :3] * image[..., 3:4] + (1 - image[..., 3:4])
+ mv_image = pipe_image(prompt, image, negative_prompt=prompt_neg, num_inference_steps=input_num_steps, guidance_scale=5.0, elevation=input_elevation)
+
+ mv_image_grid = np.concatenate([
+ np.concatenate([mv_image[1], mv_image[2]], axis=1),
+ np.concatenate([mv_image[3], mv_image[0]], axis=1),
+ ], axis=0)
+
+ # generate gaussians
+ input_image = np.stack([mv_image[1], mv_image[2], mv_image[3], mv_image[0]], axis=0) # [4, 256, 256, 3], float32
+ input_image = torch.from_numpy(input_image).permute(0, 3, 1, 2).float().to(device) # [4, 3, 256, 256]
+ input_image = F.interpolate(input_image, size=(opt.input_size, opt.input_size), mode='bilinear', align_corners=False)
+ input_image = TF.normalize(input_image, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
+
+ rays_embeddings = model.prepare_default_rays(device, elevation=input_elevation)
+ input_image = torch.cat([input_image, rays_embeddings], dim=1).unsqueeze(0) # [1, 4, 9, H, W]
+
+ with torch.no_grad():
+ with torch.autocast(device_type='cuda', dtype=torch.float16):
+ # generate gaussians
+ gaussians = model.forward_gaussians(input_image)
+
+ # save gaussians
+ model.gs.save_ply(gaussians, output_ply_path)
+
+ # render 360 video
+ images = []
+ elevation = 0
+ if opt.fancy_video:
+ azimuth = np.arange(0, 720, 4, dtype=np.int32)
+ for azi in tqdm.tqdm(azimuth):
+
+ cam_poses = torch.from_numpy(orbit_camera(elevation, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)
+
+ cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
+
+ # cameras needed by gaussian rasterizer
+ cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
+ cam_view_proj = cam_view @ proj_matrix # [V, 4, 4]
+ cam_pos = - cam_poses[:, :3, 3] # [V, 3]
+
+ scale = min(azi / 360, 1)
+
+ image = model.gs.render(gaussians, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), scale_modifier=scale)['image']
+ images.append((image.squeeze(1).permute(0,2,3,1).contiguous().float().cpu().numpy() * 255).astype(np.uint8))
+ else:
+ azimuth = np.arange(0, 360, 2, dtype=np.int32)
+ for azi in tqdm.tqdm(azimuth):
+
+ cam_poses = torch.from_numpy(orbit_camera(elevation, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)
+
+ cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
+
+ # cameras needed by gaussian rasterizer
+ cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
+ cam_view_proj = cam_view @ proj_matrix # [V, 4, 4]
+ cam_pos = - cam_poses[:, :3, 3] # [V, 3]
+
+ image = model.gs.render(gaussians, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), scale_modifier=1)['image']
+ images.append((image.squeeze(1).permute(0,2,3,1).contiguous().float().cpu().numpy() * 255).astype(np.uint8))
+
+ images = np.concatenate(images, axis=0)
+ imageio.mimwrite(output_video_path, images, fps=30)
+
+ return mv_image_grid, output_video_path, output_ply_path
+
+# gradio UI
+
+_TITLE = '''LGM: Large Multi-View Gaussian Model for High-Resolution 3D Content Creation'''
+
+_DESCRIPTION = '''
+
+
+
+
+
+* Input can be only text, only image, or both image and text.
+* If you find the output unsatisfying, try using different seeds!
+'''
+
+block = gr.Blocks(title=_TITLE).queue()
+with block:
+ with gr.Row():
+ with gr.Column(scale=1):
+ gr.Markdown('# ' + _TITLE)
+ gr.Markdown(_DESCRIPTION)
+
+ with gr.Row(variant='panel'):
+ with gr.Column(scale=1):
+ # input image
+ input_image = gr.Image(label="image", type='pil')
+ # input prompt
+ input_text = gr.Textbox(label="prompt")
+ # negative prompt
+ input_neg_text = gr.Textbox(label="negative prompt", value='ugly, blurry, pixelated obscure, unnatural colors, poor lighting, dull, unclear, cropped, lowres, low quality, artifacts, duplicate')
+ # elevation
+ input_elevation = gr.Slider(label="elevation", minimum=-90, maximum=90, step=1, value=0)
+ # inference steps
+ input_num_steps = gr.Slider(label="inference steps", minimum=1, maximum=100, step=1, value=30)
+ # random seed
+ input_seed = gr.Slider(label="random seed", minimum=0, maximum=100000, step=1, value=0)
+ # gen button
+ button_gen = gr.Button("Generate")
+
+
+ with gr.Column(scale=1):
+ with gr.Tab("Video"):
+ # final video results
+ output_video = gr.Video(label="video")
+ # ply file
+ output_file = gr.File(label="ply")
+ with gr.Tab("Multi-view Image"):
+ # multi-view results
+ output_image = gr.Image(interactive=False, show_label=False)
+
+ button_gen.click(process, inputs=[input_image, input_text, input_neg_text, input_elevation, input_num_steps, input_seed], outputs=[output_image, output_video, output_file])
+
+ gr.Examples(
+ examples=[
+ "data_test/anya_rgba.png",
+ "data_test/bird_rgba.png",
+ "data_test/catstatue_rgba.png",
+ ],
+ inputs=[input_image],
+ outputs=[output_image, output_video, output_file],
+ fn=lambda x: process(input_image=x, prompt=''),
+ cache_examples=False,
+ label='Image-to-3D Examples'
+ )
+
+ gr.Examples(
+ examples=[
+ "a motorbike",
+ "a hamburger",
+ "a furry red fox head",
+ ],
+ inputs=[input_text],
+ outputs=[output_image, output_video, output_file],
+ fn=lambda x: process(input_image=None, prompt=x),
+ cache_examples=False,
+ label='Text-to-3D Examples'
+ )
+
+block.launch(server_name="0.0.0.0", share=False)
\ No newline at end of file
diff --git a/libs/LGM/convert.py b/libs/LGM/convert.py
new file mode 100644
index 0000000000000000000000000000000000000000..e898e3413e6a9d95d258005dc2c5c1bfadd94268
--- /dev/null
+++ b/libs/LGM/convert.py
@@ -0,0 +1,462 @@
+
+import os
+import tyro
+import tqdm
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from core.options import AllConfigs, Options
+from core.gs import GaussianRenderer
+
+import mcubes
+import nerfacc
+import nvdiffrast.torch as dr
+
+import kiui
+from kiui.mesh import Mesh
+from kiui.mesh_utils import clean_mesh, decimate_mesh
+from kiui.mesh_utils import laplacian_smooth_loss, normal_consistency
+from kiui.op import uv_padding, safe_normalize, inverse_sigmoid
+from kiui.cam import orbit_camera, get_perspective
+from kiui.nn import MLP, trunc_exp
+from kiui.gridencoder import GridEncoder
+
+def get_rays(pose, h, w, fovy, opengl=True):
+
+ x, y = torch.meshgrid(
+ torch.arange(w, device=pose.device),
+ torch.arange(h, device=pose.device),
+ indexing="xy",
+ )
+ x = x.flatten()
+ y = y.flatten()
+
+ cx = w * 0.5
+ cy = h * 0.5
+ focal = h * 0.5 / np.tan(0.5 * np.deg2rad(fovy))
+
+ camera_dirs = F.pad(
+ torch.stack(
+ [
+ (x - cx + 0.5) / focal,
+ (y - cy + 0.5) / focal * (-1.0 if opengl else 1.0),
+ ],
+ dim=-1,
+ ),
+ (0, 1),
+ value=(-1.0 if opengl else 1.0),
+ ) # [hw, 3]
+
+ rays_d = camera_dirs @ pose[:3, :3].transpose(0, 1) # [hw, 3]
+ rays_o = pose[:3, 3].unsqueeze(0).expand_as(rays_d) # [hw, 3]
+
+ rays_d = safe_normalize(rays_d)
+
+ return rays_o, rays_d
+
+# Triple renderer of gaussians, gaussian, and diso mesh.
+# gaussian --> nerf --> mesh
+class Converter(nn.Module):
+ def __init__(self, opt: Options):
+ super().__init__()
+
+ self.opt = opt
+ self.device = torch.device("cuda")
+
+ # gs renderer
+ self.tan_half_fov = np.tan(0.5 * np.deg2rad(opt.fovy))
+ self.proj_matrix = torch.zeros(4, 4, dtype=torch.float32, device=self.device)
+ self.proj_matrix[0, 0] = 1 / self.tan_half_fov
+ self.proj_matrix[1, 1] = 1 / self.tan_half_fov
+ self.proj_matrix[2, 2] = (opt.zfar + opt.znear) / (opt.zfar - opt.znear)
+ self.proj_matrix[3, 2] = - (opt.zfar * opt.znear) / (opt.zfar - opt.znear)
+ self.proj_matrix[2, 3] = 1
+
+ self.gs_renderer = GaussianRenderer(opt)
+
+ self.gaussians = self.gs_renderer.load_ply(opt.test_path).to(self.device)
+
+ # nerf renderer
+ if not self.opt.force_cuda_rast:
+ self.glctx = dr.RasterizeGLContext()
+ else:
+ self.glctx = dr.RasterizeCudaContext()
+
+ self.step = 0
+ self.render_step_size = 5e-3
+ self.aabb = torch.tensor([-1.0, -1.0, -1.0, 1.0, 1.0, 1.0], device=self.device)
+ self.estimator = nerfacc.OccGridEstimator(roi_aabb=self.aabb, resolution=64, levels=1)
+
+ self.encoder_density = GridEncoder(num_levels=12) # VMEncoder(output_dim=16, mode='sum')
+ self.encoder = GridEncoder(num_levels=12)
+ self.mlp_density = MLP(self.encoder_density.output_dim, 1, 32, 2, bias=False)
+ self.mlp = MLP(self.encoder.output_dim, 3, 32, 2, bias=False)
+
+ # mesh renderer
+ self.proj = torch.from_numpy(get_perspective(self.opt.fovy)).float().to(self.device)
+ self.v = self.f = None
+ self.vt = self.ft = None
+ self.deform = None
+ self.albedo = None
+
+
+ @torch.no_grad()
+ def render_gs(self, pose):
+
+ cam_poses = torch.from_numpy(pose).unsqueeze(0).to(self.device)
+ cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
+
+ # cameras needed by gaussian rasterizer
+ cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
+ cam_view_proj = cam_view @ self.proj_matrix # [V, 4, 4]
+ cam_pos = - cam_poses[:, :3, 3] # [V, 3]
+
+ out = self.gs_renderer.render(self.gaussians.unsqueeze(0), cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0))
+ image = out['image'].squeeze(1).squeeze(0) # [C, H, W]
+ alpha = out['alpha'].squeeze(2).squeeze(1).squeeze(0) # [H, W]
+
+ return image, alpha
+
+ def get_density(self, xs):
+ # xs: [..., 3]
+ prefix = xs.shape[:-1]
+ xs = xs.view(-1, 3)
+ feats = self.encoder_density(xs)
+ density = trunc_exp(self.mlp_density(feats))
+ density = density.view(*prefix, 1)
+ return density
+
+ def render_nerf(self, pose):
+
+ pose = torch.from_numpy(pose.astype(np.float32)).to(self.device)
+
+ # get rays
+ resolution = self.opt.output_size
+ rays_o, rays_d = get_rays(pose, resolution, resolution, self.opt.fovy)
+
+ # update occ grid
+ if self.training:
+ def occ_eval_fn(xs):
+ sigmas = self.get_density(xs)
+ return self.render_step_size * sigmas
+
+ self.estimator.update_every_n_steps(self.step, occ_eval_fn=occ_eval_fn, occ_thre=0.01, n=8)
+ self.step += 1
+
+ # render
+ def sigma_fn(t_starts, t_ends, ray_indices):
+ t_origins = rays_o[ray_indices]
+ t_dirs = rays_d[ray_indices]
+ xs = t_origins + t_dirs * (t_starts + t_ends)[:, None] / 2.0
+ sigmas = self.get_density(xs)
+ return sigmas.squeeze(-1)
+
+ with torch.no_grad():
+ ray_indices, t_starts, t_ends = self.estimator.sampling(
+ rays_o,
+ rays_d,
+ sigma_fn=sigma_fn,
+ near_plane=0.01,
+ far_plane=100,
+ render_step_size=self.render_step_size,
+ stratified=self.training,
+ cone_angle=0,
+ )
+
+ t_origins = rays_o[ray_indices]
+ t_dirs = rays_d[ray_indices]
+ xs = t_origins + t_dirs * (t_starts + t_ends)[:, None] / 2.0
+ sigmas = self.get_density(xs).squeeze(-1)
+ rgbs = torch.sigmoid(self.mlp(self.encoder(xs)))
+
+ n_rays=rays_o.shape[0]
+ weights, trans, alphas = nerfacc.render_weight_from_density(t_starts, t_ends, sigmas, ray_indices=ray_indices, n_rays=n_rays)
+ color = nerfacc.accumulate_along_rays(weights, values=rgbs, ray_indices=ray_indices, n_rays=n_rays)
+ alpha = nerfacc.accumulate_along_rays(weights, values=None, ray_indices=ray_indices, n_rays=n_rays)
+
+ color = color + 1 * (1.0 - alpha)
+
+ color = color.view(resolution, resolution, 3).clamp(0, 1).permute(2, 0, 1).contiguous()
+ alpha = alpha.view(resolution, resolution).clamp(0, 1).contiguous()
+
+ return color, alpha
+
+ def fit_nerf(self, iters=512, resolution=128):
+
+ self.opt.output_size = resolution
+
+ optimizer = torch.optim.Adam([
+ {'params': self.encoder_density.parameters(), 'lr': 1e-2},
+ {'params': self.encoder.parameters(), 'lr': 1e-2},
+ {'params': self.mlp_density.parameters(), 'lr': 1e-3},
+ {'params': self.mlp.parameters(), 'lr': 1e-3},
+ ])
+
+ print(f"[INFO] fitting nerf...")
+ pbar = tqdm.trange(iters)
+ for i in pbar:
+
+ ver = np.random.randint(-45, 45)
+ hor = np.random.randint(-180, 180)
+ rad = np.random.uniform(1.5, 3.0)
+
+ pose = orbit_camera(ver, hor, rad)
+
+ image_gt, alpha_gt = self.render_gs(pose)
+ image_pred, alpha_pred = self.render_nerf(pose)
+
+ # if i % 200 == 0:
+ # kiui.vis.plot_image(image_gt, alpha_gt, image_pred, alpha_pred)
+
+ loss_mse = F.mse_loss(image_pred, image_gt) + 0.1 * F.mse_loss(alpha_pred, alpha_gt)
+ loss = loss_mse #+ 0.1 * self.encoder_density.tv_loss() #+ 0.0001 * self.encoder_density.density_loss()
+
+ loss.backward()
+ self.encoder_density.grad_total_variation(1e-8)
+
+ optimizer.step()
+ optimizer.zero_grad()
+
+ pbar.set_description(f"MSE = {loss_mse.item():.6f}")
+
+ print(f"[INFO] finished fitting nerf!")
+
+ def render_mesh(self, pose):
+
+ h = w = self.opt.output_size
+
+ v = self.v + self.deform
+ f = self.f
+
+ pose = torch.from_numpy(pose.astype(np.float32)).to(v.device)
+
+ # get v_clip and render rgb
+ v_cam = torch.matmul(F.pad(v, pad=(0, 1), mode='constant', value=1.0), torch.inverse(pose).T).float().unsqueeze(0)
+ v_clip = v_cam @ self.proj.T
+
+ rast, rast_db = dr.rasterize(self.glctx, v_clip, f, (h, w))
+
+ alpha = torch.clamp(rast[..., -1:], 0, 1).contiguous() # [1, H, W, 1]
+ alpha = dr.antialias(alpha, rast, v_clip, f).clamp(0, 1).squeeze(-1).squeeze(0) # [H, W] important to enable gradients!
+
+ if self.albedo is None:
+ xyzs, _ = dr.interpolate(v.unsqueeze(0), rast, f) # [1, H, W, 3]
+ xyzs = xyzs.view(-1, 3)
+ mask = (alpha > 0).view(-1)
+ image = torch.zeros_like(xyzs, dtype=torch.float32)
+ if mask.any():
+ masked_albedo = torch.sigmoid(self.mlp(self.encoder(xyzs[mask].detach(), bound=1)))
+ image[mask] = masked_albedo.float()
+ else:
+ texc, texc_db = dr.interpolate(self.vt.unsqueeze(0), rast, self.ft, rast_db=rast_db, diff_attrs='all')
+ image = torch.sigmoid(dr.texture(self.albedo.unsqueeze(0), texc, uv_da=texc_db)) # [1, H, W, 3]
+
+ image = image.view(1, h, w, 3)
+ # image = dr.antialias(image, rast, v_clip, f).clamp(0, 1)
+ image = image.squeeze(0).permute(2, 0, 1).contiguous() # [3, H, W]
+ image = alpha * image + (1 - alpha)
+
+ return image, alpha
+
+ def fit_mesh(self, iters=2048, resolution=512, decimate_target=5e4):
+
+ self.opt.output_size = resolution
+
+ # init mesh from nerf
+ grid_size = 256
+ sigmas = np.zeros([grid_size, grid_size, grid_size], dtype=np.float32)
+
+ S = 128
+ density_thresh = 10
+
+ X = torch.linspace(-1, 1, grid_size).split(S)
+ Y = torch.linspace(-1, 1, grid_size).split(S)
+ Z = torch.linspace(-1, 1, grid_size).split(S)
+
+ for xi, xs in enumerate(X):
+ for yi, ys in enumerate(Y):
+ for zi, zs in enumerate(Z):
+ xx, yy, zz = torch.meshgrid(xs, ys, zs, indexing='ij')
+ pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [S, 3]
+ val = self.get_density(pts.to(self.device))
+ sigmas[xi * S: xi * S + len(xs), yi * S: yi * S + len(ys), zi * S: zi * S + len(zs)] = val.reshape(len(xs), len(ys), len(zs)).detach().cpu().numpy() # [S, 1] --> [x, y, z]
+
+ print(f'[INFO] marching cubes thresh: {density_thresh} ({sigmas.min()} ~ {sigmas.max()})')
+
+ vertices, triangles = mcubes.marching_cubes(sigmas, density_thresh)
+ vertices = vertices / (grid_size - 1.0) * 2 - 1
+
+ # clean
+ vertices = vertices.astype(np.float32)
+ triangles = triangles.astype(np.int32)
+ vertices, triangles = clean_mesh(vertices, triangles, remesh=True, remesh_size=0.01)
+ if triangles.shape[0] > decimate_target:
+ vertices, triangles = decimate_mesh(vertices, triangles, decimate_target, optimalplacement=False)
+
+ self.v = torch.from_numpy(vertices).contiguous().float().to(self.device)
+ self.f = torch.from_numpy(triangles).contiguous().int().to(self.device)
+ self.deform = nn.Parameter(torch.zeros_like(self.v)).to(self.device)
+
+ # fit mesh from gs
+ lr_factor = 1
+ optimizer = torch.optim.Adam([
+ {'params': self.encoder.parameters(), 'lr': 1e-3 * lr_factor},
+ {'params': self.mlp.parameters(), 'lr': 1e-3 * lr_factor},
+ {'params': self.deform, 'lr': 1e-4},
+ ])
+
+ print(f"[INFO] fitting mesh...")
+ pbar = tqdm.trange(iters)
+ for i in pbar:
+
+ ver = np.random.randint(-10, 10)
+ hor = np.random.randint(-180, 180)
+ rad = self.opt.cam_radius # np.random.uniform(1, 2)
+
+ pose = orbit_camera(ver, hor, rad)
+
+ image_gt, alpha_gt = self.render_gs(pose)
+ image_pred, alpha_pred = self.render_mesh(pose)
+
+ loss_mse = F.mse_loss(image_pred, image_gt) + 0.1 * F.mse_loss(alpha_pred, alpha_gt)
+ # loss_lap = laplacian_smooth_loss(self.v + self.deform, self.f)
+ loss_normal = normal_consistency(self.v + self.deform, self.f)
+ loss_offsets = (self.deform ** 2).sum(-1).mean()
+ loss = loss_mse + 0.001 * loss_normal + 0.1 * loss_offsets
+
+ loss.backward()
+
+ optimizer.step()
+ optimizer.zero_grad()
+
+ # remesh periodically
+ if i > 0 and i % 512 == 0:
+ vertices = (self.v + self.deform).detach().cpu().numpy()
+ triangles = self.f.detach().cpu().numpy()
+ vertices, triangles = clean_mesh(vertices, triangles, remesh=True, remesh_size=0.01)
+ if triangles.shape[0] > decimate_target:
+ vertices, triangles = decimate_mesh(vertices, triangles, decimate_target, optimalplacement=False)
+ self.v = torch.from_numpy(vertices).contiguous().float().to(self.device)
+ self.f = torch.from_numpy(triangles).contiguous().int().to(self.device)
+ self.deform = nn.Parameter(torch.zeros_like(self.v)).to(self.device)
+ lr_factor *= 0.5
+ optimizer = torch.optim.Adam([
+ {'params': self.encoder.parameters(), 'lr': 1e-3 * lr_factor},
+ {'params': self.mlp.parameters(), 'lr': 1e-3 * lr_factor},
+ {'params': self.deform, 'lr': 1e-4},
+ ])
+
+ pbar.set_description(f"MSE = {loss_mse.item():.6f}")
+
+ # last clean
+ vertices = (self.v + self.deform).detach().cpu().numpy()
+ triangles = self.f.detach().cpu().numpy()
+ vertices, triangles = clean_mesh(vertices, triangles, remesh=False)
+ self.v = torch.from_numpy(vertices).contiguous().float().to(self.device)
+ self.f = torch.from_numpy(triangles).contiguous().int().to(self.device)
+ self.deform = nn.Parameter(torch.zeros_like(self.v).to(self.device))
+
+ print(f"[INFO] finished fitting mesh!")
+
+ # uv mesh refine
+ def fit_mesh_uv(self, iters=512, resolution=512, texture_resolution=1024, padding=2):
+
+ self.opt.output_size = resolution
+
+ # unwrap uv
+ print(f"[INFO] uv unwrapping...")
+ mesh = Mesh(v=self.v, f=self.f, albedo=None, device=self.device)
+ mesh.auto_normal()
+ mesh.auto_uv()
+
+ self.vt = mesh.vt
+ self.ft = mesh.ft
+
+ # render uv maps
+ h = w = texture_resolution
+ uv = mesh.vt * 2.0 - 1.0 # uvs to range [-1, 1]
+ uv = torch.cat((uv, torch.zeros_like(uv[..., :1]), torch.ones_like(uv[..., :1])), dim=-1) # [N, 4]
+
+ rast, _ = dr.rasterize(self.glctx, uv.unsqueeze(0), mesh.ft, (h, w)) # [1, h, w, 4]
+ xyzs, _ = dr.interpolate(mesh.v.unsqueeze(0), rast, mesh.f) # [1, h, w, 3]
+ mask, _ = dr.interpolate(torch.ones_like(mesh.v[:, :1]).unsqueeze(0), rast, mesh.f) # [1, h, w, 1]
+
+ # masked query
+ xyzs = xyzs.view(-1, 3)
+ mask = (mask > 0).view(-1)
+
+ albedo = torch.zeros(h * w, 3, device=self.device, dtype=torch.float32)
+
+ if mask.any():
+ print(f"[INFO] querying texture...")
+
+ xyzs = xyzs[mask] # [M, 3]
+
+ # batched inference to avoid OOM
+ batch = []
+ head = 0
+ while head < xyzs.shape[0]:
+ tail = min(head + 640000, xyzs.shape[0])
+ batch.append(torch.sigmoid(self.mlp(self.encoder(xyzs[head:tail]))).float())
+ head += 640000
+
+ albedo[mask] = torch.cat(batch, dim=0)
+
+ albedo = albedo.view(h, w, -1)
+ mask = mask.view(h, w)
+ albedo = uv_padding(albedo, mask, padding)
+
+ # optimize texture
+ self.albedo = nn.Parameter(inverse_sigmoid(albedo)).to(self.device)
+
+ optimizer = torch.optim.Adam([
+ {'params': self.albedo, 'lr': 1e-3},
+ ])
+
+ print(f"[INFO] fitting mesh texture...")
+ pbar = tqdm.trange(iters)
+ for i in pbar:
+
+ # shrink to front view as we care more about it...
+ ver = np.random.randint(-5, 5)
+ hor = np.random.randint(-15, 15)
+ rad = self.opt.cam_radius # np.random.uniform(1, 2)
+
+ pose = orbit_camera(ver, hor, rad)
+
+ image_gt, alpha_gt = self.render_gs(pose)
+ image_pred, alpha_pred = self.render_mesh(pose)
+
+ loss_mse = F.mse_loss(image_pred, image_gt)
+ loss = loss_mse
+
+ loss.backward()
+
+ optimizer.step()
+ optimizer.zero_grad()
+
+ pbar.set_description(f"MSE = {loss_mse.item():.6f}")
+
+ print(f"[INFO] finished fitting mesh texture!")
+
+
+ @torch.no_grad()
+ def export_mesh(self, path):
+
+ mesh = Mesh(v=self.v, f=self.f, vt=self.vt, ft=self.ft, albedo=torch.sigmoid(self.albedo), device=self.device)
+ mesh.auto_normal()
+ mesh.write(path)
+
+
+opt = tyro.cli(AllConfigs)
+
+# load a saved ply and convert to mesh
+assert opt.test_path.endswith('.ply'), '--test_path must be a .ply file saved by infer.py'
+
+converter = Converter(opt).cuda()
+converter.fit_nerf()
+converter.fit_mesh()
+converter.fit_mesh_uv()
+converter.export_mesh(opt.test_path.replace('.ply', '.glb'))
diff --git a/libs/LGM/core/__init__.py b/libs/LGM/core/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/libs/LGM/core/attention.py b/libs/LGM/core/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea1382f65805a8650b3d3369c803dd4df0bc9dc8
--- /dev/null
+++ b/libs/LGM/core/attention.py
@@ -0,0 +1,156 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
+
+import os
+import warnings
+
+from torch import Tensor
+from torch import nn
+
+XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
+try:
+ if XFORMERS_ENABLED:
+ from xformers.ops import memory_efficient_attention, unbind
+
+ XFORMERS_AVAILABLE = True
+ warnings.warn("xFormers is available (Attention)")
+ else:
+ warnings.warn("xFormers is disabled (Attention)")
+ raise ImportError
+except ImportError:
+ XFORMERS_AVAILABLE = False
+ warnings.warn("xFormers is not available (Attention)")
+
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ proj_bias: bool = True,
+ attn_drop: float = 0.0,
+ proj_drop: float = 0.0,
+ ) -> None:
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim**-0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x: Tensor) -> Tensor:
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
+ attn = q @ k.transpose(-2, -1)
+
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class MemEffAttention(Attention):
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
+ if not XFORMERS_AVAILABLE:
+ if attn_bias is not None:
+ raise AssertionError("xFormers is required for using nested tensors")
+ return super().forward(x)
+
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
+
+ q, k, v = unbind(qkv, 2)
+
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
+ x = x.reshape([B, N, C])
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class CrossAttention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ dim_q: int,
+ dim_k: int,
+ dim_v: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ proj_bias: bool = True,
+ attn_drop: float = 0.0,
+ proj_drop: float = 0.0,
+ ) -> None:
+ super().__init__()
+ self.dim = dim
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim**-0.5
+
+ self.to_q = nn.Linear(dim_q, dim, bias=qkv_bias)
+ self.to_k = nn.Linear(dim_k, dim, bias=qkv_bias)
+ self.to_v = nn.Linear(dim_v, dim, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
+ # q: [B, N, Cq]
+ # k: [B, M, Ck]
+ # v: [B, M, Cv]
+ # return: [B, N, C]
+
+ B, N, _ = q.shape
+ M = k.shape[1]
+
+ q = self.scale * self.to_q(q).reshape(B, N, self.num_heads, self.dim // self.num_heads).permute(0, 2, 1, 3) # [B, nh, N, C/nh]
+ k = self.to_k(k).reshape(B, M, self.num_heads, self.dim // self.num_heads).permute(0, 2, 1, 3) # [B, nh, M, C/nh]
+ v = self.to_v(v).reshape(B, M, self.num_heads, self.dim // self.num_heads).permute(0, 2, 1, 3) # [B, nh, M, C/nh]
+
+ attn = q @ k.transpose(-2, -1) # [B, nh, N, M]
+
+ attn = attn.softmax(dim=-1) # [B, nh, N, M]
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1) # [B, nh, N, M] @ [B, nh, M, C/nh] --> [B, nh, N, C/nh] --> [B, N, nh, C/nh] --> [B, N, C]
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class MemEffCrossAttention(CrossAttention):
+ def forward(self, q: Tensor, k: Tensor, v: Tensor, attn_bias=None) -> Tensor:
+ if not XFORMERS_AVAILABLE:
+ if attn_bias is not None:
+ raise AssertionError("xFormers is required for using nested tensors")
+ return super().forward(x)
+
+ B, N, _ = q.shape
+ M = k.shape[1]
+
+ q = self.scale * self.to_q(q).reshape(B, N, self.num_heads, self.dim // self.num_heads) # [B, N, nh, C/nh]
+ k = self.to_k(k).reshape(B, M, self.num_heads, self.dim // self.num_heads) # [B, M, nh, C/nh]
+ v = self.to_v(v).reshape(B, M, self.num_heads, self.dim // self.num_heads) # [B, M, nh, C/nh]
+
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
+ x = x.reshape(B, N, -1)
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
diff --git a/libs/LGM/core/gs.py b/libs/LGM/core/gs.py
new file mode 100644
index 0000000000000000000000000000000000000000..f06bae6242fa46d7ab362ab2aba114314231bf7e
--- /dev/null
+++ b/libs/LGM/core/gs.py
@@ -0,0 +1,194 @@
+import numpy as np
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from diff_gaussian_rasterization import (
+ GaussianRasterizationSettings,
+ GaussianRasterizer,
+)
+
+from core.options import Options
+
+import kiui
+
+class GaussianRenderer:
+ def __init__(self, opt: Options):
+
+ self.opt = opt
+ self.bg_color = torch.tensor([1, 1, 1], dtype=torch.float32, device="cuda")
+
+ # intrinsics
+ self.tan_half_fov = np.tan(0.5 * np.deg2rad(self.opt.fovy))
+ self.proj_matrix = torch.zeros(4, 4, dtype=torch.float32)
+ self.proj_matrix[0, 0] = 1 / self.tan_half_fov
+ self.proj_matrix[1, 1] = 1 / self.tan_half_fov
+ self.proj_matrix[2, 2] = (opt.zfar + opt.znear) / (opt.zfar - opt.znear)
+ self.proj_matrix[3, 2] = - (opt.zfar * opt.znear) / (opt.zfar - opt.znear)
+ self.proj_matrix[2, 3] = 1
+
+ def render(self, gaussians, cam_view, cam_view_proj, cam_pos, bg_color=None, scale_modifier=1):
+ # gaussians: [B, N, 14]
+ # cam_view, cam_view_proj: [B, V, 4, 4]
+ # cam_pos: [B, V, 3]
+
+ device = gaussians.device
+ B, V = cam_view.shape[:2]
+
+ # loop of loop...
+ images = []
+ alphas = []
+ depths = []
+
+ for b in range(B):
+
+ # pos, opacity, scale, rotation, shs
+ means3D = gaussians[b, :, 0:3].contiguous().float()
+ opacity = gaussians[b, :, 3:4].contiguous().float()
+ scales = gaussians[b, :, 4:7].contiguous().float()
+ rotations = gaussians[b, :, 7:11].contiguous().float()
+ rgbs = gaussians[b, :, 11:].contiguous().float() # [N, 3]
+
+ for v in range(V):
+
+ # render novel views
+ view_matrix = cam_view[b, v].float()
+ view_proj_matrix = cam_view_proj[b, v].float()
+ campos = cam_pos[b, v].float()
+
+ raster_settings = GaussianRasterizationSettings(
+ image_height=self.opt.output_size,
+ image_width=self.opt.output_size,
+ tanfovx=self.tan_half_fov,
+ tanfovy=self.tan_half_fov,
+ bg=self.bg_color if bg_color is None else bg_color,
+ scale_modifier=scale_modifier,
+ viewmatrix=view_matrix,
+ projmatrix=view_proj_matrix,
+ sh_degree=0,
+ campos=campos,
+ prefiltered=False,
+ debug=False,
+ )
+
+ rasterizer = GaussianRasterizer(raster_settings=raster_settings)
+
+ # Rasterize visible Gaussians to image, obtain their radii (on screen).
+ rendered_image, radii, rendered_depth, rendered_alpha = rasterizer(
+ means3D=means3D,
+ means2D=torch.zeros_like(means3D, dtype=torch.float32, device=device),
+ shs=None,
+ colors_precomp=rgbs,
+ opacities=opacity,
+ scales=scales,
+ rotations=rotations,
+ cov3D_precomp=None,
+ )
+
+ rendered_image = rendered_image.clamp(0, 1)
+
+ images.append(rendered_image)
+ alphas.append(rendered_alpha)
+ depths.append(rendered_depth)
+
+ images = torch.stack(images, dim=0).view(B, V, 3, self.opt.output_size, self.opt.output_size)
+ alphas = torch.stack(alphas, dim=0).view(B, V, 1, self.opt.output_size, self.opt.output_size)
+ depths = torch.stack(depths, dim=0).view(B, V, 1, self.opt.output_size, self.opt.output_size)
+
+ return {
+ "image": images, # [B, V, 3, H, W]
+ "alpha": alphas, # [B, V, 1, H, W]
+ "depth": depths, # [B, V, 1, H, W]
+ }
+
+ def save_ply(self, gaussians, path, compatible=True):
+ # gaussians: [B, N, 14]
+ # compatible: save pre-activated gaussians as in the original paper
+
+ assert gaussians.shape[0] == 1, 'only support batch size 1'
+
+ from plyfile import PlyData, PlyElement
+
+ means3D = gaussians[0, :, 0:3].contiguous().float()
+ opacity = gaussians[0, :, 3:4].contiguous().float()
+ scales = gaussians[0, :, 4:7].contiguous().float()
+ rotations = gaussians[0, :, 7:11].contiguous().float()
+ shs = gaussians[0, :, 11:].unsqueeze(1).contiguous().float() # [N, 1, 3]
+
+ # prune by opacity
+ mask = opacity.squeeze(-1) >= 0.005
+ means3D = means3D[mask]
+ opacity = opacity[mask]
+ scales = scales[mask]
+ rotations = rotations[mask]
+ shs = shs[mask]
+
+ # invert activation to make it compatible with the original ply format
+ if compatible:
+ opacity = kiui.op.inverse_sigmoid(opacity)
+ scales = torch.log(scales + 1e-8)
+ shs = (shs - 0.5) / 0.28209479177387814
+
+ xyzs = means3D.detach().cpu().numpy()
+ f_dc = shs.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
+ opacities = opacity.detach().cpu().numpy()
+ scales = scales.detach().cpu().numpy()
+ rotations = rotations.detach().cpu().numpy()
+
+ l = ['x', 'y', 'z']
+ # All channels except the 3 DC
+ for i in range(f_dc.shape[1]):
+ l.append('f_dc_{}'.format(i))
+ l.append('opacity')
+ for i in range(scales.shape[1]):
+ l.append('scale_{}'.format(i))
+ for i in range(rotations.shape[1]):
+ l.append('rot_{}'.format(i))
+
+ dtype_full = [(attribute, 'f4') for attribute in l]
+
+ elements = np.empty(xyzs.shape[0], dtype=dtype_full)
+ attributes = np.concatenate((xyzs, f_dc, opacities, scales, rotations), axis=1)
+ elements[:] = list(map(tuple, attributes))
+ el = PlyElement.describe(elements, 'vertex')
+
+ PlyData([el]).write(path)
+
+ def load_ply(self, path, compatible=True):
+
+ from plyfile import PlyData, PlyElement
+
+ plydata = PlyData.read(path)
+
+ xyz = np.stack((np.asarray(plydata.elements[0]["x"]),
+ np.asarray(plydata.elements[0]["y"]),
+ np.asarray(plydata.elements[0]["z"])), axis=1)
+ print("Number of points at loading : ", xyz.shape[0])
+
+ opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis]
+
+ shs = np.zeros((xyz.shape[0], 3))
+ shs[:, 0] = np.asarray(plydata.elements[0]["f_dc_0"])
+ shs[:, 1] = np.asarray(plydata.elements[0]["f_dc_1"])
+ shs[:, 2] = np.asarray(plydata.elements[0]["f_dc_2"])
+
+ scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")]
+ scales = np.zeros((xyz.shape[0], len(scale_names)))
+ for idx, attr_name in enumerate(scale_names):
+ scales[:, idx] = np.asarray(plydata.elements[0][attr_name])
+
+ rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot_")]
+ rots = np.zeros((xyz.shape[0], len(rot_names)))
+ for idx, attr_name in enumerate(rot_names):
+ rots[:, idx] = np.asarray(plydata.elements[0][attr_name])
+
+ gaussians = np.concatenate([xyz, opacities, scales, rots, shs], axis=1)
+ gaussians = torch.from_numpy(gaussians).float() # cpu
+
+ if compatible:
+ gaussians[..., 3:4] = torch.sigmoid(gaussians[..., 3:4])
+ gaussians[..., 4:7] = torch.exp(gaussians[..., 4:7])
+ gaussians[..., 11:] = 0.28209479177387814 * gaussians[..., 11:] + 0.5
+
+ return gaussians
\ No newline at end of file
diff --git a/libs/LGM/core/models.py b/libs/LGM/core/models.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b8bf322c77a1f1920830bb19b5acc1652239cb8
--- /dev/null
+++ b/libs/LGM/core/models.py
@@ -0,0 +1,171 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+
+import kiui
+from kiui.lpips import LPIPS
+
+from core.unet import UNet
+from core.options import Options
+from core.gs import GaussianRenderer
+
+
+class LGM(nn.Module):
+ def __init__(
+ self,
+ opt: Options,
+ ):
+ super().__init__()
+
+ self.opt = opt
+
+ # unet
+ self.unet = UNet(
+ 9, 14,
+ down_channels=self.opt.down_channels,
+ down_attention=self.opt.down_attention,
+ mid_attention=self.opt.mid_attention,
+ up_channels=self.opt.up_channels,
+ up_attention=self.opt.up_attention,
+ )
+
+ # last conv
+ self.conv = nn.Conv2d(14, 14, kernel_size=1) # NOTE: maybe remove it if train again
+
+ # Gaussian Renderer
+ self.gs = GaussianRenderer(opt)
+
+ # activations...
+ self.pos_act = lambda x: x.clamp(-1, 1)
+ self.scale_act = lambda x: 0.1 * F.softplus(x)
+ self.opacity_act = lambda x: torch.sigmoid(x)
+ self.rot_act = lambda x: F.normalize(x, dim=-1)
+ self.rgb_act = lambda x: 0.5 * torch.tanh(x) + 0.5 # NOTE: may use sigmoid if train again
+
+ # LPIPS loss
+ if self.opt.lambda_lpips > 0:
+ self.lpips_loss = LPIPS(net='vgg')
+ self.lpips_loss.requires_grad_(False)
+
+
+ def state_dict(self, **kwargs):
+ # remove lpips_loss
+ state_dict = super().state_dict(**kwargs)
+ for k in list(state_dict.keys()):
+ if 'lpips_loss' in k:
+ del state_dict[k]
+ return state_dict
+
+
+ def prepare_default_rays(self, device, elevation=0):
+
+ from kiui.cam import orbit_camera
+ from core.utils import get_rays
+
+ cam_poses = np.stack([
+ orbit_camera(elevation, 0, radius=self.opt.cam_radius),
+ orbit_camera(elevation, 90, radius=self.opt.cam_radius),
+ orbit_camera(elevation, 180, radius=self.opt.cam_radius),
+ orbit_camera(elevation, 270, radius=self.opt.cam_radius),
+ ], axis=0) # [4, 4, 4]
+ cam_poses = torch.from_numpy(cam_poses)
+
+ rays_embeddings = []
+ for i in range(cam_poses.shape[0]):
+ rays_o, rays_d = get_rays(cam_poses[i], self.opt.input_size, self.opt.input_size, self.opt.fovy) # [h, w, 3]
+ rays_plucker = torch.cat([torch.cross(rays_o, rays_d, dim=-1), rays_d], dim=-1) # [h, w, 6]
+ rays_embeddings.append(rays_plucker)
+
+ ## visualize rays for plotting figure
+ # kiui.vis.plot_image(rays_d * 0.5 + 0.5, save=True)
+
+ rays_embeddings = torch.stack(rays_embeddings, dim=0).permute(0, 3, 1, 2).contiguous().to(device) # [V, 6, h, w]
+
+ return rays_embeddings
+
+
+ def forward_gaussians(self, images):
+ # images: [B, 4, 9, H, W]
+ # return: Gaussians: [B, dim_t]
+
+ B, V, C, H, W = images.shape
+ images = images.view(B*V, C, H, W)
+
+ x = self.unet(images) # [B*4, 14, h, w]
+ x = self.conv(x) # [B*4, 14, h, w]
+
+ x = x.reshape(B, 4, 14, self.opt.splat_size, self.opt.splat_size)
+
+ ## visualize multi-view gaussian features for plotting figure
+ # tmp_alpha = self.opacity_act(x[0, :, 3:4])
+ # tmp_img_rgb = self.rgb_act(x[0, :, 11:]) * tmp_alpha + (1 - tmp_alpha)
+ # tmp_img_pos = self.pos_act(x[0, :, 0:3]) * 0.5 + 0.5
+ # kiui.vis.plot_image(tmp_img_rgb, save=True)
+ # kiui.vis.plot_image(tmp_img_pos, save=True)
+
+ x = x.permute(0, 1, 3, 4, 2).reshape(B, -1, 14)
+
+ pos = self.pos_act(x[..., 0:3]) # [B, N, 3]
+ opacity = self.opacity_act(x[..., 3:4])
+ scale = self.scale_act(x[..., 4:7])
+ rotation = self.rot_act(x[..., 7:11])
+ rgbs = self.rgb_act(x[..., 11:])
+
+ gaussians = torch.cat([pos, opacity, scale, rotation, rgbs], dim=-1) # [B, N, 14]
+
+ return gaussians
+
+
+ def forward(self, data, step_ratio=1):
+ # data: output of the dataloader
+ # return: loss
+
+ results = {}
+ loss = 0
+
+ images = data['input'] # [B, 4, 9, h, W], input features
+
+ # use the first view to predict gaussians
+ gaussians = self.forward_gaussians(images) # [B, N, 14]
+
+ results['gaussians'] = gaussians
+
+ # always use white bg
+ bg_color = torch.ones(3, dtype=torch.float32, device=gaussians.device)
+
+ # use the other views for rendering and supervision
+ results = self.gs.render(gaussians, data['cam_view'], data['cam_view_proj'], data['cam_pos'], bg_color=bg_color)
+ pred_images = results['image'] # [B, V, C, output_size, output_size]
+ pred_alphas = results['alpha'] # [B, V, 1, output_size, output_size]
+
+ results['images_pred'] = pred_images
+ results['alphas_pred'] = pred_alphas
+
+ gt_images = data['images_output'] # [B, V, 3, output_size, output_size], ground-truth novel views
+ gt_masks = data['masks_output'] # [B, V, 1, output_size, output_size], ground-truth masks
+
+ gt_images = gt_images * gt_masks + bg_color.view(1, 1, 3, 1, 1) * (1 - gt_masks)
+
+ loss_mse = F.mse_loss(pred_images, gt_images) + F.mse_loss(pred_alphas, gt_masks)
+ loss = loss + loss_mse
+
+ if self.opt.lambda_lpips > 0:
+ loss_lpips = self.lpips_loss(
+ # gt_images.view(-1, 3, self.opt.output_size, self.opt.output_size) * 2 - 1,
+ # pred_images.view(-1, 3, self.opt.output_size, self.opt.output_size) * 2 - 1,
+ # downsampled to at most 256 to reduce memory cost
+ F.interpolate(gt_images.view(-1, 3, self.opt.output_size, self.opt.output_size) * 2 - 1, (256, 256), mode='bilinear', align_corners=False),
+ F.interpolate(pred_images.view(-1, 3, self.opt.output_size, self.opt.output_size) * 2 - 1, (256, 256), mode='bilinear', align_corners=False),
+ ).mean()
+ results['loss_lpips'] = loss_lpips
+ loss = loss + self.opt.lambda_lpips * loss_lpips
+
+ results['loss'] = loss
+
+ # metric
+ with torch.no_grad():
+ psnr = -10 * torch.log10(torch.mean((pred_images.detach() - gt_images) ** 2))
+ results['psnr'] = psnr
+
+ return results
diff --git a/libs/LGM/core/options.py b/libs/LGM/core/options.py
new file mode 100644
index 0000000000000000000000000000000000000000..f8c75bd4cc1f883cff55cfdae88a39d8c3ca7e4f
--- /dev/null
+++ b/libs/LGM/core/options.py
@@ -0,0 +1,120 @@
+import tyro
+from dataclasses import dataclass
+from typing import Tuple, Literal, Dict, Optional
+
+
+@dataclass
+class Options:
+ ### model
+ # Unet image input size
+ input_size: int = 256
+ # Unet definition
+ down_channels: Tuple[int, ...] = (64, 128, 256, 512, 1024, 1024)
+ down_attention: Tuple[bool, ...] = (False, False, False, True, True, True)
+ mid_attention: bool = True
+ up_channels: Tuple[int, ...] = (1024, 1024, 512, 256)
+ up_attention: Tuple[bool, ...] = (True, True, True, False)
+ # Unet output size, dependent on the input_size and U-Net structure!
+ splat_size: int = 64
+ # gaussian render size
+ output_size: int = 256
+
+ ### dataset
+ # data mode (only support s3 now)
+ data_mode: Literal['s3'] = 's3'
+ # fovy of the dataset
+ fovy: float = 49.1
+ # camera near plane
+ znear: float = 0.5
+ # camera far plane
+ zfar: float = 2.5
+ # number of all views (input + output)
+ num_views: int = 12
+ # number of views
+ num_input_views: int = 4
+ # camera radius
+ cam_radius: float = 1.5 # to better use [-1, 1]^3 space
+ # num workers
+ num_workers: int = 8
+
+ ### training
+ # workspace
+ workspace: str = './workspace'
+ # resume
+ resume: Optional[str] = None
+ # batch size (per-GPU)
+ batch_size: int = 8
+ # gradient accumulation
+ gradient_accumulation_steps: int = 1
+ # training epochs
+ num_epochs: int = 30
+ # lpips loss weight
+ lambda_lpips: float = 1.0
+ # gradient clip
+ gradient_clip: float = 1.0
+ # mixed precision
+ mixed_precision: str = 'bf16'
+ # learning rate
+ lr: float = 4e-4
+ # augmentation prob for grid distortion
+ prob_grid_distortion: float = 0.5
+ # augmentation prob for camera jitter
+ prob_cam_jitter: float = 0.5
+
+ ### testing
+ # test image path
+ test_path: Optional[str] = None
+
+ ### misc
+ # nvdiffrast backend setting
+ force_cuda_rast: bool = False
+ # render fancy video with gaussian scaling effect
+ fancy_video: bool = False
+
+
+# all the default settings
+config_defaults: Dict[str, Options] = {}
+config_doc: Dict[str, str] = {}
+
+config_doc['lrm'] = 'the default settings for LGM'
+config_defaults['lrm'] = Options()
+
+config_doc['small'] = 'small model with lower resolution Gaussians'
+config_defaults['small'] = Options(
+ input_size=256,
+ splat_size=64,
+ output_size=256,
+ batch_size=8,
+ gradient_accumulation_steps=1,
+ mixed_precision='bf16',
+)
+
+config_doc['big'] = 'big model with higher resolution Gaussians'
+config_defaults['big'] = Options(
+ input_size=256,
+ up_channels=(1024, 1024, 512, 256, 128), # one more decoder
+ up_attention=(True, True, True, False, False),
+ splat_size=128,
+ output_size=512, # render & supervise Gaussians at a higher resolution.
+ batch_size=8,
+ num_views=8,
+ gradient_accumulation_steps=1,
+ mixed_precision='bf16',
+)
+
+config_doc['tiny'] = 'tiny model for ablation'
+config_defaults['tiny'] = Options(
+ input_size=256,
+ down_channels=(32, 64, 128, 256, 512),
+ down_attention=(False, False, False, False, True),
+ up_channels=(512, 256, 128),
+ up_attention=(True, False, False, False),
+ splat_size=64,
+ output_size=256,
+ batch_size=16,
+ num_views=8,
+ gradient_accumulation_steps=1,
+ mixed_precision='bf16',
+)
+
+AllConfigs = tyro.extras.subcommand_type_from_defaults(config_defaults, config_doc)
\ No newline at end of file
diff --git a/libs/LGM/core/process_image.py b/libs/LGM/core/process_image.py
new file mode 100644
index 0000000000000000000000000000000000000000..415c7c81bb7f8df6a41ae3592a83725f8a6fdd57
--- /dev/null
+++ b/libs/LGM/core/process_image.py
@@ -0,0 +1,77 @@
+import os
+import rembg
+import numpy as np
+import cv2
+
+from PIL import Image
+
+def recenter(image, h_begin=100, w_begin=220, res=256):
+ h_image, w_image = image.shape[:2]
+ new_image = np.zeros((res, res, 4), dtype=np.uint8)
+ h_begin_new = -min(0, h_begin)
+ w_begin_new = -min(0, w_begin)
+ if h_begin > 0 and w_begin > 0:
+ new_image = image[h_begin:h_begin+res, w_begin:w_begin+res]
+ else:
+ new_image[h_begin_new:h_begin_new+h_image, w_begin_new:w_image] = image
+ return new_image
+
+def recover(image, original_size=(720, 480), h_begin=100, w_begin=220, res=256):
+ target_w, target_h = original_size
+ recovered_image = np.zeros((target_h, target_w, 4), dtype=np.uint8)
+ h_begin_new = -min(0, h_begin)
+ w_begin_new = -min(0, w_begin)
+ if h_begin > 0 and w_begin > 0:
+ recovered_image[h_begin:h_begin+res, w_begin:w_begin+res] = image
+ else:
+ recovered_image = image[h_begin_new:h_begin_new+target_h, w_begin_new:w_begin_new+target_w]
+ return recovered_image.astype(np.uint8)
+
+def resize_and_center_crop(image, target_h=480, target_w=720):
+
+ w, h = image.size
+ image_ratio = w / h
+
+ if target_w / target_h > image_ratio:
+ new_w = target_w
+ new_h = int(h * (target_w / w))
+ else:
+ new_h = target_h
+ new_w = int(w * (target_h / h))
+
+ image = image.resize((new_w, new_h), Image.LANCZOS)
+ left = max(0, (new_w - target_w) // 2)
+ top = max(0, (new_h - target_h) // 2)
+ right = left + target_w
+ bottom = top + target_h
+ image = image.crop((left, top, right, bottom))
+ return image
+
+if __name__ == "__main__":
+
+ base_dir = 'data_test'
+ task_name = 'plane'
+ raw_path = os.listdir(f'{base_dir}/raw_data')
+ bg_remover = rembg.new_session()
+
+ for image_path in raw_path:
+ if not f'{task_name}_original' in image_path:
+ continue
+ input_image = Image.open(f'{base_dir}/raw_data/{image_path}')
+ image = resize_and_center_crop(input_image)
+ image.save(f'{base_dir}/raw_data/{image_path[:-4]}_resized.png')
+ image.save(f'{base_dir}/{image_path.split("_")[0]}.png')
+ image = np.array(image)
+
+ carved_image = rembg.remove(image, session=bg_remover) # [H, W, 4]
+ Image.fromarray(carved_image).save(f'{base_dir}/raw_data/{image_path[:-4]}_carved.png')
+
+ ### Test
+ # mask = carved_image[..., -1] > 0
+ # image = recenter(carved_image)
+ # image = cv2.resize(image, (256, 256), interpolation=cv2.INTER_AREA)
+ # Image.fromarray(image).save(f'{base_dir}/raw_data/{image_path[:-4]}_recentered.png')
+ # image = cv2.resize(image, (280, 280), interpolation=cv2.INTER_AREA)
+ # image = recover(image, (720, 480))
+ # Image.fromarray(image).save(f'{base_dir}/raw_data/{image_path[:-4]}_recovered.png')
+
\ No newline at end of file
diff --git a/libs/LGM/core/provider_objaverse.py b/libs/LGM/core/provider_objaverse.py
new file mode 100644
index 0000000000000000000000000000000000000000..a90b773c75ccd5f21552f08cb8bff3630ef20782
--- /dev/null
+++ b/libs/LGM/core/provider_objaverse.py
@@ -0,0 +1,172 @@
+import os
+import cv2
+import random
+import numpy as np
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torchvision.transforms.functional as TF
+from torch.utils.data import Dataset
+
+import kiui
+from core.options import Options
+from core.utils import get_rays, grid_distortion, orbit_camera_jitter
+
+IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
+IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
+
+
+class ObjaverseDataset(Dataset):
+
+ def _warn(self):
+ raise NotImplementedError('this dataset is just an example and cannot be used directly, you should modify it to your own setting! (search keyword TODO)')
+
+ def __init__(self, opt: Options, training=True):
+
+ self.opt = opt
+ self.training = training
+
+ # TODO: remove this barrier
+ self._warn()
+
+ # TODO: load the list of objects for training
+ self.items = []
+ with open('TODO: file containing the list', 'r') as f:
+ for line in f.readlines():
+ self.items.append(line.strip())
+
+ # naive split
+ if self.training:
+ self.items = self.items[:-self.opt.batch_size]
+ else:
+ self.items = self.items[-self.opt.batch_size:]
+
+ # default camera intrinsics
+ self.tan_half_fov = np.tan(0.5 * np.deg2rad(self.opt.fovy))
+ self.proj_matrix = torch.zeros(4, 4, dtype=torch.float32)
+ self.proj_matrix[0, 0] = 1 / self.tan_half_fov
+ self.proj_matrix[1, 1] = 1 / self.tan_half_fov
+ self.proj_matrix[2, 2] = (self.opt.zfar + self.opt.znear) / (self.opt.zfar - self.opt.znear)
+ self.proj_matrix[3, 2] = - (self.opt.zfar * self.opt.znear) / (self.opt.zfar - self.opt.znear)
+ self.proj_matrix[2, 3] = 1
+
+
+ def __len__(self):
+ return len(self.items)
+
+ def __getitem__(self, idx):
+
+ uid = self.items[idx]
+ results = {}
+
+ # load num_views images
+ images = []
+ masks = []
+ cam_poses = []
+
+ vid_cnt = 0
+
+ # TODO: choose views, based on your rendering settings
+ if self.training:
+ # input views are in (36, 72), other views are randomly selected
+ vids = np.random.permutation(np.arange(36, 73))[:self.opt.num_input_views].tolist() + np.random.permutation(100).tolist()
+ else:
+ # fixed views
+ vids = np.arange(36, 73, 4).tolist() + np.arange(100).tolist()
+
+ for vid in vids:
+
+ image_path = os.path.join(uid, 'rgb', f'{vid:03d}.png')
+ camera_path = os.path.join(uid, 'pose', f'{vid:03d}.txt')
+
+ try:
+ # TODO: load data (modify self.client here)
+ image = np.frombuffer(self.client.get(image_path), np.uint8)
+ image = torch.from_numpy(cv2.imdecode(image, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255) # [512, 512, 4] in [0, 1]
+ c2w = [float(t) for t in self.client.get(camera_path).decode().strip().split(' ')]
+ c2w = torch.tensor(c2w, dtype=torch.float32).reshape(4, 4)
+ except Exception as e:
+ # print(f'[WARN] dataset {uid} {vid}: {e}')
+ continue
+
+ # TODO: you may have a different camera system
+ # blender world + opencv cam --> opengl world & cam
+ c2w[1] *= -1
+ c2w[[1, 2]] = c2w[[2, 1]]
+ c2w[:3, 1:3] *= -1 # invert up and forward direction
+
+ # scale up radius to fully use the [-1, 1]^3 space!
+ c2w[:3, 3] *= self.opt.cam_radius / 1.5 # 1.5 is the default scale
+
+ image = image.permute(2, 0, 1) # [4, 512, 512]
+ mask = image[3:4] # [1, 512, 512]
+ image = image[:3] * mask + (1 - mask) # [3, 512, 512], to white bg
+ image = image[[2,1,0]].contiguous() # bgr to rgb
+
+ images.append(image)
+ masks.append(mask.squeeze(0))
+ cam_poses.append(c2w)
+
+ vid_cnt += 1
+ if vid_cnt == self.opt.num_views:
+ break
+
+ if vid_cnt < self.opt.num_views:
+ print(f'[WARN] dataset {uid}: not enough valid views, only {vid_cnt} views found!')
+ n = self.opt.num_views - vid_cnt
+ images = images + [images[-1]] * n
+ masks = masks + [masks[-1]] * n
+ cam_poses = cam_poses + [cam_poses[-1]] * n
+
+ images = torch.stack(images, dim=0) # [V, C, H, W]
+ masks = torch.stack(masks, dim=0) # [V, H, W]
+ cam_poses = torch.stack(cam_poses, dim=0) # [V, 4, 4]
+
+ # normalized camera feats as in paper (transform the first pose to a fixed position)
+ transform = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, self.opt.cam_radius], [0, 0, 0, 1]], dtype=torch.float32) @ torch.inverse(cam_poses[0])
+ cam_poses = transform.unsqueeze(0) @ cam_poses # [V, 4, 4]
+
+ images_input = F.interpolate(images[:self.opt.num_input_views].clone(), size=(self.opt.input_size, self.opt.input_size), mode='bilinear', align_corners=False) # [V, C, H, W]
+ cam_poses_input = cam_poses[:self.opt.num_input_views].clone()
+
+ # data augmentation
+ if self.training:
+ # apply random grid distortion to simulate 3D inconsistency
+ if random.random() < self.opt.prob_grid_distortion:
+ images_input[1:] = grid_distortion(images_input[1:])
+ # apply camera jittering (only to input!)
+ if random.random() < self.opt.prob_cam_jitter:
+ cam_poses_input[1:] = orbit_camera_jitter(cam_poses_input[1:])
+
+ images_input = TF.normalize(images_input, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
+
+ # resize render ground-truth images, range still in [0, 1]
+ results['images_output'] = F.interpolate(images, size=(self.opt.output_size, self.opt.output_size), mode='bilinear', align_corners=False) # [V, C, output_size, output_size]
+ results['masks_output'] = F.interpolate(masks.unsqueeze(1), size=(self.opt.output_size, self.opt.output_size), mode='bilinear', align_corners=False) # [V, 1, output_size, output_size]
+
+ # build rays for input views
+ rays_embeddings = []
+ for i in range(self.opt.num_input_views):
+ rays_o, rays_d = get_rays(cam_poses_input[i], self.opt.input_size, self.opt.input_size, self.opt.fovy) # [h, w, 3]
+ rays_plucker = torch.cat([torch.cross(rays_o, rays_d, dim=-1), rays_d], dim=-1) # [h, w, 6]
+ rays_embeddings.append(rays_plucker)
+
+
+ rays_embeddings = torch.stack(rays_embeddings, dim=0).permute(0, 3, 1, 2).contiguous() # [V, 6, h, w]
+ final_input = torch.cat([images_input, rays_embeddings], dim=1) # [V=4, 9, H, W]
+ results['input'] = final_input
+
+ # opengl to colmap camera for gaussian renderer
+ cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
+
+ # cameras needed by gaussian rasterizer
+ cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
+ cam_view_proj = cam_view @ self.proj_matrix # [V, 4, 4]
+ cam_pos = - cam_poses[:, :3, 3] # [V, 3]
+
+ results['cam_view'] = cam_view
+ results['cam_view_proj'] = cam_view_proj
+ results['cam_pos'] = cam_pos
+
+ return results
\ No newline at end of file
diff --git a/libs/LGM/core/unet.py b/libs/LGM/core/unet.py
new file mode 100644
index 0000000000000000000000000000000000000000..4134e809d0bad8263874b77a217a7fef06309355
--- /dev/null
+++ b/libs/LGM/core/unet.py
@@ -0,0 +1,319 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+import numpy as np
+from typing import Tuple, Literal
+from functools import partial
+
+from core.attention import MemEffAttention
+
+class MVAttention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ proj_bias: bool = True,
+ attn_drop: float = 0.0,
+ proj_drop: float = 0.0,
+ groups: int = 32,
+ eps: float = 1e-5,
+ residual: bool = True,
+ skip_scale: float = 1,
+ num_frames: int = 4, # WARN: hardcoded!
+ ):
+ super().__init__()
+
+ self.residual = residual
+ self.skip_scale = skip_scale
+ self.num_frames = num_frames
+
+ self.norm = nn.GroupNorm(num_groups=groups, num_channels=dim, eps=eps, affine=True)
+ self.attn = MemEffAttention(dim, num_heads, qkv_bias, proj_bias, attn_drop, proj_drop)
+
+ def forward(self, x):
+ # x: [B*V, C, H, W]
+ BV, C, H, W = x.shape
+ B = BV // self.num_frames # assert BV % self.num_frames == 0
+
+ res = x
+ x = self.norm(x)
+
+ x = x.reshape(B, self.num_frames, C, H, W).permute(0, 1, 3, 4, 2).reshape(B, -1, C)
+ x = self.attn(x)
+ x = x.reshape(B, self.num_frames, H, W, C).permute(0, 1, 4, 2, 3).reshape(BV, C, H, W)
+
+ if self.residual:
+ x = (x + res) * self.skip_scale
+ return x
+
+class ResnetBlock(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ resample: Literal['default', 'up', 'down'] = 'default',
+ groups: int = 32,
+ eps: float = 1e-5,
+ skip_scale: float = 1, # multiplied to output
+ ):
+ super().__init__()
+
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.skip_scale = skip_scale
+
+ self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
+
+ self.norm2 = nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True)
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
+
+ self.act = F.silu
+
+ self.resample = None
+ if resample == 'up':
+ self.resample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
+ elif resample == 'down':
+ self.resample = nn.AvgPool2d(kernel_size=2, stride=2)
+
+ self.shortcut = nn.Identity()
+ if self.in_channels != self.out_channels:
+ self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=True)
+
+
+ def forward(self, x):
+ res = x
+
+ x = self.norm1(x)
+ x = self.act(x)
+
+ if self.resample:
+ res = self.resample(res)
+ x = self.resample(x)
+
+ x = self.conv1(x)
+ x = self.norm2(x)
+ x = self.act(x)
+ x = self.conv2(x)
+
+ x = (x + self.shortcut(res)) * self.skip_scale
+
+ return x
+
+class DownBlock(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ num_layers: int = 1,
+ downsample: bool = True,
+ attention: bool = True,
+ attention_heads: int = 16,
+ skip_scale: float = 1,
+ ):
+ super().__init__()
+
+ nets = []
+ attns = []
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ nets.append(ResnetBlock(in_channels, out_channels, skip_scale=skip_scale))
+ if attention:
+ attns.append(MVAttention(out_channels, attention_heads, skip_scale=skip_scale))
+ else:
+ attns.append(None)
+ self.nets = nn.ModuleList(nets)
+ self.attns = nn.ModuleList(attns)
+
+ self.downsample = None
+ if downsample:
+ self.downsample = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=2, padding=1)
+
+ def forward(self, x):
+ xs = []
+
+ for attn, net in zip(self.attns, self.nets):
+ x = net(x)
+ if attn:
+ x = attn(x)
+ xs.append(x)
+
+ if self.downsample:
+ x = self.downsample(x)
+ xs.append(x)
+
+ return x, xs
+
+
+class MidBlock(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ num_layers: int = 1,
+ attention: bool = True,
+ attention_heads: int = 16,
+ skip_scale: float = 1,
+ ):
+ super().__init__()
+
+ nets = []
+ attns = []
+ # first layer
+ nets.append(ResnetBlock(in_channels, in_channels, skip_scale=skip_scale))
+ # more layers
+ for i in range(num_layers):
+ nets.append(ResnetBlock(in_channels, in_channels, skip_scale=skip_scale))
+ if attention:
+ attns.append(MVAttention(in_channels, attention_heads, skip_scale=skip_scale))
+ else:
+ attns.append(None)
+ self.nets = nn.ModuleList(nets)
+ self.attns = nn.ModuleList(attns)
+
+ def forward(self, x):
+ x = self.nets[0](x)
+ for attn, net in zip(self.attns, self.nets[1:]):
+ if attn:
+ x = attn(x)
+ x = net(x)
+ return x
+
+
+class UpBlock(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ prev_out_channels: int,
+ out_channels: int,
+ num_layers: int = 1,
+ upsample: bool = True,
+ attention: bool = True,
+ attention_heads: int = 16,
+ skip_scale: float = 1,
+ ):
+ super().__init__()
+
+ nets = []
+ attns = []
+ for i in range(num_layers):
+ cin = in_channels if i == 0 else out_channels
+ cskip = prev_out_channels if (i == num_layers - 1) else out_channels
+
+ nets.append(ResnetBlock(cin + cskip, out_channels, skip_scale=skip_scale))
+ if attention:
+ attns.append(MVAttention(out_channels, attention_heads, skip_scale=skip_scale))
+ else:
+ attns.append(None)
+ self.nets = nn.ModuleList(nets)
+ self.attns = nn.ModuleList(attns)
+
+ self.upsample = None
+ if upsample:
+ self.upsample = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
+
+ def forward(self, x, xs):
+
+ for attn, net in zip(self.attns, self.nets):
+ res_x = xs[-1]
+ xs = xs[:-1]
+ x = torch.cat([x, res_x], dim=1)
+ x = net(x)
+ if attn:
+ x = attn(x)
+
+ if self.upsample:
+ x = F.interpolate(x, scale_factor=2.0, mode='nearest')
+ x = self.upsample(x)
+
+ return x
+
+
+# it could be asymmetric!
+class UNet(nn.Module):
+ def __init__(
+ self,
+ in_channels: int = 3,
+ out_channels: int = 3,
+ down_channels: Tuple[int, ...] = (64, 128, 256, 512, 1024),
+ down_attention: Tuple[bool, ...] = (False, False, False, True, True),
+ mid_attention: bool = True,
+ up_channels: Tuple[int, ...] = (1024, 512, 256),
+ up_attention: Tuple[bool, ...] = (True, True, False),
+ layers_per_block: int = 2,
+ skip_scale: float = np.sqrt(0.5),
+ ):
+ super().__init__()
+
+ # first
+ self.conv_in = nn.Conv2d(in_channels, down_channels[0], kernel_size=3, stride=1, padding=1)
+
+ # down
+ down_blocks = []
+ cout = down_channels[0]
+ for i in range(len(down_channels)):
+ cin = cout
+ cout = down_channels[i]
+
+ down_blocks.append(DownBlock(
+ cin, cout,
+ num_layers=layers_per_block,
+ downsample=(i != len(down_channels) - 1), # not final layer
+ attention=down_attention[i],
+ skip_scale=skip_scale,
+ ))
+ self.down_blocks = nn.ModuleList(down_blocks)
+
+ # mid
+ self.mid_block = MidBlock(down_channels[-1], attention=mid_attention, skip_scale=skip_scale)
+
+ # up
+ up_blocks = []
+ cout = up_channels[0]
+ for i in range(len(up_channels)):
+ cin = cout
+ cout = up_channels[i]
+ cskip = down_channels[max(-2 - i, -len(down_channels))] # for assymetric
+
+ up_blocks.append(UpBlock(
+ cin, cskip, cout,
+ num_layers=layers_per_block + 1, # one more layer for up
+ upsample=(i != len(up_channels) - 1), # not final layer
+ attention=up_attention[i],
+ skip_scale=skip_scale,
+ ))
+ self.up_blocks = nn.ModuleList(up_blocks)
+
+ # last
+ self.norm_out = nn.GroupNorm(num_channels=up_channels[-1], num_groups=32, eps=1e-5)
+ self.conv_out = nn.Conv2d(up_channels[-1], out_channels, kernel_size=3, stride=1, padding=1)
+
+
+ def forward(self, x):
+ # x: [B, Cin, H, W]
+
+ # first
+ x = self.conv_in(x)
+
+ # down
+ xss = [x]
+ for block in self.down_blocks:
+ x, xs = block(x)
+ xss.extend(xs)
+
+ # mid
+ x = self.mid_block(x)
+
+ # up
+ for block in self.up_blocks:
+ xs = xss[-len(block.nets):]
+ xss = xss[:-len(block.nets)]
+ x = block(x, xs)
+
+ # last
+ x = self.norm_out(x)
+ x = F.silu(x)
+ x = self.conv_out(x) # [B, Cout, H', W']
+
+ return x
diff --git a/libs/LGM/core/utils.py b/libs/LGM/core/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..17ab4f5116f6c40422efb65a9bd139ef07f9e41c
--- /dev/null
+++ b/libs/LGM/core/utils.py
@@ -0,0 +1,109 @@
+import numpy as np
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+import roma
+from kiui.op import safe_normalize
+
+def get_rays(pose, h, w, fovy, opengl=True):
+
+ x, y = torch.meshgrid(
+ torch.arange(w, device=pose.device),
+ torch.arange(h, device=pose.device),
+ indexing="xy",
+ )
+ x = x.flatten()
+ y = y.flatten()
+
+ cx = w * 0.5
+ cy = h * 0.5
+
+ focal = h * 0.5 / np.tan(0.5 * np.deg2rad(fovy))
+
+ camera_dirs = F.pad(
+ torch.stack(
+ [
+ (x - cx + 0.5) / focal,
+ (y - cy + 0.5) / focal * (-1.0 if opengl else 1.0),
+ ],
+ dim=-1,
+ ),
+ (0, 1),
+ value=(-1.0 if opengl else 1.0),
+ ) # [hw, 3]
+
+ rays_d = camera_dirs @ pose[:3, :3].transpose(0, 1) # [hw, 3]
+ rays_o = pose[:3, 3].unsqueeze(0).expand_as(rays_d) # [hw, 3]
+
+ rays_o = rays_o.view(h, w, 3)
+ rays_d = safe_normalize(rays_d).view(h, w, 3)
+
+ return rays_o, rays_d
+
+def orbit_camera_jitter(poses, strength=0.1):
+ # poses: [B, 4, 4], assume orbit camera in opengl format
+ # random orbital rotate
+
+ B = poses.shape[0]
+ rotvec_x = poses[:, :3, 1] * strength * np.pi * (torch.rand(B, 1, device=poses.device) * 2 - 1)
+ rotvec_y = poses[:, :3, 0] * strength * np.pi / 2 * (torch.rand(B, 1, device=poses.device) * 2 - 1)
+
+ rot = roma.rotvec_to_rotmat(rotvec_x) @ roma.rotvec_to_rotmat(rotvec_y)
+ R = rot @ poses[:, :3, :3]
+ T = rot @ poses[:, :3, 3:]
+
+ new_poses = poses.clone()
+ new_poses[:, :3, :3] = R
+ new_poses[:, :3, 3:] = T
+
+ return new_poses
+
+def grid_distortion(images, strength=0.5):
+ # images: [B, C, H, W]
+ # num_steps: int, grid resolution for distortion
+ # strength: float in [0, 1], strength of distortion
+
+ B, C, H, W = images.shape
+
+ num_steps = np.random.randint(8, 17)
+ grid_steps = torch.linspace(-1, 1, num_steps)
+
+ # have to loop batch...
+ grids = []
+ for b in range(B):
+ # construct displacement
+ x_steps = torch.linspace(0, 1, num_steps) # [num_steps], inclusive
+ x_steps = (x_steps + strength * (torch.rand_like(x_steps) - 0.5) / (num_steps - 1)).clamp(0, 1) # perturb
+ x_steps = (x_steps * W).long() # [num_steps]
+ x_steps[0] = 0
+ x_steps[-1] = W
+ xs = []
+ for i in range(num_steps - 1):
+ xs.append(torch.linspace(grid_steps[i], grid_steps[i + 1], x_steps[i + 1] - x_steps[i]))
+ xs = torch.cat(xs, dim=0) # [W]
+
+ y_steps = torch.linspace(0, 1, num_steps) # [num_steps], inclusive
+ y_steps = (y_steps + strength * (torch.rand_like(y_steps) - 0.5) / (num_steps - 1)).clamp(0, 1) # perturb
+ y_steps = (y_steps * H).long() # [num_steps]
+ y_steps[0] = 0
+ y_steps[-1] = H
+ ys = []
+ for i in range(num_steps - 1):
+ ys.append(torch.linspace(grid_steps[i], grid_steps[i + 1], y_steps[i + 1] - y_steps[i]))
+ ys = torch.cat(ys, dim=0) # [H]
+
+ # construct grid
+ grid_x, grid_y = torch.meshgrid(xs, ys, indexing='xy') # [H, W]
+ grid = torch.stack([grid_x, grid_y], dim=-1) # [H, W, 2]
+
+ grids.append(grid)
+
+ grids = torch.stack(grids, dim=0).to(images.device) # [B, H, W, 2]
+
+ # grid sample
+ images = F.grid_sample(images, grids, align_corners=False)
+
+ return images
+
diff --git a/libs/LGM/gui.py b/libs/LGM/gui.py
new file mode 100644
index 0000000000000000000000000000000000000000..c5188c7b06800de5dc3675bc7503e10c56b06608
--- /dev/null
+++ b/libs/LGM/gui.py
@@ -0,0 +1,294 @@
+
+import os
+import tyro
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from core.options import AllConfigs, Options
+from core.gs import GaussianRenderer
+
+import dearpygui.dearpygui as dpg
+
+import kiui
+from kiui.cam import OrbitCamera
+
+
+class GUI:
+ def __init__(self, opt: Options):
+ self.opt = opt
+ self.W = opt.output_size
+ self.H = opt.output_size
+ self.cam = OrbitCamera(self.W, self.H, r=opt.cam_radius, fovy=opt.fovy)
+
+ self.device = torch.device("cuda")
+
+ self.tan_half_fov = np.tan(0.5 * np.deg2rad(opt.fovy))
+ self.proj_matrix = torch.zeros(4, 4, dtype=torch.float32, device=self.device)
+ self.proj_matrix[0, 0] = 1 / self.tan_half_fov
+ self.proj_matrix[1, 1] = 1 / self.tan_half_fov
+ self.proj_matrix[2, 2] = (opt.zfar + opt.znear) / (opt.zfar - opt.znear)
+ self.proj_matrix[3, 2] = - (opt.zfar * opt.znear) / (opt.zfar - opt.znear)
+ self.proj_matrix[2, 3] = 1
+
+ self.mode = "image"
+
+ self.buffer_image = np.ones((self.W, self.H, 3), dtype=np.float32)
+ self.need_update = True # update buffer_image
+
+ # renderer
+ self.renderer = GaussianRenderer(opt)
+ self.gaussain_scale_factor = 1
+
+ self.gaussians = self.renderer.load_ply(opt.test_path).to(self.device)
+
+ dpg.create_context()
+ self.register_dpg()
+ self.test_step()
+
+ def __del__(self):
+ dpg.destroy_context()
+
+ @torch.no_grad()
+ def test_step(self):
+ # ignore if no need to update
+ if not self.need_update:
+ return
+
+ starter = torch.cuda.Event(enable_timing=True)
+ ender = torch.cuda.Event(enable_timing=True)
+ starter.record()
+
+ # should update image
+ if self.need_update:
+ # render image
+
+ cam_poses = torch.from_numpy(self.cam.pose).unsqueeze(0).to(self.device)
+ cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
+
+ # cameras needed by gaussian rasterizer
+ cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
+ cam_view_proj = cam_view @ self.proj_matrix # [V, 4, 4]
+ cam_pos = - cam_poses[:, :3, 3] # [V, 3]
+
+ buffer_image = self.renderer.render(self.gaussians.unsqueeze(0), cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), scale_modifier=self.gaussain_scale_factor)[self.mode]
+ buffer_image = buffer_image.squeeze(1) # [B, C, H, W]
+
+ if self.mode in ['alpha']:
+ buffer_image = buffer_image.repeat(1, 3, 1, 1)
+
+ buffer_image = F.interpolate(
+ buffer_image,
+ size=(self.H, self.W),
+ mode="bilinear",
+ align_corners=False,
+ ).squeeze(0)
+
+ self.buffer_image = (
+ buffer_image.permute(1, 2, 0)
+ .contiguous()
+ .clamp(0, 1)
+ .contiguous()
+ .detach()
+ .cpu()
+ .numpy()
+ )
+
+ self.need_update = False
+
+ ender.record()
+ torch.cuda.synchronize()
+ t = starter.elapsed_time(ender)
+
+ dpg.set_value("_log_infer_time", f"{t:.4f}ms ({int(1000/t)} FPS)")
+ dpg.set_value(
+ "_texture", self.buffer_image
+ ) # buffer must be contiguous, else seg fault!
+
+ def register_dpg(self):
+ ### register texture
+
+ with dpg.texture_registry(show=False):
+ dpg.add_raw_texture(
+ self.W,
+ self.H,
+ self.buffer_image,
+ format=dpg.mvFormat_Float_rgb,
+ tag="_texture",
+ )
+
+ ### register window
+
+ # the rendered image, as the primary window
+ with dpg.window(
+ tag="_primary_window",
+ width=self.W,
+ height=self.H,
+ pos=[0, 0],
+ no_move=True,
+ no_title_bar=True,
+ no_scrollbar=True,
+ ):
+ # add the texture
+ dpg.add_image("_texture")
+
+ # dpg.set_primary_window("_primary_window", True)
+
+ # control window
+ with dpg.window(
+ label="Control",
+ tag="_control_window",
+ width=600,
+ height=self.H,
+ pos=[self.W, 0],
+ no_move=True,
+ no_title_bar=True,
+ ):
+ # button theme
+ with dpg.theme() as theme_button:
+ with dpg.theme_component(dpg.mvButton):
+ dpg.add_theme_color(dpg.mvThemeCol_Button, (23, 3, 18))
+ dpg.add_theme_color(dpg.mvThemeCol_ButtonHovered, (51, 3, 47))
+ dpg.add_theme_color(dpg.mvThemeCol_ButtonActive, (83, 18, 83))
+ dpg.add_theme_style(dpg.mvStyleVar_FrameRounding, 5)
+ dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 3, 3)
+
+ # timer stuff
+ with dpg.group(horizontal=True):
+ dpg.add_text("Infer time: ")
+ dpg.add_text("no data", tag="_log_infer_time")
+
+ # rendering options
+ with dpg.collapsing_header(label="Rendering", default_open=True):
+ # mode combo
+ def callback_change_mode(sender, app_data):
+ self.mode = app_data
+ self.need_update = True
+
+ dpg.add_combo(
+ ("image", "alpha"),
+ label="mode",
+ default_value=self.mode,
+ callback=callback_change_mode,
+ )
+
+ # fov slider
+ def callback_set_fovy(sender, app_data):
+ self.cam.fovy = np.deg2rad(app_data)
+ self.need_update = True
+
+ dpg.add_slider_int(
+ label="FoV (vertical)",
+ min_value=1,
+ max_value=120,
+ format="%d deg",
+ default_value=np.rad2deg(self.cam.fovy),
+ callback=callback_set_fovy,
+ )
+
+ def callback_set_gaussain_scale(sender, app_data):
+ self.gaussain_scale_factor = app_data
+ self.need_update = True
+
+ dpg.add_slider_float(
+ label="gaussain scale",
+ min_value=0,
+ max_value=1,
+ format="%.2f",
+ default_value=self.gaussain_scale_factor,
+ callback=callback_set_gaussain_scale,
+ )
+
+ ### register camera handler
+
+ def callback_camera_drag_rotate(sender, app_data):
+ if not dpg.is_item_focused("_primary_window"):
+ return
+
+ dx = app_data[1]
+ dy = app_data[2]
+
+ self.cam.orbit(dx, dy)
+ self.need_update = True
+
+ def callback_camera_wheel_scale(sender, app_data):
+ if not dpg.is_item_focused("_primary_window"):
+ return
+
+ delta = app_data
+
+ self.cam.scale(delta)
+ self.need_update = True
+
+ def callback_camera_drag_pan(sender, app_data):
+ if not dpg.is_item_focused("_primary_window"):
+ return
+
+ dx = app_data[1]
+ dy = app_data[2]
+
+ self.cam.pan(dx, dy)
+ self.need_update = True
+
+ with dpg.handler_registry():
+ # for camera moving
+ dpg.add_mouse_drag_handler(
+ button=dpg.mvMouseButton_Left,
+ callback=callback_camera_drag_rotate,
+ )
+ dpg.add_mouse_wheel_handler(callback=callback_camera_wheel_scale)
+ dpg.add_mouse_drag_handler(
+ button=dpg.mvMouseButton_Middle, callback=callback_camera_drag_pan
+ )
+
+ dpg.create_viewport(
+ title="Gaussian3D",
+ width=self.W + 600,
+ height=self.H + (45 if os.name == "nt" else 0),
+ resizable=False,
+ )
+
+ ### global theme
+ with dpg.theme() as theme_no_padding:
+ with dpg.theme_component(dpg.mvAll):
+ # set all padding to 0 to avoid scroll bar
+ dpg.add_theme_style(
+ dpg.mvStyleVar_WindowPadding, 0, 0, category=dpg.mvThemeCat_Core
+ )
+ dpg.add_theme_style(
+ dpg.mvStyleVar_FramePadding, 0, 0, category=dpg.mvThemeCat_Core
+ )
+ dpg.add_theme_style(
+ dpg.mvStyleVar_CellPadding, 0, 0, category=dpg.mvThemeCat_Core
+ )
+
+ dpg.bind_item_theme("_primary_window", theme_no_padding)
+
+ dpg.setup_dearpygui()
+
+ ### register a larger font
+ # get it from: https://github.com/lxgw/LxgwWenKai/releases/download/v1.300/LXGWWenKai-Regular.ttf
+ if os.path.exists("LXGWWenKai-Regular.ttf"):
+ with dpg.font_registry():
+ with dpg.font("LXGWWenKai-Regular.ttf", 18) as default_font:
+ dpg.bind_font(default_font)
+
+ # dpg.show_metrics()
+
+ dpg.show_viewport()
+
+ def render(self):
+ while dpg.is_dearpygui_running():
+ # update texture every frame
+ self.test_step()
+ dpg.render_dearpygui_frame()
+
+
+opt = tyro.cli(AllConfigs)
+
+# load a saved ply and visualize
+assert opt.test_path.endswith('.ply'), '--test_path must be a .ply file saved by infer.py'
+
+gui = GUI(opt)
+gui.render()
\ No newline at end of file
diff --git a/libs/LGM/infer.py b/libs/LGM/infer.py
new file mode 100644
index 0000000000000000000000000000000000000000..67422d45fe33ab72337a8f2880ef3cb920b16138
--- /dev/null
+++ b/libs/LGM/infer.py
@@ -0,0 +1,199 @@
+
+import os
+import tyro
+import glob
+import imageio
+import numpy as np
+import tqdm
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torchvision.transforms.functional as TF
+from safetensors.torch import load_file
+import rembg
+
+import kiui
+import cv2
+import numpy as np
+
+from kiui.cam import orbit_camera
+
+from core.process_image import recenter
+from core.options import AllConfigs, Options
+from core.models import LGM
+from mvdream.pipeline_mvdream import MVDreamPipeline
+from PIL import Image
+
+IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
+IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
+
+# Ball(11,1) Plane(0,3)
+torch.random.manual_seed(11)
+np.random.seed(11)
+
+opt = tyro.cli(AllConfigs)
+
+# model
+model = LGM(opt)
+
+# resume pretrained checkpoint
+if opt.resume is not None:
+ if opt.resume.endswith('safetensors'):
+ ckpt = load_file(opt.resume, device='cpu')
+ else:
+ ckpt = torch.load(opt.resume, map_location='cpu')
+ model.load_state_dict(ckpt, strict=False)
+ print(f'[INFO] Loaded checkpoint from {opt.resume}')
+else:
+ print(f'[WARN] model randomly initialized, are you sure?')
+
+# device
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+model = model.half().to(device)
+model.eval()
+
+rays_embeddings = model.prepare_default_rays(device)
+
+tan_half_fov = np.tan(0.5 * np.deg2rad(opt.fovy))
+proj_matrix = torch.zeros(4, 4, dtype=torch.float32, device=device)
+proj_matrix[0, 0] = 1 / tan_half_fov
+proj_matrix[1, 1] = 1 / tan_half_fov
+proj_matrix[2, 2] = (opt.zfar + opt.znear) / (opt.zfar - opt.znear)
+proj_matrix[3, 2] = - (opt.zfar * opt.znear) / (opt.zfar - opt.znear)
+proj_matrix[2, 3] = 1
+
+# load image dream
+pipe = MVDreamPipeline.from_pretrained(
+ "ashawkey/imagedream-ipmv-diffusers", # remote weights
+ torch_dtype=torch.float16,
+ trust_remote_code=True,
+ # local_files_only=True,
+)
+pipe = pipe.to(device)
+
+# load rembg
+bg_remover = rembg.new_session()
+
+# process function
+def process(opt: Options, path, task_name):
+
+ name = os.path.splitext(os.path.basename(path))[0]
+ print(f'[INFO] Processing {path} --> {name}')
+ os.makedirs(opt.workspace, exist_ok=True)
+
+ input_image = kiui.read_image(path, mode='uint8')
+
+ # bg removal
+ carved_image = rembg.remove(input_image, session=bg_remover) # [H, W, 4]
+ mask = carved_image[..., -1] > 0
+
+ # recenter
+ h_begin, w_begin, res = 50, 160, 380
+ # h_begin, w_begin, res = -120, 0, 720
+ image = recenter(carved_image, h_begin, w_begin, res)
+
+ # generate mv
+ image = image.astype(np.float32) / 255.0
+
+ # rgba to rgb white bg
+ if image.shape[-1] == 4:
+ image = image[..., :3] * image[..., 3:4] + (1 - image[..., 3:4])
+
+ # image_out = Image.fromarray((image * 255).astype(np.uint8))
+ # image_out = image_out.resize((256, 256), Image.BILINEAR)
+ # image_out.save(os.path.join(opt.workspace, name + '_input.png'))
+ # image = np.array(image_out).astype(np.float32) / 255.0
+ # mv_image = pipe('A basketball', image, guidance_scale=5.0, num_inference_steps=30, elevation=0)
+ # mv_image = np.stack([image, mv_image[2], mv_image[3], mv_image[0]], axis=0) # [4, 256, 256, 3], float32
+ # mv_image = np.stack([mv_image[1], mv_image[2], mv_image[3], image], axis=0) # [4, 256, 256, 3], float32
+ # mv_image = np.stack([mv_image[0], mv_image[2], mv_image[1], mv_image[3]], axis=0) # [4, 256, 256, 3], float32
+ # for i in range(4):
+ # Image.fromarray((mv_image[i] * 255).astype(np.uint8)).save(os.path.join(opt.workspace, name + f'_{i}.png'))
+
+ images = []
+ for i in range(4):
+ image = imageio.imread(os.path.join(opt.workspace, name + f'_{i}.png'))
+ image = image.astype(np.float32) / 255.0
+ if image.shape[-1] == 4:
+ image = image[..., :3] * image[..., 3:4] + (1 - image[..., 3:4])
+ images.append(image)
+ mv_image = np.stack(images, axis=0)
+
+ # generate gaussians
+ input_image = torch.from_numpy(mv_image).permute(0, 3, 1, 2).float().to(device) # [4, 3, 256, 256]
+ input_image = F.interpolate(input_image, size=(opt.input_size, opt.input_size), mode='bilinear', align_corners=False)
+ input_image = TF.normalize(input_image, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
+
+ input_image = torch.cat([input_image, rays_embeddings], dim=1).unsqueeze(0) # [1, 4, 9, H, W]
+
+ with torch.no_grad():
+ with torch.autocast(device_type='cuda', dtype=torch.float16):
+ # generate gaussians
+ gaussians = model.forward_gaussians(input_image)
+
+ # save gaussians
+ model.gs.save_ply(gaussians, os.path.join(opt.workspace, name + '.ply'))
+
+ # render front view
+ cam_poses = torch.from_numpy(orbit_camera(0, 0, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)
+ cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
+ cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
+ cam_view_proj = cam_view @ proj_matrix # [V, 4, 4]
+
+ os.makedirs(f'../../data/{task_name}_rendered', exist_ok=True)
+ np.save(f'../../data/{task_name}_rendered/projection.npy', cam_view_proj[0].cpu().numpy())
+
+ cam_pos = - cam_poses[:, :3, 3] # [V, 3]
+ image = model.gs.render(gaussians, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), scale_modifier=1)['image']
+ image_save = (image[0, 0].permute(1, 2, 0).contiguous().float().cpu().numpy() * 255).astype(np.uint8)
+ Image.fromarray(image_save).save(os.path.join(opt.workspace, name + '_front_view.png'))
+
+ # render 360 video
+ images = []
+ elevation = 0
+
+ if opt.fancy_video:
+ azimuth = np.arange(0, 720, 4, dtype=np.int32)
+ for azi in tqdm.tqdm(azimuth):
+
+ cam_poses = torch.from_numpy(orbit_camera(elevation, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)
+
+ cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
+
+ # cameras needed by gaussian rasterizer
+ cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
+ cam_view_proj = cam_view @ proj_matrix # [V, 4, 4]
+ cam_pos = - cam_poses[:, :3, 3] # [V, 3]
+ scale = min(azi / 360, 1)
+ image = model.gs.render(gaussians, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), scale_modifier=scale)['image']
+ images.append((image.squeeze(1).permute(0,2,3,1).contiguous().float().cpu().numpy() * 255).astype(np.uint8))
+ else:
+ azimuth = np.arange(0, 360, 2, dtype=np.int32)
+ for azi in tqdm.tqdm(azimuth):
+
+ cam_poses = torch.from_numpy(orbit_camera(elevation, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)
+
+ cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
+
+ # cameras needed by gaussian rasterizer
+ cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
+ cam_view_proj = cam_view @ proj_matrix # [V, 4, 4]
+ cam_pos = - cam_poses[:, :3, 3] # [V, 3]
+
+ image = model.gs.render(gaussians, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), scale_modifier=1)['image']
+ images.append((image.squeeze(1).permute(0,2,3,1).contiguous().float().cpu().numpy() * 255).astype(np.uint8))
+
+ images = np.concatenate(images, axis=0)
+ imageio.mimwrite(os.path.join(opt.workspace, name + '.mp4'), images, fps=30)
+
+assert opt.test_path is not None
+if os.path.isdir(opt.test_path):
+ file_paths = glob.glob(os.path.join(opt.test_path, "*"))
+else:
+ file_paths = [opt.test_path]
+
+task_name = 'chair'
+for path in file_paths:
+ if not task_name in path:
+ continue
+ process(opt, path, task_name=task_name)
diff --git a/libs/LGM/main.py b/libs/LGM/main.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d658e78efcf480da907aa7ead7735e8910fcaa8
--- /dev/null
+++ b/libs/LGM/main.py
@@ -0,0 +1,185 @@
+import tyro
+import time
+import random
+
+import torch
+from core.options import AllConfigs
+from core.models import LGM
+from accelerate import Accelerator, DistributedDataParallelKwargs
+from safetensors.torch import load_file
+
+import kiui
+
+def main():
+ opt = tyro.cli(AllConfigs)
+
+ # ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
+
+ accelerator = Accelerator(
+ mixed_precision=opt.mixed_precision,
+ gradient_accumulation_steps=opt.gradient_accumulation_steps,
+ # kwargs_handlers=[ddp_kwargs],
+ )
+
+ # model
+ model = LGM(opt)
+
+ # resume
+ if opt.resume is not None:
+ if opt.resume.endswith('safetensors'):
+ ckpt = load_file(opt.resume, device='cpu')
+ else:
+ ckpt = torch.load(opt.resume, map_location='cpu')
+
+ # tolerant load (only load matching shapes)
+ # model.load_state_dict(ckpt, strict=False)
+ state_dict = model.state_dict()
+ for k, v in ckpt.items():
+ if k in state_dict:
+ if state_dict[k].shape == v.shape:
+ state_dict[k].copy_(v)
+ else:
+ accelerator.print(f'[WARN] mismatching shape for param {k}: ckpt {v.shape} != model {state_dict[k].shape}, ignored.')
+ else:
+ accelerator.print(f'[WARN] unexpected param {k}: {v.shape}')
+
+ # data
+ if opt.data_mode == 's3':
+ from core.provider_objaverse import ObjaverseDataset as Dataset
+ else:
+ raise NotImplementedError
+
+ train_dataset = Dataset(opt, training=True)
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ batch_size=opt.batch_size,
+ shuffle=True,
+ num_workers=opt.num_workers,
+ pin_memory=True,
+ drop_last=True,
+ )
+
+ test_dataset = Dataset(opt, training=False)
+ test_dataloader = torch.utils.data.DataLoader(
+ test_dataset,
+ batch_size=opt.batch_size,
+ shuffle=False,
+ num_workers=0,
+ pin_memory=True,
+ drop_last=False,
+ )
+
+ # optimizer
+ optimizer = torch.optim.AdamW(model.parameters(), lr=opt.lr, weight_decay=0.05, betas=(0.9, 0.95))
+
+ # scheduler (per-iteration)
+ # scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=3000, eta_min=1e-6)
+ total_steps = opt.num_epochs * len(train_dataloader)
+ pct_start = 3000 / total_steps
+ scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=opt.lr, total_steps=total_steps, pct_start=pct_start)
+
+ # accelerate
+ model, optimizer, train_dataloader, test_dataloader, scheduler = accelerator.prepare(
+ model, optimizer, train_dataloader, test_dataloader, scheduler
+ )
+
+ # loop
+ for epoch in range(opt.num_epochs):
+ # train
+ model.train()
+ total_loss = 0
+ total_psnr = 0
+ for i, data in enumerate(train_dataloader):
+ with accelerator.accumulate(model):
+
+ optimizer.zero_grad()
+
+ step_ratio = (epoch + i / len(train_dataloader)) / opt.num_epochs
+
+ out = model(data, step_ratio)
+ loss = out['loss']
+ psnr = out['psnr']
+ accelerator.backward(loss)
+
+ # gradient clipping
+ if accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(model.parameters(), opt.gradient_clip)
+
+ optimizer.step()
+ scheduler.step()
+
+ total_loss += loss.detach()
+ total_psnr += psnr.detach()
+
+ if accelerator.is_main_process:
+ # logging
+ if i % 100 == 0:
+ mem_free, mem_total = torch.cuda.mem_get_info()
+ print(f"[INFO] {i}/{len(train_dataloader)} mem: {(mem_total-mem_free)/1024**3:.2f}/{mem_total/1024**3:.2f}G lr: {scheduler.get_last_lr()[0]:.7f} step_ratio: {step_ratio:.4f} loss: {loss.item():.6f}")
+
+ # save log images
+ if i % 500 == 0:
+ gt_images = data['images_output'].detach().cpu().numpy() # [B, V, 3, output_size, output_size]
+ gt_images = gt_images.transpose(0, 3, 1, 4, 2).reshape(-1, gt_images.shape[1] * gt_images.shape[3], 3) # [B*output_size, V*output_size, 3]
+ kiui.write_image(f'{opt.workspace}/train_gt_images_{epoch}_{i}.jpg', gt_images)
+
+ # gt_alphas = data['masks_output'].detach().cpu().numpy() # [B, V, 1, output_size, output_size]
+ # gt_alphas = gt_alphas.transpose(0, 3, 1, 4, 2).reshape(-1, gt_alphas.shape[1] * gt_alphas.shape[3], 1)
+ # kiui.write_image(f'{opt.workspace}/train_gt_alphas_{epoch}_{i}.jpg', gt_alphas)
+
+ pred_images = out['images_pred'].detach().cpu().numpy() # [B, V, 3, output_size, output_size]
+ pred_images = pred_images.transpose(0, 3, 1, 4, 2).reshape(-1, pred_images.shape[1] * pred_images.shape[3], 3)
+ kiui.write_image(f'{opt.workspace}/train_pred_images_{epoch}_{i}.jpg', pred_images)
+
+ # pred_alphas = out['alphas_pred'].detach().cpu().numpy() # [B, V, 1, output_size, output_size]
+ # pred_alphas = pred_alphas.transpose(0, 3, 1, 4, 2).reshape(-1, pred_alphas.shape[1] * pred_alphas.shape[3], 1)
+ # kiui.write_image(f'{opt.workspace}/train_pred_alphas_{epoch}_{i}.jpg', pred_alphas)
+
+ total_loss = accelerator.gather_for_metrics(total_loss).mean()
+ total_psnr = accelerator.gather_for_metrics(total_psnr).mean()
+ if accelerator.is_main_process:
+ total_loss /= len(train_dataloader)
+ total_psnr /= len(train_dataloader)
+ accelerator.print(f"[train] epoch: {epoch} loss: {total_loss.item():.6f} psnr: {total_psnr.item():.4f}")
+
+ # checkpoint
+ # if epoch % 10 == 0 or epoch == opt.num_epochs - 1:
+ accelerator.wait_for_everyone()
+ accelerator.save_model(model, opt.workspace)
+
+ # eval
+ with torch.no_grad():
+ model.eval()
+ total_psnr = 0
+ for i, data in enumerate(test_dataloader):
+
+ out = model(data)
+
+ psnr = out['psnr']
+ total_psnr += psnr.detach()
+
+ # save some images
+ if accelerator.is_main_process:
+ gt_images = data['images_output'].detach().cpu().numpy() # [B, V, 3, output_size, output_size]
+ gt_images = gt_images.transpose(0, 3, 1, 4, 2).reshape(-1, gt_images.shape[1] * gt_images.shape[3], 3) # [B*output_size, V*output_size, 3]
+ kiui.write_image(f'{opt.workspace}/eval_gt_images_{epoch}_{i}.jpg', gt_images)
+
+ pred_images = out['images_pred'].detach().cpu().numpy() # [B, V, 3, output_size, output_size]
+ pred_images = pred_images.transpose(0, 3, 1, 4, 2).reshape(-1, pred_images.shape[1] * pred_images.shape[3], 3)
+ kiui.write_image(f'{opt.workspace}/eval_pred_images_{epoch}_{i}.jpg', pred_images)
+
+ # pred_alphas = out['alphas_pred'].detach().cpu().numpy() # [B, V, 1, output_size, output_size]
+ # pred_alphas = pred_alphas.transpose(0, 3, 1, 4, 2).reshape(-1, pred_alphas.shape[1] * pred_alphas.shape[3], 1)
+ # kiui.write_image(f'{opt.workspace}/eval_pred_alphas_{epoch}_{i}.jpg', pred_alphas)
+
+ torch.cuda.empty_cache()
+
+ total_psnr = accelerator.gather_for_metrics(total_psnr).mean()
+ if accelerator.is_main_process:
+ total_psnr /= len(test_dataloader)
+ accelerator.print(f"[eval] epoch: {epoch} psnr: {psnr:.4f}")
+
+
+
+if __name__ == "__main__":
+ main()
diff --git a/libs/LGM/mvdream/mv_unet.py b/libs/LGM/mvdream/mv_unet.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d9ad4def5910394eb64b36f9f76c98e8eaf80ae
--- /dev/null
+++ b/libs/LGM/mvdream/mv_unet.py
@@ -0,0 +1,1005 @@
+import math
+import numpy as np
+from inspect import isfunction
+from typing import Optional, Any, List
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange, repeat
+
+from diffusers.configuration_utils import ConfigMixin
+from diffusers.models.modeling_utils import ModelMixin
+
+# require xformers!
+import xformers
+import xformers.ops
+
+from kiui.cam import orbit_camera
+
+def get_camera(
+ num_frames, elevation=0, azimuth_start=0, azimuth_span=360, blender_coord=True, extra_view=False,
+):
+ angle_gap = azimuth_span / num_frames
+ cameras = []
+ for azimuth in np.arange(azimuth_start, azimuth_span + azimuth_start, angle_gap):
+
+ pose = orbit_camera(elevation, azimuth, radius=1) # [4, 4]
+
+ # opengl to blender
+ if blender_coord:
+ pose[2] *= -1
+ pose[[1, 2]] = pose[[2, 1]]
+
+ cameras.append(pose.flatten())
+
+ if extra_view:
+ cameras.append(np.zeros_like(cameras[0]))
+
+ return torch.from_numpy(np.stack(cameras, axis=0)).float() # [num_frames, 16]
+
+
+def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
+ """
+ Create sinusoidal timestep embeddings.
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an [N x dim] Tensor of positional embeddings.
+ """
+ if not repeat_only:
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period)
+ * torch.arange(start=0, end=half, dtype=torch.float32)
+ / half
+ ).to(device=timesteps.device)
+ args = timesteps[:, None] * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat(
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
+ )
+ else:
+ embedding = repeat(timesteps, "b -> b d", d=dim)
+ # import pdb; pdb.set_trace()
+ return embedding
+
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+def conv_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D convolution module.
+ """
+ if dims == 1:
+ return nn.Conv1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.Conv2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.Conv3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+def avg_pool_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D average pooling module.
+ """
+ if dims == 1:
+ return nn.AvgPool1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.AvgPool2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.AvgPool3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+def default(val, d):
+ if val is not None:
+ return val
+ return d() if isfunction(d) else d
+
+
+class GEGLU(nn.Module):
+ def __init__(self, dim_in, dim_out):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out * 2)
+
+ def forward(self, x):
+ x, gate = self.proj(x).chunk(2, dim=-1)
+ return x * F.gelu(gate)
+
+
+class FeedForward(nn.Module):
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
+ super().__init__()
+ inner_dim = int(dim * mult)
+ dim_out = default(dim_out, dim)
+ project_in = (
+ nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
+ if not glu
+ else GEGLU(dim, inner_dim)
+ )
+
+ self.net = nn.Sequential(
+ project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
+ )
+
+ def forward(self, x):
+ return self.net(x)
+
+
+class MemoryEfficientCrossAttention(nn.Module):
+ # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
+ def __init__(
+ self,
+ query_dim,
+ context_dim=None,
+ heads=8,
+ dim_head=64,
+ dropout=0.0,
+ ip_dim=0,
+ ip_weight=1,
+ ):
+ super().__init__()
+
+ inner_dim = dim_head * heads
+ context_dim = default(context_dim, query_dim)
+
+ self.heads = heads
+ self.dim_head = dim_head
+
+ self.ip_dim = ip_dim
+ self.ip_weight = ip_weight
+
+ if self.ip_dim > 0:
+ self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_v_ip = nn.Linear(context_dim, inner_dim, bias=False)
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+
+ self.to_out = nn.Sequential(
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
+ )
+ self.attention_op: Optional[Any] = None
+
+ def forward(self, x, context=None):
+ q = self.to_q(x)
+ context = default(context, x)
+
+ if self.ip_dim > 0:
+ # context: [B, 77 + 16(ip), 1024]
+ token_len = context.shape[1]
+ context_ip = context[:, -self.ip_dim :, :]
+ k_ip = self.to_k_ip(context_ip)
+ v_ip = self.to_v_ip(context_ip)
+ context = context[:, : (token_len - self.ip_dim), :]
+
+ k = self.to_k(context)
+ v = self.to_v(context)
+
+ b, _, _ = q.shape
+ q, k, v = map(
+ lambda t: t.unsqueeze(3)
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
+ .permute(0, 2, 1, 3)
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
+ .contiguous(),
+ (q, k, v),
+ )
+
+ # actually compute the attention, what we cannot get enough of
+ out = xformers.ops.memory_efficient_attention(
+ q, k, v, attn_bias=None, op=self.attention_op
+ )
+
+ if self.ip_dim > 0:
+ k_ip, v_ip = map(
+ lambda t: t.unsqueeze(3)
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
+ .permute(0, 2, 1, 3)
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
+ .contiguous(),
+ (k_ip, v_ip),
+ )
+ # actually compute the attention, what we cannot get enough of
+ out_ip = xformers.ops.memory_efficient_attention(
+ q, k_ip, v_ip, attn_bias=None, op=self.attention_op
+ )
+ out = out + self.ip_weight * out_ip
+
+ out = (
+ out.unsqueeze(0)
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
+ .permute(0, 2, 1, 3)
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
+ )
+ return self.to_out(out)
+
+
+class BasicTransformerBlock3D(nn.Module):
+
+ def __init__(
+ self,
+ dim,
+ n_heads,
+ d_head,
+ context_dim,
+ dropout=0.0,
+ gated_ff=True,
+ ip_dim=0,
+ ip_weight=1,
+ ):
+ super().__init__()
+
+ self.attn1 = MemoryEfficientCrossAttention(
+ query_dim=dim,
+ context_dim=None, # self-attention
+ heads=n_heads,
+ dim_head=d_head,
+ dropout=dropout,
+ )
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
+ self.attn2 = MemoryEfficientCrossAttention(
+ query_dim=dim,
+ context_dim=context_dim,
+ heads=n_heads,
+ dim_head=d_head,
+ dropout=dropout,
+ # ip only applies to cross-attention
+ ip_dim=ip_dim,
+ ip_weight=ip_weight,
+ )
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+ self.norm3 = nn.LayerNorm(dim)
+
+ def forward(self, x, context=None, num_frames=1):
+ x = rearrange(x, "(b f) l c -> b (f l) c", f=num_frames).contiguous()
+ x = self.attn1(self.norm1(x), context=None) + x
+ x = rearrange(x, "b (f l) c -> (b f) l c", f=num_frames).contiguous()
+ x = self.attn2(self.norm2(x), context=context) + x
+ x = self.ff(self.norm3(x)) + x
+ return x
+
+
+class SpatialTransformer3D(nn.Module):
+
+ def __init__(
+ self,
+ in_channels,
+ n_heads,
+ d_head,
+ context_dim, # cross attention input dim
+ depth=1,
+ dropout=0.0,
+ ip_dim=0,
+ ip_weight=1,
+ ):
+ super().__init__()
+
+ if not isinstance(context_dim, list):
+ context_dim = [context_dim]
+
+ self.in_channels = in_channels
+
+ inner_dim = n_heads * d_head
+ self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+ self.proj_in = nn.Linear(in_channels, inner_dim)
+
+ self.transformer_blocks = nn.ModuleList(
+ [
+ BasicTransformerBlock3D(
+ inner_dim,
+ n_heads,
+ d_head,
+ context_dim=context_dim[d],
+ dropout=dropout,
+ ip_dim=ip_dim,
+ ip_weight=ip_weight,
+ )
+ for d in range(depth)
+ ]
+ )
+
+ self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
+
+
+ def forward(self, x, context=None, num_frames=1):
+ # note: if no context is given, cross-attention defaults to self-attention
+ if not isinstance(context, list):
+ context = [context]
+ b, c, h, w = x.shape
+ x_in = x
+ x = self.norm(x)
+ x = rearrange(x, "b c h w -> b (h w) c").contiguous()
+ x = self.proj_in(x)
+ for i, block in enumerate(self.transformer_blocks):
+ x = block(x, context=context[i], num_frames=num_frames)
+ x = self.proj_out(x)
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
+
+ return x + x_in
+
+
+class PerceiverAttention(nn.Module):
+ def __init__(self, *, dim, dim_head=64, heads=8):
+ super().__init__()
+ self.scale = dim_head ** -0.5
+ self.dim_head = dim_head
+ self.heads = heads
+ inner_dim = dim_head * heads
+
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
+
+ def forward(self, x, latents):
+ """
+ Args:
+ x (torch.Tensor): image features
+ shape (b, n1, D)
+ latent (torch.Tensor): latent features
+ shape (b, n2, D)
+ """
+ x = self.norm1(x)
+ latents = self.norm2(latents)
+
+ b, l, _ = latents.shape
+
+ q = self.to_q(latents)
+ kv_input = torch.cat((x, latents), dim=-2)
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
+
+ q, k, v = map(
+ lambda t: t.reshape(b, t.shape[1], self.heads, -1)
+ .transpose(1, 2)
+ .reshape(b, self.heads, t.shape[1], -1)
+ .contiguous(),
+ (q, k, v),
+ )
+
+ # attention
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
+ weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
+ out = weight @ v
+
+ out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
+
+ return self.to_out(out)
+
+
+class Resampler(nn.Module):
+ def __init__(
+ self,
+ dim=1024,
+ depth=8,
+ dim_head=64,
+ heads=16,
+ num_queries=8,
+ embedding_dim=768,
+ output_dim=1024,
+ ff_mult=4,
+ ):
+ super().__init__()
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim ** 0.5)
+ self.proj_in = nn.Linear(embedding_dim, dim)
+ self.proj_out = nn.Linear(dim, output_dim)
+ self.norm_out = nn.LayerNorm(output_dim)
+
+ self.layers = nn.ModuleList([])
+ for _ in range(depth):
+ self.layers.append(
+ nn.ModuleList(
+ [
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
+ nn.Sequential(
+ nn.LayerNorm(dim),
+ nn.Linear(dim, dim * ff_mult, bias=False),
+ nn.GELU(),
+ nn.Linear(dim * ff_mult, dim, bias=False),
+ )
+ ]
+ )
+ )
+
+ def forward(self, x):
+ latents = self.latents.repeat(x.size(0), 1, 1)
+ x = self.proj_in(x)
+ for attn, ff in self.layers:
+ latents = attn(x, latents) + latents
+ latents = ff(latents) + latents
+
+ latents = self.proj_out(latents)
+ return self.norm_out(latents)
+
+
+class CondSequential(nn.Sequential):
+ """
+ A sequential module that passes timestep embeddings to the children that
+ support it as an extra input.
+ """
+
+ def forward(self, x, emb, context=None, num_frames=1):
+ for layer in self:
+ if isinstance(layer, ResBlock):
+ x = layer(x, emb)
+ elif isinstance(layer, SpatialTransformer3D):
+ x = layer(x, context, num_frames=num_frames)
+ else:
+ x = layer(x)
+ return x
+
+
+class Upsample(nn.Module):
+ """
+ An upsampling layer with an optional convolution.
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ upsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ if use_conv:
+ self.conv = conv_nd(
+ dims, self.channels, self.out_channels, 3, padding=padding
+ )
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ if self.dims == 3:
+ x = F.interpolate(
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
+ )
+ else:
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
+ if self.use_conv:
+ x = self.conv(x)
+ return x
+
+
+class Downsample(nn.Module):
+ """
+ A downsampling layer with an optional convolution.
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ downsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ stride = 2 if dims != 3 else (1, 2, 2)
+ if use_conv:
+ self.op = conv_nd(
+ dims,
+ self.channels,
+ self.out_channels,
+ 3,
+ stride=stride,
+ padding=padding,
+ )
+ else:
+ assert self.channels == self.out_channels
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ return self.op(x)
+
+
+class ResBlock(nn.Module):
+ """
+ A residual block that can optionally change the number of channels.
+ :param channels: the number of input channels.
+ :param emb_channels: the number of timestep embedding channels.
+ :param dropout: the rate of dropout.
+ :param out_channels: if specified, the number of out channels.
+ :param use_conv: if True and out_channels is specified, use a spatial
+ convolution instead of a smaller 1x1 convolution to change the
+ channels in the skip connection.
+ :param dims: determines if the signal is 1D, 2D, or 3D.
+ :param up: if True, use this block for upsampling.
+ :param down: if True, use this block for downsampling.
+ """
+
+ def __init__(
+ self,
+ channels,
+ emb_channels,
+ dropout,
+ out_channels=None,
+ use_conv=False,
+ use_scale_shift_norm=False,
+ dims=2,
+ up=False,
+ down=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ self.emb_channels = emb_channels
+ self.dropout = dropout
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.use_scale_shift_norm = use_scale_shift_norm
+
+ self.in_layers = nn.Sequential(
+ nn.GroupNorm(32, channels),
+ nn.SiLU(),
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
+ )
+
+ self.updown = up or down
+
+ if up:
+ self.h_upd = Upsample(channels, False, dims)
+ self.x_upd = Upsample(channels, False, dims)
+ elif down:
+ self.h_upd = Downsample(channels, False, dims)
+ self.x_upd = Downsample(channels, False, dims)
+ else:
+ self.h_upd = self.x_upd = nn.Identity()
+
+ self.emb_layers = nn.Sequential(
+ nn.SiLU(),
+ nn.Linear(
+ emb_channels,
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
+ ),
+ )
+ self.out_layers = nn.Sequential(
+ nn.GroupNorm(32, self.out_channels),
+ nn.SiLU(),
+ nn.Dropout(p=dropout),
+ zero_module(
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
+ ),
+ )
+
+ if self.out_channels == channels:
+ self.skip_connection = nn.Identity()
+ elif use_conv:
+ self.skip_connection = conv_nd(
+ dims, channels, self.out_channels, 3, padding=1
+ )
+ else:
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
+
+ def forward(self, x, emb):
+ if self.updown:
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
+ h = in_rest(x)
+ h = self.h_upd(h)
+ x = self.x_upd(x)
+ h = in_conv(h)
+ else:
+ h = self.in_layers(x)
+ emb_out = self.emb_layers(emb).type(h.dtype)
+ while len(emb_out.shape) < len(h.shape):
+ emb_out = emb_out[..., None]
+ if self.use_scale_shift_norm:
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
+ scale, shift = torch.chunk(emb_out, 2, dim=1)
+ h = out_norm(h) * (1 + scale) + shift
+ h = out_rest(h)
+ else:
+ h = h + emb_out
+ h = self.out_layers(h)
+ return self.skip_connection(x) + h
+
+
+class MultiViewUNetModel(ModelMixin, ConfigMixin):
+ """
+ The full multi-view UNet model with attention, timestep embedding and camera embedding.
+ :param in_channels: channels in the input Tensor.
+ :param model_channels: base channel count for the model.
+ :param out_channels: channels in the output Tensor.
+ :param num_res_blocks: number of residual blocks per downsample.
+ :param attention_resolutions: a collection of downsample rates at which
+ attention will take place. May be a set, list, or tuple.
+ For example, if this contains 4, then at 4x downsampling, attention
+ will be used.
+ :param dropout: the dropout probability.
+ :param channel_mult: channel multiplier for each level of the UNet.
+ :param conv_resample: if True, use learned convolutions for upsampling and
+ downsampling.
+ :param dims: determines if the signal is 1D, 2D, or 3D.
+ :param num_classes: if specified (as an int), then this model will be
+ class-conditional with `num_classes` classes.
+ :param num_heads: the number of attention heads in each attention layer.
+ :param num_heads_channels: if specified, ignore num_heads and instead use
+ a fixed channel width per attention head.
+ :param num_heads_upsample: works with num_heads to set a different number
+ of heads for upsampling. Deprecated.
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
+ :param resblock_updown: use residual blocks for up/downsampling.
+ :param use_new_attention_order: use a different attention pattern for potentially
+ increased efficiency.
+ :param camera_dim: dimensionality of camera input.
+ """
+
+ def __init__(
+ self,
+ image_size,
+ in_channels,
+ model_channels,
+ out_channels,
+ num_res_blocks,
+ attention_resolutions,
+ dropout=0,
+ channel_mult=(1, 2, 4, 8),
+ conv_resample=True,
+ dims=2,
+ num_classes=None,
+ num_heads=-1,
+ num_head_channels=-1,
+ num_heads_upsample=-1,
+ use_scale_shift_norm=False,
+ resblock_updown=False,
+ transformer_depth=1,
+ context_dim=None,
+ n_embed=None,
+ num_attention_blocks=None,
+ adm_in_channels=None,
+ camera_dim=None,
+ ip_dim=0, # imagedream uses ip_dim > 0
+ ip_weight=1.0,
+ **kwargs,
+ ):
+ super().__init__()
+ assert context_dim is not None
+
+ if num_heads_upsample == -1:
+ num_heads_upsample = num_heads
+
+ if num_heads == -1:
+ assert (
+ num_head_channels != -1
+ ), "Either num_heads or num_head_channels has to be set"
+
+ if num_head_channels == -1:
+ assert (
+ num_heads != -1
+ ), "Either num_heads or num_head_channels has to be set"
+
+ self.image_size = image_size
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.out_channels = out_channels
+ if isinstance(num_res_blocks, int):
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
+ else:
+ if len(num_res_blocks) != len(channel_mult):
+ raise ValueError(
+ "provide num_res_blocks either as an int (globally constant) or "
+ "as a list/tuple (per-level) with the same length as channel_mult"
+ )
+ self.num_res_blocks = num_res_blocks
+
+ if num_attention_blocks is not None:
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
+ assert all(
+ map(
+ lambda i: self.num_res_blocks[i] >= num_attention_blocks[i],
+ range(len(num_attention_blocks)),
+ )
+ )
+ print(
+ f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
+ f"attention will still not be set."
+ )
+
+ self.attention_resolutions = attention_resolutions
+ self.dropout = dropout
+ self.channel_mult = channel_mult
+ self.conv_resample = conv_resample
+ self.num_classes = num_classes
+ self.num_heads = num_heads
+ self.num_head_channels = num_head_channels
+ self.num_heads_upsample = num_heads_upsample
+ self.predict_codebook_ids = n_embed is not None
+
+ self.ip_dim = ip_dim
+ self.ip_weight = ip_weight
+
+ if self.ip_dim > 0:
+ self.image_embed = Resampler(
+ dim=context_dim,
+ depth=4,
+ dim_head=64,
+ heads=12,
+ num_queries=ip_dim, # num token
+ embedding_dim=1280,
+ output_dim=context_dim,
+ ff_mult=4,
+ )
+
+ time_embed_dim = model_channels * 4
+ self.time_embed = nn.Sequential(
+ nn.Linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ nn.Linear(time_embed_dim, time_embed_dim),
+ )
+
+ if camera_dim is not None:
+ time_embed_dim = model_channels * 4
+ self.camera_embed = nn.Sequential(
+ nn.Linear(camera_dim, time_embed_dim),
+ nn.SiLU(),
+ nn.Linear(time_embed_dim, time_embed_dim),
+ )
+
+ if self.num_classes is not None:
+ if isinstance(self.num_classes, int):
+ self.label_emb = nn.Embedding(self.num_classes, time_embed_dim)
+ elif self.num_classes == "continuous":
+ # print("setting up linear c_adm embedding layer")
+ self.label_emb = nn.Linear(1, time_embed_dim)
+ elif self.num_classes == "sequential":
+ assert adm_in_channels is not None
+ self.label_emb = nn.Sequential(
+ nn.Sequential(
+ nn.Linear(adm_in_channels, time_embed_dim),
+ nn.SiLU(),
+ nn.Linear(time_embed_dim, time_embed_dim),
+ )
+ )
+ else:
+ raise ValueError()
+
+ self.input_blocks = nn.ModuleList(
+ [
+ CondSequential(
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
+ )
+ ]
+ )
+ self._feature_size = model_channels
+ input_block_chans = [model_channels]
+ ch = model_channels
+ ds = 1
+ for level, mult in enumerate(channel_mult):
+ for nr in range(self.num_res_blocks[level]):
+ layers: List[Any] = [
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=mult * model_channels,
+ dims=dims,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = mult * model_channels
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+
+ if num_attention_blocks is None or nr < num_attention_blocks[level]:
+ layers.append(
+ SpatialTransformer3D(
+ ch,
+ num_heads,
+ dim_head,
+ context_dim=context_dim,
+ depth=transformer_depth,
+ ip_dim=self.ip_dim,
+ ip_weight=self.ip_weight,
+ )
+ )
+ self.input_blocks.append(CondSequential(*layers))
+ self._feature_size += ch
+ input_block_chans.append(ch)
+ if level != len(channel_mult) - 1:
+ out_ch = ch
+ self.input_blocks.append(
+ CondSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_scale_shift_norm=use_scale_shift_norm,
+ down=True,
+ )
+ if resblock_updown
+ else Downsample(
+ ch, conv_resample, dims=dims, out_channels=out_ch
+ )
+ )
+ )
+ ch = out_ch
+ input_block_chans.append(ch)
+ ds *= 2
+ self._feature_size += ch
+
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+
+ self.middle_block = CondSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ SpatialTransformer3D(
+ ch,
+ num_heads,
+ dim_head,
+ context_dim=context_dim,
+ depth=transformer_depth,
+ ip_dim=self.ip_dim,
+ ip_weight=self.ip_weight,
+ ),
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ )
+ self._feature_size += ch
+
+ self.output_blocks = nn.ModuleList([])
+ for level, mult in list(enumerate(channel_mult))[::-1]:
+ for i in range(self.num_res_blocks[level] + 1):
+ ich = input_block_chans.pop()
+ layers = [
+ ResBlock(
+ ch + ich,
+ time_embed_dim,
+ dropout,
+ out_channels=model_channels * mult,
+ dims=dims,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = model_channels * mult
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+
+ if num_attention_blocks is None or i < num_attention_blocks[level]:
+ layers.append(
+ SpatialTransformer3D(
+ ch,
+ num_heads,
+ dim_head,
+ context_dim=context_dim,
+ depth=transformer_depth,
+ ip_dim=self.ip_dim,
+ ip_weight=self.ip_weight,
+ )
+ )
+ if level and i == self.num_res_blocks[level]:
+ out_ch = ch
+ layers.append(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_scale_shift_norm=use_scale_shift_norm,
+ up=True,
+ )
+ if resblock_updown
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
+ )
+ ds //= 2
+ self.output_blocks.append(CondSequential(*layers))
+ self._feature_size += ch
+
+ self.out = nn.Sequential(
+ nn.GroupNorm(32, ch),
+ nn.SiLU(),
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
+ )
+ if self.predict_codebook_ids:
+ self.id_predictor = nn.Sequential(
+ nn.GroupNorm(32, ch),
+ conv_nd(dims, model_channels, n_embed, 1),
+ # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
+ )
+
+ def forward(
+ self,
+ x,
+ timesteps=None,
+ context=None,
+ y=None,
+ camera=None,
+ num_frames=1,
+ ip=None,
+ ip_img=None,
+ **kwargs,
+ ):
+ """
+ Apply the model to an input batch.
+ :param x: an [(N x F) x C x ...] Tensor of inputs. F is the number of frames (views).
+ :param timesteps: a 1-D batch of timesteps.
+ :param context: conditioning plugged in via crossattn
+ :param y: an [N] Tensor of labels, if class-conditional.
+ :param num_frames: a integer indicating number of frames for tensor reshaping.
+ :return: an [(N x F) x C x ...] Tensor of outputs. F is the number of frames (views).
+ """
+ assert (
+ x.shape[0] % num_frames == 0
+ ), "input batch size must be dividable by num_frames!"
+ assert (y is not None) == (
+ self.num_classes is not None
+ ), "must specify y if and only if the model is class-conditional"
+
+ hs = []
+
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
+
+ emb = self.time_embed(t_emb)
+
+ if self.num_classes is not None:
+ assert y is not None
+ assert y.shape[0] == x.shape[0]
+ emb = emb + self.label_emb(y)
+
+ # Add camera embeddings
+ if camera is not None:
+ emb = emb + self.camera_embed(camera)
+
+ # imagedream variant
+ if self.ip_dim > 0:
+ x[(num_frames - 1) :: num_frames, :, :, :] = ip_img # place at [4, 9]
+ ip_emb = self.image_embed(ip)
+ context = torch.cat((context, ip_emb), 1)
+
+ h = x
+ for module in self.input_blocks:
+ h = module(h, emb, context, num_frames=num_frames)
+ hs.append(h)
+ h = self.middle_block(h, emb, context, num_frames=num_frames)
+ for module in self.output_blocks:
+ h = torch.cat([h, hs.pop()], dim=1)
+ h = module(h, emb, context, num_frames=num_frames)
+ h = h.type(x.dtype)
+ if self.predict_codebook_ids:
+ return self.id_predictor(h)
+ else:
+ return self.out(h)
\ No newline at end of file
diff --git a/libs/LGM/mvdream/pipeline_mvdream.py b/libs/LGM/mvdream/pipeline_mvdream.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b7b3a72558aa03f77a1626e26d4f87fd830dd90
--- /dev/null
+++ b/libs/LGM/mvdream/pipeline_mvdream.py
@@ -0,0 +1,559 @@
+import torch
+import torch.nn.functional as F
+import inspect
+import numpy as np
+from typing import Callable, List, Optional, Union
+from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel, CLIPImageProcessor
+from diffusers import AutoencoderKL, DiffusionPipeline
+from diffusers.utils import (
+ deprecate,
+ is_accelerate_available,
+ is_accelerate_version,
+ logging,
+)
+from diffusers.configuration_utils import FrozenDict
+from diffusers.schedulers import DDIMScheduler
+from diffusers.utils.torch_utils import randn_tensor
+
+from mvdream.mv_unet import MultiViewUNetModel, get_camera
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class MVDreamPipeline(DiffusionPipeline):
+
+ _optional_components = ["feature_extractor", "image_encoder"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ unet: MultiViewUNetModel,
+ tokenizer: CLIPTokenizer,
+ text_encoder: CLIPTextModel,
+ scheduler: DDIMScheduler,
+ # imagedream variant
+ feature_extractor: CLIPImageProcessor,
+ image_encoder: CLIPVisionModel,
+ requires_safety_checker: bool = False,
+ ):
+ super().__init__()
+
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: # type: ignore
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " # type: ignore
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file"
+ )
+ deprecate(
+ "steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False
+ )
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: # type: ignore
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
+ )
+ deprecate(
+ "clip_sample not set", "1.0.0", deprecation_message, standard_warn=False
+ )
+ new_config = dict(scheduler.config)
+ new_config["clip_sample"] = False
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ self.register_modules(
+ vae=vae,
+ unet=unet,
+ scheduler=scheduler,
+ tokenizer=tokenizer,
+ text_encoder=text_encoder,
+ feature_extractor=feature_extractor,
+ image_encoder=image_encoder,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding.
+
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
+ steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding.
+
+ When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
+ several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
+ """
+ self.vae.enable_tiling()
+
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
+ def enable_sequential_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
+ text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
+ Note that offloading happens on a submodule basis. Memory savings are higher than with
+ `enable_model_cpu_offload`, but performance is lower.
+ """
+ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
+ from accelerate import cpu_offload
+ else:
+ raise ImportError(
+ "`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher"
+ )
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ if self.device.type != "cpu":
+ self.to("cpu", silence_dtype_warnings=True)
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
+
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
+ cpu_offload(cpu_offloaded_model, device)
+
+ def enable_model_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
+ """
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
+ from accelerate import cpu_offload_with_hook
+ else:
+ raise ImportError(
+ "`enable_model_offload` requires `accelerate v0.17.0` or higher."
+ )
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ if self.device.type != "cpu":
+ self.to("cpu", silence_dtype_warnings=True)
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
+
+ hook = None
+ for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
+ _, hook = cpu_offload_with_hook(
+ cpu_offloaded_model, device, prev_module_hook=hook
+ )
+
+ # We'll offload the last model manually.
+ self.final_offload_hook = hook
+
+ @property
+ def _execution_device(self):
+ r"""
+ Returns the device on which the pipeline's models will be executed. After calling
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
+ hooks.
+ """
+ if not hasattr(self.unet, "_hf_hook"):
+ return self.device
+ for module in self.unet.modules():
+ if (
+ hasattr(module, "_hf_hook")
+ and hasattr(module._hf_hook, "execution_device")
+ and module._hf_hook.execution_device is not None
+ ):
+ return torch.device(module._hf_hook.execution_device)
+ return self.device
+
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance: bool,
+ negative_prompt=None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ """
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ raise ValueError(
+ f"`prompt` should be either a string or a list of strings, but got {type(prompt)}."
+ )
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(
+ prompt, padding="longest", return_tensors="pt"
+ ).input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if (
+ hasattr(self.text_encoder.config, "use_attention_mask")
+ and self.text_encoder.config.use_attention_mask
+ ):
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ prompt_embeds = prompt_embeds[0]
+
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(
+ bs_embed * num_images_per_prompt, seq_len, -1
+ )
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if (
+ hasattr(self.text_encoder.config, "use_attention_mask")
+ and self.text_encoder.config.use_attention_mask
+ ):
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(
+ dtype=self.text_encoder.dtype, device=device
+ )
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(
+ 1, num_images_per_prompt, 1
+ )
+ negative_prompt_embeds = negative_prompt_embeds.view(
+ batch_size * num_images_per_prompt, seq_len, -1
+ )
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ return prompt_embeds
+
+ def decode_latents(self, latents):
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.decode(latents).sample
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(
+ inspect.signature(self.scheduler.step).parameters.keys()
+ )
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(
+ inspect.signature(self.scheduler.step).parameters.keys()
+ )
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def prepare_latents(
+ self,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ ):
+ shape = (
+ batch_size,
+ num_channels_latents,
+ height // self.vae_scale_factor,
+ width // self.vae_scale_factor,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(
+ shape, generator=generator, device=device, dtype=dtype
+ )
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ def encode_image(self, image, device, num_images_per_prompt):
+ dtype = next(self.image_encoder.parameters()).dtype
+
+ if image.dtype == np.float32:
+ image = (image * 255).astype(np.uint8)
+
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
+ image = image.to(device=device, dtype=dtype)
+
+ image_embeds = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+
+ return torch.zeros_like(image_embeds), image_embeds
+
+ def encode_image_latents(self, image, device, num_images_per_prompt):
+
+ dtype = next(self.image_encoder.parameters()).dtype
+
+ image = torch.from_numpy(image).unsqueeze(0).permute(0, 3, 1, 2).to(device=device) # [1, 3, H, W]
+ image = 2 * image - 1
+ image = F.interpolate(image, (256, 256), mode='bilinear', align_corners=False)
+ image = image.to(dtype=dtype)
+
+ posterior = self.vae.encode(image).latent_dist
+ latents = posterior.sample() * self.vae.config.scaling_factor # [B, C, H, W]
+ latents = latents.repeat_interleave(num_images_per_prompt, dim=0)
+
+ return torch.zeros_like(latents), latents
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: str = "",
+ image: Optional[np.ndarray] = None,
+ height: int = 256,
+ width: int = 256,
+ elevation: float = 0,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.0,
+ negative_prompt: str = "",
+ num_images_per_prompt: int = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ output_type: Optional[str] = "numpy", # pil, numpy, latents
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ num_frames: int = 4,
+ device=torch.device("cuda:0"),
+ ):
+ self.unet = self.unet.to(device=device)
+ self.vae = self.vae.to(device=device)
+ self.text_encoder = self.text_encoder.to(device=device)
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # imagedream variant
+ if image is not None:
+ assert isinstance(image, np.ndarray) and image.dtype == np.float32
+ self.image_encoder = self.image_encoder.to(device=device)
+ image_embeds_neg, image_embeds_pos = self.encode_image(image, device, num_images_per_prompt)
+ image_latents_neg, image_latents_pos = self.encode_image_latents(image, device, num_images_per_prompt)
+
+ _prompt_embeds = self._encode_prompt(
+ prompt=prompt,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ ) # type: ignore
+ prompt_embeds_neg, prompt_embeds_pos = _prompt_embeds.chunk(2)
+
+ # Prepare latent variables
+ actual_num_frames = num_frames if image is None else num_frames + 1
+ latents: torch.Tensor = self.prepare_latents(
+ actual_num_frames * num_images_per_prompt,
+ 4,
+ height,
+ width,
+ prompt_embeds_pos.dtype,
+ device,
+ generator,
+ None,
+ )
+
+ if image is not None:
+ camera = get_camera(num_frames, elevation=elevation, extra_view=True).to(dtype=latents.dtype, device=device)
+ else:
+ camera = get_camera(num_frames, elevation=elevation, extra_view=False).to(dtype=latents.dtype, device=device)
+ camera = camera.repeat_interleave(num_images_per_prompt, dim=0)
+
+ # Prepare extra step kwargs.
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ multiplier = 2 if do_classifier_free_guidance else 1
+ latent_model_input = torch.cat([latents] * multiplier)
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ unet_inputs = {
+ 'x': latent_model_input,
+ 'timesteps': torch.tensor([t] * actual_num_frames * multiplier, dtype=latent_model_input.dtype, device=device),
+ 'context': torch.cat([prompt_embeds_neg] * actual_num_frames + [prompt_embeds_pos] * actual_num_frames),
+ 'num_frames': actual_num_frames,
+ 'camera': torch.cat([camera] * multiplier),
+ }
+
+ if image is not None:
+ unet_inputs['ip'] = torch.cat([image_embeds_neg] * actual_num_frames + [image_embeds_pos] * actual_num_frames)
+ unet_inputs['ip_img'] = torch.cat([image_latents_neg] + [image_latents_pos]) # no repeat
+
+ # predict the noise residual
+ noise_pred = self.unet.forward(**unet_inputs)
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (
+ noise_pred_text - noise_pred_uncond
+ )
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents: torch.Tensor = self.scheduler.step(
+ noise_pred, t, latents, **extra_step_kwargs, return_dict=False
+ )[0]
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or (
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
+ ):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents) # type: ignore
+
+ # Post-processing
+ if output_type == "latent":
+ image = latents
+ elif output_type == "pil":
+ image = self.decode_latents(latents)
+ image = self.numpy_to_pil(image)
+ else: # numpy
+ image = self.decode_latents(latents)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ return image
\ No newline at end of file
diff --git a/libs/LGM/readme.md b/libs/LGM/readme.md
new file mode 100644
index 0000000000000000000000000000000000000000..e29cb2bfa8c9461130365d9ac6e68a2c631c6887
--- /dev/null
+++ b/libs/LGM/readme.md
@@ -0,0 +1,108 @@
+
+## Large Multi-View Gaussian Model
+
+This is the official implementation of *LGM: Large Multi-View Gaussian Model for High-Resolution 3D Content Creation*.
+
+### [Project Page](https://me.kiui.moe/lgm/) | [Arxiv](https://arxiv.org/abs/2402.05054) | [Weights](https://huggingface.co/ashawkey/LGM) |
+
+https://github.com/3DTopia/LGM/assets/25863658/cf64e489-29f3-4935-adba-e393a24c26e8
+
+### News
+[2024.4.3] Thanks to [@yxymessi](https://github.com/yxymessi) and [@florinshen](https://github.com/florinshen), we have fixed a **severe bug in rotation normalization** [here](https://github.com/3DTopia/LGM/commit/9a0797cdbacf8e6216d0108cb00cbe43b9cb3d81). We have finetuned the model with correct normalization for 30 more epochs and uploaded new checkpoints.
+
+### Replicate Demo:
+* gaussians: [demo](https://replicate.com/camenduru/lgm) | [code](https://github.com/camenduru/LGM-replicate)
+* mesh: [demo](https://replicate.com/camenduru/lgm-ply-to-glb) | [code](https://github.com/camenduru/LGM-ply-to-glb-replicate)
+
+Thanks to [@camenduru](https://github.com/camenduru)!
+
+### Install
+
+```bash
+# xformers is required! please refer to https://github.com/facebookresearch/xformers for details.
+# for example, we use torch 2.1.0 + cuda 11.8
+pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu118
+pip install -U xformers --index-url https://download.pytorch.org/whl/cu118
+
+# a modified gaussian splatting (+ depth, alpha rendering)
+git clone --recursive https://github.com/ashawkey/diff-gaussian-rasterization
+pip install ./diff-gaussian-rasterization
+
+# for mesh extraction
+pip install git+https://github.com/NVlabs/nvdiffrast
+
+# other dependencies
+pip install -r requirements.txt
+```
+
+### Pretrained Weights
+
+Our pretrained weight can be downloaded from [huggingface](https://huggingface.co/ashawkey/LGM).
+
+For example, to download the fp16 model for inference:
+```bash
+mkdir pretrained && cd pretrained
+wget https://huggingface.co/ashawkey/LGM/resolve/main/model_fp16_fixrot.safetensors
+cd ..
+```
+
+For [MVDream](https://github.com/bytedance/MVDream) and [ImageDream](https://github.com/bytedance/ImageDream), we use a [diffusers implementation](https://github.com/ashawkey/mvdream_diffusers).
+Their weights will be downloaded automatically.
+
+### Inference
+
+Inference takes about 10GB GPU memory (loading all imagedream, mvdream, and our LGM).
+
+```bash
+### gradio app for both text/image to 3D
+python app.py big --resume pretrained/model_fp16.safetensors
+
+### test
+# --workspace: folder to save output (*.ply and *.mp4)
+# --test_path: path to a folder containing images, or a single image
+python infer.py big --resume pretrained/model_fp16.safetensors --workspace workspace_test --test_path data_test
+
+### local gui to visualize saved ply
+python gui.py big --output_size 800 --test_path workspace_test/saved.ply
+
+### mesh conversion
+python convert.py big --test_path workspace_test/saved.ply
+```
+
+For more options, please check [options](./core/options.py).
+
+### Training
+
+**NOTE**:
+Since the dataset used in our training is based on AWS, it cannot be directly used for training in a new environment.
+We provide the necessary training code framework, please check and modify the [dataset](./core/provider_objaverse.py) implementation!
+
+We also provide the **~80K subset of [Objaverse](https://objaverse.allenai.org/objaverse-1.0)** used to train LGM in [objaverse_filter](https://github.com/ashawkey/objaverse_filter).
+
+```bash
+# debug training
+accelerate launch --config_file acc_configs/gpu1.yaml main.py big --workspace workspace_debug
+
+# training (use slurm for multi-nodes training)
+accelerate launch --config_file acc_configs/gpu8.yaml main.py big --workspace workspace
+```
+
+### Acknowledgement
+
+This work is built on many amazing research works and open-source projects, thanks a lot to all the authors for sharing!
+
+- [gaussian-splatting](https://github.com/graphdeco-inria/gaussian-splatting) and [diff-gaussian-rasterization](https://github.com/graphdeco-inria/diff-gaussian-rasterization)
+- [nvdiffrast](https://github.com/NVlabs/nvdiffrast)
+- [dearpygui](https://github.com/hoffstadt/DearPyGui)
+- [tyro](https://github.com/brentyi/tyro)
+
+### Citation
+
+```
+@article{tang2024lgm,
+ title={LGM: Large Multi-View Gaussian Model for High-Resolution 3D Content Creation},
+ author={Tang, Jiaxiang and Chen, Zhaoxi and Chen, Xiaokang and Wang, Tengfei and Zeng, Gang and Liu, Ziwei},
+ journal={arXiv preprint arXiv:2402.05054},
+ year={2024}
+}
+```
diff --git a/libs/LGM/requirements.txt b/libs/LGM/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..ade4e49c6b22e2c0ecdca2e0b8ac507a94cd634d
--- /dev/null
+++ b/libs/LGM/requirements.txt
@@ -0,0 +1,28 @@
+torch
+numpy
+tyro
+diffusers
+dearpygui
+einops
+accelerate
+gradio
+imageio
+imageio-ffmpeg
+lpips
+matplotlib
+packaging
+Pillow
+pygltflib
+rembg[gpu,cli]
+rich
+safetensors
+scikit-image
+scikit-learn
+scipy
+tqdm
+transformers
+trimesh
+kiui >= 0.2.3
+xatlas
+roma
+plyfile
diff --git a/libs/LGM/scripts/convert_all.py b/libs/LGM/scripts/convert_all.py
new file mode 100644
index 0000000000000000000000000000000000000000..163ae27a414a4925d8f8b7b2b55a8c3571f04792
--- /dev/null
+++ b/libs/LGM/scripts/convert_all.py
@@ -0,0 +1,15 @@
+import os
+import glob
+import argparse
+
+parser = argparse.ArgumentParser()
+parser.add_argument('dir', default='workspace', type=str)
+parser.add_argument('--gpu', default=0, type=int, help='ID of GPU to use')
+args = parser.parse_args()
+
+files = glob.glob(f'{args.dir}/*.ply')
+
+for file in files:
+ name = file.replace('.ply', '')
+ os.system(f'CUDA_VISIBLE_DEVICES={args.gpu} python convert.py big --test_path {file}')
+ # os.system(f'CUDA_VISIBLE_DEVICES={args.gpu} kire {name}.glb --save_video {name}_mesh.mp4 --wogui')
\ No newline at end of file
diff --git a/libs/LGM/scripts/examples.sh b/libs/LGM/scripts/examples.sh
new file mode 100644
index 0000000000000000000000000000000000000000..d68897eda043eb73cc87199fe560919bf7ead42f
--- /dev/null
+++ b/libs/LGM/scripts/examples.sh
@@ -0,0 +1,17 @@
+# debug training
+accelerate launch --config_file acc_configs/gpu1.yaml main.py big --workspace workspace_debug
+
+# training (should use slurm)
+accelerate launch --config_file acc_configs/gpu8.yaml main.py big --workspace workspace
+
+# test
+python infer.py big --workspace workspace_test --resume workspace/model.safetensors --test_path data_test
+
+# gradio app
+python app.py big --resume workspace/model.safetensors
+
+# local gui
+python gui.py big --output_size 800 --test_path workspace_test/anya_rgba.ply
+
+# mesh conversion
+python convert.py big --test_path workspace_test/anya_rgba.ply
\ No newline at end of file
diff --git a/libs/das/infer.py b/libs/das/infer.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c68313624bb59863494abe472bfe3793d6dca45
--- /dev/null
+++ b/libs/das/infer.py
@@ -0,0 +1,118 @@
+import os
+import sys
+import argparse
+from PIL import Image
+
+import torch
+import numpy as np
+from PIL import Image
+import torchvision.transforms as transforms
+from moviepy.editor import VideoFileClip
+from diffusers.utils import load_image, load_video
+
+from models.pipelines import DiffusionAsShaderPipeline, CameraMotionGenerator, ObjectMotionGenerator
+
+def load_media(media_path, max_frames=49, transform=None):
+ """Load video or image frames and convert to tensor
+
+ Args:
+ media_path (str): Path to video or image file
+ max_frames (int): Maximum number of frames to load
+ transform (callable): Transform to apply to frames
+
+ Returns:
+ Tuple[torch.Tensor, float]: Video tensor [T,C,H,W] and FPS
+ """
+ if transform is None:
+ transform = transforms.Compose([
+ transforms.Resize((480, 720)),
+ transforms.ToTensor()
+ ])
+
+ # Determine if input is video or image based on extension
+ ext = os.path.splitext(media_path)[1].lower()
+ is_video = ext in ['.mp4', '.avi', '.mov']
+
+ if is_video:
+ frames = load_video(media_path)
+ fps = len(frames) / VideoFileClip(media_path).duration
+ else:
+ # Handle image as single frame
+ image = load_image(media_path)
+ frames = [image]
+ fps = 8 # Default fps for images
+
+ # Ensure we have exactly max_frames
+ if len(frames) > max_frames:
+ frames = frames[:max_frames]
+ elif len(frames) < max_frames:
+ last_frame = frames[-1]
+ while len(frames) < max_frames:
+ frames.append(last_frame.copy())
+
+ # Convert frames to tensor
+ video_tensor = torch.stack([transform(frame) for frame in frames])
+
+ return video_tensor, fps, is_video
+
+def inference(das, prompt, checkpoint_path, tracking_path, output_dir, input_path):
+ video_tensor, fps, is_video = load_media(input_path)
+ tracking_tensor, _, _ = load_media(args.tracking_path)
+ das.apply_tracking(
+ video_tensor=video_tensor,
+ fps=8,
+ tracking_tensor=tracking_tensor,
+ img_cond_tensor=None,
+ prompt=prompt,
+ checkpoint_path=checkpoint_path
+ )
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--input_path', type=str, default=None, help='Path to input video/image')
+ parser.add_argument('--prompt', type=str, required=True, help='Repaint prompt')
+ parser.add_argument('--output_dir', type=str, default='outputs', help='Output directory')
+ parser.add_argument('--gpu', type=int, default=0, help='GPU device ID')
+ parser.add_argument('--checkpoint_path', type=str, required=True, help='Path to model checkpoint')
+ parser.add_argument('--depth_path', type=str, default=None, help='Path to depth image')
+ parser.add_argument('--tracking_path', type=str, default=None, help='Path to tracking video, if provided, camera motion and object manipulation will not be applied')
+ parser.add_argument('--repaint', type=str, default=None,
+ help='Path to repainted image, or "true" to perform repainting, if not provided use original frame')
+ parser.add_argument('--camera_motion', type=str, default=None, help='Camera motion mode')
+ parser.add_argument('--object_motion', type=str, default=None, help='Object motion mode: up/down/left/right')
+ parser.add_argument('--object_mask', type=str, default=None, help='Path to object mask image (binary image)')
+ parser.add_argument('--tracking_method', type=str, default="spatracker",
+ help='default tracking method for image input: moge/spatracker, if \'moge\' method will extract first frame for video input')
+ parser.add_argument('--coarse_video_path', type=str, default=None, help='Path to coarse video for object motion')
+ parser.add_argument('--start_noise_t', type=int, default=10, help='Strength of object motion')
+ args = parser.parse_args()
+
+ # Load input video/image
+ video_tensor, fps, is_video = load_media(args.input_path)
+
+ # Initialize pipeline
+ das = DiffusionAsShaderPipeline(gpu_id=args.gpu, output_dir=args.output_dir)
+
+ # Repaint first frame if requested
+ repaint_img_tensor = None
+
+ # Generate tracking if not provided
+ tracking_tensor = None
+ pred_tracks = None
+ cam_motion = CameraMotionGenerator(args.camera_motion)
+
+ if args.tracking_path:
+ tracking_tensor, _, _ = load_media(args.tracking_path)
+
+ coarse_video = load_media(args.coarse_video_path)[0] if args.coarse_video_path else None
+
+ das.apply_tracking(
+ video_tensor=video_tensor,
+ fps=24,
+ tracking_tensor=tracking_tensor,
+ img_cond_tensor=repaint_img_tensor,
+ prompt=args.prompt,
+ checkpoint_path=args.checkpoint_path,
+ coarse_video=coarse_video,
+ start_noise_t=args.start_noise_t
+ )
diff --git a/libs/das/models/cogvideox_tracking.py b/libs/das/models/cogvideox_tracking.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ae95a7116056d5d5f372faf477a6636d8295e78
--- /dev/null
+++ b/libs/das/models/cogvideox_tracking.py
@@ -0,0 +1,1032 @@
+from typing import Any, Dict, Optional, Tuple, Union, List, Callable
+
+import torch, os, math
+from torch import nn
+from PIL import Image
+from tqdm import tqdm
+
+from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
+from diffusers.models.modeling_outputs import Transformer2DModelOutput
+from diffusers.models.transformers.cogvideox_transformer_3d import CogVideoXBlock, CogVideoXTransformer3DModel
+
+from diffusers.pipelines.cogvideo.pipeline_cogvideox import CogVideoXPipeline, CogVideoXPipelineOutput
+from diffusers.pipelines.cogvideo.pipeline_cogvideox_image2video import CogVideoXImageToVideoPipeline
+from diffusers.pipelines.cogvideo.pipeline_cogvideox_video2video import CogVideoXVideoToVideoPipeline
+from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
+from diffusers.pipelines.cogvideo.pipeline_cogvideox import retrieve_timesteps
+from transformers import T5EncoderModel, T5Tokenizer
+from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
+from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
+from diffusers.pipelines import DiffusionPipeline
+from diffusers.models.modeling_utils import ModelMixin
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+class CogVideoXTransformer3DModelTracking(CogVideoXTransformer3DModel, ModelMixin):
+ """
+ Add tracking maps to the CogVideoX transformer model.
+
+ Parameters:
+ num_tracking_blocks (`int`, defaults to `18`):
+ The number of tracking blocks to use. Must be less than or equal to num_layers.
+ """
+
+ def __init__(
+ self,
+ num_tracking_blocks: Optional[int] = 18,
+ num_attention_heads: int = 30,
+ attention_head_dim: int = 64,
+ in_channels: int = 16,
+ out_channels: Optional[int] = 16,
+ flip_sin_to_cos: bool = True,
+ freq_shift: int = 0,
+ time_embed_dim: int = 512,
+ text_embed_dim: int = 4096,
+ num_layers: int = 30,
+ dropout: float = 0.0,
+ attention_bias: bool = True,
+ sample_width: int = 90,
+ sample_height: int = 60,
+ sample_frames: int = 49,
+ patch_size: int = 2,
+ temporal_compression_ratio: int = 4,
+ max_text_seq_length: int = 226,
+ activation_fn: str = "gelu-approximate",
+ timestep_activation_fn: str = "silu",
+ norm_elementwise_affine: bool = True,
+ norm_eps: float = 1e-5,
+ spatial_interpolation_scale: float = 1.875,
+ temporal_interpolation_scale: float = 1.0,
+ use_rotary_positional_embeddings: bool = False,
+ use_learned_positional_embeddings: bool = False,
+ **kwargs
+ ):
+ super().__init__(
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ flip_sin_to_cos=flip_sin_to_cos,
+ freq_shift=freq_shift,
+ time_embed_dim=time_embed_dim,
+ text_embed_dim=text_embed_dim,
+ num_layers=num_layers,
+ dropout=dropout,
+ attention_bias=attention_bias,
+ sample_width=sample_width,
+ sample_height=sample_height,
+ sample_frames=sample_frames,
+ patch_size=patch_size,
+ temporal_compression_ratio=temporal_compression_ratio,
+ max_text_seq_length=max_text_seq_length,
+ activation_fn=activation_fn,
+ timestep_activation_fn=timestep_activation_fn,
+ norm_elementwise_affine=norm_elementwise_affine,
+ norm_eps=norm_eps,
+ spatial_interpolation_scale=spatial_interpolation_scale,
+ temporal_interpolation_scale=temporal_interpolation_scale,
+ use_rotary_positional_embeddings=use_rotary_positional_embeddings,
+ use_learned_positional_embeddings=use_learned_positional_embeddings,
+ **kwargs
+ )
+
+ inner_dim = num_attention_heads * attention_head_dim
+ self.num_tracking_blocks = num_tracking_blocks
+
+ # Ensure num_tracking_blocks is not greater than num_layers
+ if num_tracking_blocks > num_layers:
+ raise ValueError("num_tracking_blocks must be less than or equal to num_layers")
+
+ # Create linear layers for combining hidden states and tracking maps
+ self.combine_linears = nn.ModuleList(
+ [nn.Linear(inner_dim, inner_dim) for _ in range(num_tracking_blocks)]
+ )
+
+ # Initialize weights of combine_linears to zero
+ for linear in self.combine_linears:
+ linear.weight.data.zero_()
+ linear.bias.data.zero_()
+
+ # Create transformer blocks for processing tracking maps
+ self.transformer_blocks_copy = nn.ModuleList(
+ [
+ CogVideoXBlock(
+ dim=inner_dim,
+ num_attention_heads=self.config.num_attention_heads,
+ attention_head_dim=self.config.attention_head_dim,
+ time_embed_dim=self.config.time_embed_dim,
+ dropout=self.config.dropout,
+ activation_fn=self.config.activation_fn,
+ attention_bias=self.config.attention_bias,
+ norm_elementwise_affine=self.config.norm_elementwise_affine,
+ norm_eps=self.config.norm_eps,
+ )
+ for _ in range(num_tracking_blocks)
+ ]
+ )
+
+ # For initial combination of hidden states and tracking maps
+ self.initial_combine_linear = nn.Linear(inner_dim, inner_dim)
+ self.initial_combine_linear.weight.data.zero_()
+ self.initial_combine_linear.bias.data.zero_()
+
+ # Freeze all parameters
+ for param in self.parameters():
+ param.requires_grad = False
+
+ # Unfreeze parameters that need to be trained
+ for linear in self.combine_linears:
+ for param in linear.parameters():
+ param.requires_grad = True
+
+ for block in self.transformer_blocks_copy:
+ for param in block.parameters():
+ param.requires_grad = True
+
+ for param in self.initial_combine_linear.parameters():
+ param.requires_grad = True
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ tracking_maps: torch.Tensor,
+ timestep: Union[int, float, torch.LongTensor],
+ timestep_cond: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ return_dict: bool = True,
+ ):
+ if attention_kwargs is not None:
+ attention_kwargs = attention_kwargs.copy()
+ lora_scale = attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+ else:
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
+ logger.warning(
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
+ )
+
+ batch_size, num_frames, channels, height, width = hidden_states.shape
+
+ # 1. Time embedding
+ timesteps = timestep
+ t_emb = self.time_proj(timesteps)
+
+ # timesteps does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=hidden_states.dtype)
+ emb = self.time_embedding(t_emb, timestep_cond)
+
+ # 2. Patch embedding
+ hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
+ hidden_states = self.embedding_dropout(hidden_states)
+
+ # Process tracking maps
+ prompt_embed = encoder_hidden_states.clone()
+ tracking_maps_hidden_states = self.patch_embed(prompt_embed, tracking_maps)
+ tracking_maps_hidden_states = self.embedding_dropout(tracking_maps_hidden_states)
+ del prompt_embed
+
+ text_seq_length = encoder_hidden_states.shape[1]
+ encoder_hidden_states = hidden_states[:, :text_seq_length]
+ hidden_states = hidden_states[:, text_seq_length:]
+ tracking_maps = tracking_maps_hidden_states[:, text_seq_length:]
+
+ # Combine hidden states and tracking maps initially
+ combined = hidden_states + tracking_maps
+ tracking_maps = self.initial_combine_linear(combined)
+
+ # Process transformer blocks
+ for i in range(len(self.transformer_blocks)):
+ if self.training and self.gradient_checkpointing:
+ # Gradient checkpointing logic for hidden states
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(self.transformer_blocks[i]),
+ hidden_states,
+ encoder_hidden_states,
+ emb,
+ image_rotary_emb,
+ **ckpt_kwargs,
+ )
+ else:
+ hidden_states, encoder_hidden_states = self.transformer_blocks[i](
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ temb=emb,
+ image_rotary_emb=image_rotary_emb,
+ )
+
+ if i < len(self.transformer_blocks_copy):
+ if self.training and self.gradient_checkpointing:
+ # Gradient checkpointing logic for tracking maps
+ tracking_maps, _ = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(self.transformer_blocks_copy[i]),
+ tracking_maps,
+ encoder_hidden_states,
+ emb,
+ image_rotary_emb,
+ **ckpt_kwargs,
+ )
+ else:
+ tracking_maps, _ = self.transformer_blocks_copy[i](
+ hidden_states=tracking_maps,
+ encoder_hidden_states=encoder_hidden_states,
+ temb=emb,
+ image_rotary_emb=image_rotary_emb,
+ )
+
+ # Combine hidden states and tracking maps
+ tracking_maps = self.combine_linears[i](tracking_maps)
+ hidden_states = hidden_states + tracking_maps
+
+
+ if not self.config.use_rotary_positional_embeddings:
+ # CogVideoX-2B
+ hidden_states = self.norm_final(hidden_states)
+ else:
+ # CogVideoX-5B
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
+ hidden_states = self.norm_final(hidden_states)
+ hidden_states = hidden_states[:, text_seq_length:]
+
+ # 4. Final block
+ hidden_states = self.norm_out(hidden_states, temb=emb)
+ hidden_states = self.proj_out(hidden_states)
+
+ # 5. Unpatchify
+ # Note: we use `-1` instead of `channels`:
+ # - It is okay to `channels` use for CogVideoX-2b and CogVideoX-5b (number of input channels is equal to output channels)
+ # - However, for CogVideoX-5b-I2V also takes concatenated input image latents (number of input channels is twice the output channels)
+ p = self.config.patch_size
+ output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
+ output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
+
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+
+ if not return_dict:
+ return (output,)
+ return Transformer2DModelOutput(sample=output)
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
+ try:
+ model = super().from_pretrained(pretrained_model_name_or_path, **kwargs)
+ print("Loaded DiffusionAsShader checkpoint directly.")
+
+ for param in model.parameters():
+ param.requires_grad = False
+
+ for linear in model.combine_linears:
+ for param in linear.parameters():
+ param.requires_grad = True
+
+ for block in model.transformer_blocks_copy:
+ for param in block.parameters():
+ param.requires_grad = True
+
+ for param in model.initial_combine_linear.parameters():
+ param.requires_grad = True
+
+ return model
+
+ except Exception as e:
+ print(f"Failed to load as DiffusionAsShader: {e}")
+ print("Attempting to load as CogVideoXTransformer3DModel and convert...")
+
+ base_model = CogVideoXTransformer3DModel.from_pretrained(pretrained_model_name_or_path, **kwargs)
+
+ config = dict(base_model.config)
+ config["num_tracking_blocks"] = kwargs.pop("num_tracking_blocks", 18)
+
+ model = cls(**config)
+ model.load_state_dict(base_model.state_dict(), strict=False)
+
+ model.initial_combine_linear.weight.data.zero_()
+ model.initial_combine_linear.bias.data.zero_()
+
+ for linear in model.combine_linears:
+ linear.weight.data.zero_()
+ linear.bias.data.zero_()
+
+ for i in range(model.num_tracking_blocks):
+ model.transformer_blocks_copy[i].load_state_dict(model.transformer_blocks[i].state_dict())
+
+
+ for param in model.parameters():
+ param.requires_grad = False
+
+ for linear in model.combine_linears:
+ for param in linear.parameters():
+ param.requires_grad = True
+
+ for block in model.transformer_blocks_copy:
+ for param in block.parameters():
+ param.requires_grad = True
+
+ for param in model.initial_combine_linear.parameters():
+ param.requires_grad = True
+
+ return model
+
+ def save_pretrained(
+ self,
+ save_directory: Union[str, os.PathLike],
+ is_main_process: bool = True,
+ save_function: Optional[Callable] = None,
+ safe_serialization: bool = True,
+ variant: Optional[str] = None,
+ max_shard_size: Union[int, str] = "5GB",
+ push_to_hub: bool = False,
+ **kwargs,
+ ):
+ super().save_pretrained(
+ save_directory,
+ is_main_process=is_main_process,
+ save_function=save_function,
+ safe_serialization=safe_serialization,
+ variant=variant,
+ max_shard_size=max_shard_size,
+ push_to_hub=push_to_hub,
+ **kwargs,
+ )
+
+ if is_main_process:
+ config_dict = dict(self.config)
+ config_dict.pop("_name_or_path", None)
+ config_dict.pop("_use_default_values", None)
+ config_dict["_class_name"] = "CogVideoXTransformer3DModelTracking"
+ config_dict["num_tracking_blocks"] = self.num_tracking_blocks
+
+ os.makedirs(save_directory, exist_ok=True)
+ with open(os.path.join(save_directory, "config.json"), "w", encoding="utf-8") as f:
+ import json
+ json.dump(config_dict, f, indent=2)
+
+class CogVideoXPipelineTracking(CogVideoXPipeline, DiffusionPipeline):
+
+ def __init__(
+ self,
+ tokenizer: T5Tokenizer,
+ text_encoder: T5EncoderModel,
+ vae: AutoencoderKLCogVideoX,
+ transformer: CogVideoXTransformer3DModelTracking,
+ scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
+ ):
+ super().__init__(tokenizer, text_encoder, vae, transformer, scheduler)
+
+ if not isinstance(self.transformer, CogVideoXTransformer3DModelTracking):
+ raise ValueError("The transformer in this pipeline must be of type CogVideoXTransformer3DModelTracking")
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ height: int = 480,
+ width: int = 720,
+ num_frames: int = 49,
+ num_inference_steps: int = 50,
+ timesteps: Optional[List[int]] = None,
+ guidance_scale: float = 6,
+ use_dynamic_cfg: bool = False,
+ num_videos_per_prompt: int = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: str = "pil",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 226,
+ tracking_maps: Optional[torch.Tensor] = None,
+ ) -> Union[CogVideoXPipelineOutput, Tuple]:
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ num_videos_per_prompt = 1
+
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ negative_prompt,
+ callback_on_step_end_tensor_inputs,
+ prompt_embeds,
+ negative_prompt_embeds,
+ )
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._interrupt = False
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt,
+ negative_prompt,
+ do_classifier_free_guidance,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+ if do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+ self._num_timesteps = len(timesteps)
+
+ latent_channels = self.transformer.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_videos_per_prompt,
+ latent_channels,
+ num_frames,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ image_rotary_emb = (
+ self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
+ if self.transformer.config.use_rotary_positional_embeddings
+ else None
+ )
+
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ old_pred_original_sample = None
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ tracking_maps_latent = torch.cat([tracking_maps] * 2) if do_classifier_free_guidance else tracking_maps
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ timestep = t.expand(latent_model_input.shape[0])
+
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ encoder_hidden_states=prompt_embeds,
+ timestep=timestep,
+ image_rotary_emb=image_rotary_emb,
+ attention_kwargs=attention_kwargs,
+ tracking_maps=tracking_maps_latent,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_pred.float()
+
+ if use_dynamic_cfg:
+ self._guidance_scale = 1 + guidance_scale * (
+ (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
+ )
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ if not isinstance(self.scheduler, CogVideoXDPMScheduler):
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+ else:
+ latents, old_pred_original_sample = self.scheduler.step(
+ noise_pred,
+ old_pred_original_sample,
+ t,
+ timesteps[i - 1] if i > 0 else None,
+ latents,
+ **extra_step_kwargs,
+ return_dict=False,
+ )
+ latents = latents.to(prompt_embeds.dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if not output_type == "latent":
+ video = self.decode_latents(latents)
+ video = self.video_processor.postprocess_video(video=video, output_type=output_type)
+ else:
+ video = latents
+
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+ return CogVideoXPipelineOutput(frames=video)
+
+class CogVideoXImageToVideoPipelineTracking(CogVideoXImageToVideoPipeline, DiffusionPipeline):
+
+ def __init__(
+ self,
+ tokenizer: T5Tokenizer,
+ text_encoder: T5EncoderModel,
+ vae: AutoencoderKLCogVideoX,
+ transformer: CogVideoXTransformer3DModelTracking,
+ scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
+ ):
+ super().__init__(tokenizer, text_encoder, vae, transformer, scheduler)
+
+ if not isinstance(self.transformer, CogVideoXTransformer3DModelTracking):
+ raise ValueError("The transformer in this pipeline must be of type CogVideoXTransformer3DModelTracking")
+
+ # 打印transformer blocks的数量
+ print(f"Number of transformer blocks: {len(self.transformer.transformer_blocks)}")
+ print(f"Number of tracking transformer blocks: {len(self.transformer.transformer_blocks_copy)}")
+ self.transformer = torch.compile(self.transformer)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ image: Union[torch.Tensor, Image.Image],
+ prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_frames: int = 49,
+ num_inference_steps: int = 50,
+ timesteps: Optional[List[int]] = None,
+ guidance_scale: float = 6,
+ use_dynamic_cfg: bool = False,
+ num_videos_per_prompt: int = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: str = "pil",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 226,
+ tracking_maps: Optional[torch.Tensor] = None,
+ tracking_image: Optional[torch.Tensor] = None,
+ coarse_video: Optional[torch.Tensor] = None,
+ start_noise_t: Optional[int] = 10
+ ) -> Union[CogVideoXPipelineOutput, Tuple]:
+ # Most of the implementation remains the same as the parent class
+ # We will modify the parts that need to handle tracking_maps
+
+ # 1. Check inputs and set default values
+ self.check_inputs(
+ image,
+ prompt,
+ height,
+ width,
+ negative_prompt,
+ callback_on_step_end_tensor_inputs,
+ prompt_embeds,
+ negative_prompt_embeds,
+ )
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._interrupt = False
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+ if do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ del negative_prompt_embeds
+
+ # 4. Prepare timesteps
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+ self._num_timesteps = len(timesteps)
+
+ # 5. Prepare latents
+ image = self.video_processor.preprocess(image, height=height, width=width).to(
+ device, dtype=prompt_embeds.dtype
+ )
+
+ tracking_image = self.video_processor.preprocess(tracking_image, height=height, width=width).to(
+ device, dtype=prompt_embeds.dtype
+ )
+ if self.transformer.config.in_channels != 16:
+ latent_channels = self.transformer.config.in_channels // 2
+ else:
+ latent_channels = self.transformer.config.in_channels
+ latents, image_latents = self.prepare_latents(
+ image,
+ batch_size * num_videos_per_prompt,
+ latent_channels,
+ num_frames,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+ del image
+
+ _, tracking_image_latents = self.prepare_latents(
+ tracking_image,
+ batch_size * num_videos_per_prompt,
+ latent_channels,
+ num_frames,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents=None,
+ )
+ del tracking_image
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Create rotary embeds if required
+ image_rotary_emb = (
+ self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
+ if self.transformer.config.use_rotary_positional_embeddings
+ else None
+ )
+
+ # 8. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+ # from diffusers.utils import export_to_video
+ # from diffusers import DDIMScheduler
+ # self.scheduler = DDIMScheduler.from_config(self.scheduler.config)
+ # self.scheduler.set_timesteps(num_inference_steps)
+ if coarse_video is not None:
+ coarse_latents = self.vae.encode(coarse_video * 2 - 1).latent_dist.sample()
+ coarse_latents = coarse_latents.permute(0, 2, 1, 3, 4) * self.vae.config.scaling_factor
+ # coarse_latents = self.scheduler.get_velocity(coarse_latents, torch.randn_like(coarse_latents), timesteps[start_noise_t])
+ coarse_latents = self.scheduler.add_noise(coarse_latents, torch.randn_like(coarse_latents), timesteps[start_noise_t])
+ print('adding noise of', timesteps[start_noise_t])
+ # video = self.decode_latents(coarse_latents)
+ # video = self.video_processor.postprocess_video(video=video, output_type=output_type)
+ # export_to_video(video[0], 'output_addnoise.mp4', fps=24)
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ old_pred_original_sample = None
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ if coarse_video is not None and i < start_noise_t:
+ progress_bar.update()
+ continue
+
+ if coarse_video is not None and i == start_noise_t: # replace from current step
+ latents = coarse_latents
+
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ latent_image_input = torch.cat([image_latents] * 2) if do_classifier_free_guidance else image_latents
+ latent_model_input = torch.cat([latent_model_input, latent_image_input], dim=2)
+ del latent_image_input
+
+ # Handle tracking maps
+ if tracking_maps is not None:
+ latents_tracking_image = torch.cat([tracking_image_latents] * 2) if do_classifier_free_guidance else tracking_image_latents
+ tracking_maps_input = torch.cat([tracking_maps] * 2) if do_classifier_free_guidance else tracking_maps
+ tracking_maps_input = torch.cat([tracking_maps_input, latents_tracking_image], dim=2)
+ del latents_tracking_image
+ else:
+ tracking_maps_input = None
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latent_model_input.shape[0])
+
+ # Predict noise
+ self.transformer.to(dtype=latent_model_input.dtype)
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ encoder_hidden_states=prompt_embeds,
+ timestep=timestep,
+ image_rotary_emb=image_rotary_emb,
+ attention_kwargs=attention_kwargs,
+ tracking_maps=tracking_maps_input,
+ return_dict=False,
+ )[0]
+ del latent_model_input
+ if tracking_maps_input is not None:
+ del tracking_maps_input
+ noise_pred = noise_pred.float()
+
+ # perform guidance
+ if use_dynamic_cfg:
+ self._guidance_scale = 1 + guidance_scale * (
+ (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
+ )
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+ del noise_pred_uncond, noise_pred_text
+
+ # compute the previous noisy sample x_t -> x_t-1
+ if not isinstance(self.scheduler, CogVideoXDPMScheduler):
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+ else:
+ latents, old_pred_original_sample = self.scheduler.step(
+ noise_pred,
+ old_pred_original_sample,
+ t,
+ timesteps[i - 1] if i > 0 else None,
+ latents,
+ **extra_step_kwargs,
+ return_dict=False,
+ )
+ del noise_pred
+ latents = latents.to(prompt_embeds.dtype)
+
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ # 9. Post-processing
+ if not output_type == "latent":
+ video = self.decode_latents(latents)
+ video = self.video_processor.postprocess_video(video=video, output_type=output_type)
+ else:
+ video = latents
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return CogVideoXPipelineOutput(frames=video)
+
+class CogVideoXVideoToVideoPipelineTracking(CogVideoXVideoToVideoPipeline, DiffusionPipeline):
+
+ def __init__(
+ self,
+ tokenizer: T5Tokenizer,
+ text_encoder: T5EncoderModel,
+ vae: AutoencoderKLCogVideoX,
+ transformer: CogVideoXTransformer3DModelTracking,
+ scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
+ ):
+ super().__init__(tokenizer, text_encoder, vae, transformer, scheduler)
+
+ if not isinstance(self.transformer, CogVideoXTransformer3DModelTracking):
+ raise ValueError("The transformer in this pipeline must be of type CogVideoXTransformer3DModelTracking")
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ video: List[Image.Image] = None,
+ prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ height: int = 480,
+ width: int = 720,
+ num_inference_steps: int = 50,
+ timesteps: Optional[List[int]] = None,
+ strength: float = 0.8,
+ guidance_scale: float = 6,
+ use_dynamic_cfg: bool = False,
+ num_videos_per_prompt: int = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: str = "pil",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 226,
+ tracking_maps: Optional[torch.Tensor] = None,
+ ) -> Union[CogVideoXPipelineOutput, Tuple]:
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ num_videos_per_prompt = 1
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt=prompt,
+ height=height,
+ width=width,
+ strength=strength,
+ negative_prompt=negative_prompt,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ video=video,
+ latents=latents,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ )
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._interrupt = False
+
+ # 2. Default call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt,
+ negative_prompt,
+ do_classifier_free_guidance,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+ if do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+
+ # 4. Prepare timesteps
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device)
+ latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
+ self._num_timesteps = len(timesteps)
+
+ # 5. Prepare latents
+ if latents is None:
+ video = self.video_processor.preprocess_video(video, height=height, width=width)
+ video = video.to(device=device, dtype=prompt_embeds.dtype)
+
+ latent_channels = self.transformer.config.in_channels
+ latents = self.prepare_latents(
+ video,
+ batch_size * num_videos_per_prompt,
+ latent_channels,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ latent_timestep,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Create rotary embeds if required
+ image_rotary_emb = (
+ self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
+ if self.transformer.config.use_rotary_positional_embeddings
+ else None
+ )
+
+ # 8. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ # for DPM-solver++
+ old_pred_original_sample = None
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ tracking_maps_input = torch.cat([tracking_maps] * 2) if do_classifier_free_guidance else tracking_maps
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latent_model_input.shape[0])
+
+ # predict noise model_output
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ encoder_hidden_states=prompt_embeds,
+ timestep=timestep,
+ image_rotary_emb=image_rotary_emb,
+ attention_kwargs=attention_kwargs,
+ tracking_maps=tracking_maps_input,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_pred.float()
+
+ # perform guidance
+ if use_dynamic_cfg:
+ self._guidance_scale = 1 + guidance_scale * (
+ (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
+ )
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ if not isinstance(self.scheduler, CogVideoXDPMScheduler):
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+ else:
+ latents, old_pred_original_sample = self.scheduler.step(
+ noise_pred,
+ old_pred_original_sample,
+ t,
+ timesteps[i - 1] if i > 0 else None,
+ latents,
+ **extra_step_kwargs,
+ return_dict=False,
+ )
+ latents = latents.to(prompt_embeds.dtype)
+
+ # call the callback, if provided
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if not output_type == "latent":
+ video = self.decode_latents(latents)
+ video = self.video_processor.postprocess_video(video=video, output_type=output_type)
+ else:
+ video = latents
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return CogVideoXPipelineOutput(frames=video)
+
diff --git a/libs/das/models/pipelines.py b/libs/das/models/pipelines.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1ad15b57d4d8942e59b207ca097d16a64f72e6c
--- /dev/null
+++ b/libs/das/models/pipelines.py
@@ -0,0 +1,851 @@
+import os
+import sys
+import math
+from tqdm import tqdm
+from PIL import Image, ImageDraw
+project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+# try:
+# sys.path.append(os.path.join(project_root, "submodules/MoGe"))
+# os.environ["TOKENIZERS_PARALLELISM"] = "false"
+# except:
+# print("Warning: MoGe not found, motion transfer will not be applied")
+
+import torch
+import numpy as np
+from PIL import Image
+import torchvision.transforms as transforms
+from diffusers import CogVideoXDPMScheduler
+from diffusers.utils import export_to_video, load_image, load_video
+
+# from models.spatracker.predictor import SpaTrackerPredictor
+# from models.spatracker.utils.visualizer import Visualizer
+from models.cogvideox_tracking import CogVideoXImageToVideoPipelineTracking
+
+# from submodules.MoGe.moge.model import MoGeModel
+from image_gen_aux import DepthPreprocessor
+from moviepy.editor import ImageSequenceClip
+from typing import Any, Dict, Optional, Tuple, Union, List, Callable
+
+class DiffusionAsShaderPipeline:
+ def __init__(self, gpu_id=0, output_dir='outputs'):
+ """Initialize MotionTransfer class
+
+ Args:
+ gpu_id (int): GPU device ID
+ output_dir (str): Output directory path
+ """
+ # video parameters
+ self.max_depth = 65.0
+ self.fps = 8
+
+ # camera parameters
+ self.camera_motion=None
+ self.fov=55
+
+ # device
+ self.device = f"cuda:{gpu_id}"
+ torch.cuda.set_device(gpu_id)
+
+ # files
+ self.output_dir = output_dir
+ os.makedirs(output_dir, exist_ok=True)
+
+ # Initialize transform
+ self.transform = transforms.Compose([
+ transforms.Resize((480, 720)),
+ transforms.ToTensor()
+ ])
+
+ @torch.no_grad()
+ def _infer(
+ self,
+ prompt: str,
+ model_path: str,
+ tracking_tensor: torch.Tensor = None,
+ image_tensor: torch.Tensor = None, # [C,H,W] in range [0,1]
+ output_path: str = "./output.mp4",
+ num_inference_steps: int = 50,
+ guidance_scale: float = 6.0,
+ num_videos_per_prompt: int = 1,
+ dtype: torch.dtype = torch.bfloat16,
+ fps: int = 24,
+ seed: int = 0,
+ coarse_video: Optional[torch.Tensor] = None,
+ start_noise_t: Optional[int] = 10
+ ):
+ """
+ Generates a video based on the given prompt and saves it to the specified path.
+
+ Parameters:
+ - prompt (str): The description of the video to be generated.
+ - model_path (str): The path of the pre-trained model to be used.
+ - tracking_tensor (torch.Tensor): Tracking video tensor [T, C, H, W] in range [0,1]
+ - image_tensor (torch.Tensor): Input image tensor [C, H, W] in range [0,1]
+ - output_path (str): The path where the generated video will be saved.
+ - num_inference_steps (int): Number of steps for the inference process.
+ - guidance_scale (float): The scale for classifier-free guidance.
+ - num_videos_per_prompt (int): Number of videos to generate per prompt.
+ - dtype (torch.dtype): The data type for computation.
+ - seed (int): The seed for reproducibility.
+ """
+ pipe = CogVideoXImageToVideoPipelineTracking.from_pretrained(model_path, torch_dtype=dtype)
+
+ # Convert tensor to PIL Image
+ image_np = (image_tensor.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
+ image = Image.fromarray(image_np)
+ height, width = image.height, image.width
+
+ pipe.transformer.eval()
+ pipe.text_encoder.eval()
+ pipe.vae.eval()
+
+ # Process tracking tensor
+ tracking_maps = tracking_tensor.float() # [T, C, H, W]
+ tracking_maps = tracking_maps.to(device=self.device, dtype=dtype)
+ tracking_first_frame = tracking_maps[0:1] # Get first frame as [1, C, H, W]
+ height, width = tracking_first_frame.shape[2], tracking_first_frame.shape[3]
+
+ if coarse_video is not None:
+ coarse_video = coarse_video.to(device=self.device, dtype=dtype).unsqueeze(0).permute(0, 2, 1, 3, 4)
+
+ # 2. Set Scheduler.
+ pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
+
+ pipe.to(self.device, dtype=dtype)
+ # pipe.enable_sequential_cpu_offload()
+
+ pipe.vae.enable_slicing()
+ pipe.vae.enable_tiling()
+ pipe.transformer.eval()
+ pipe.text_encoder.eval()
+ pipe.vae.eval()
+
+ pipe.transformer.gradient_checkpointing = False
+
+ print("Encoding tracking maps")
+ tracking_maps = tracking_maps.unsqueeze(0) # [B, T, C, H, W]
+ tracking_maps = tracking_maps.permute(0, 2, 1, 3, 4) # [B, C, T, H, W]
+ tracking_latent_dist = pipe.vae.encode(tracking_maps).latent_dist
+ tracking_maps = tracking_latent_dist.sample() * pipe.vae.config.scaling_factor
+ tracking_maps = tracking_maps.permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
+
+ # 4. Generate the video frames based on the prompt.
+ video_generate = pipe(
+ prompt=prompt,
+ negative_prompt="The video is not of a high quality, it has a low resolution. Watermark present in each frame. The background is solid. Strange body and strange trajectory. Distortion.",
+ image=image,
+ num_videos_per_prompt=num_videos_per_prompt,
+ num_inference_steps=num_inference_steps,
+ num_frames=49,
+ use_dynamic_cfg=True,
+ guidance_scale=guidance_scale,
+ generator=torch.Generator().manual_seed(seed),
+ tracking_maps=tracking_maps,
+ tracking_image=tracking_first_frame,
+ height=height,
+ width=width,
+ coarse_video=coarse_video,
+ start_noise_t=start_noise_t
+ ).frames[0]
+
+ # 5. Export the generated frames to a video file. fps must be 8 for original video.
+ output_path = output_path if output_path else f"result.mp4"
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
+ export_to_video(video_generate, output_path, fps=fps)
+
+ #========== camera parameters ==========#
+
+ def _set_camera_motion(self, camera_motion):
+ self.camera_motion = camera_motion
+
+ def _get_intr(self, fov, H=480, W=720):
+ fov_rad = math.radians(fov)
+ focal_length = (W / 2) / math.tan(fov_rad / 2)
+
+ cx = W / 2
+ cy = H / 2
+
+ intr = torch.tensor([
+ [focal_length, 0, cx],
+ [0, focal_length, cy],
+ [0, 0, 1]
+ ], dtype=torch.float32)
+
+ return intr
+
+ def _apply_poses(self, pts, intr, poses):
+ """
+ Args:
+ pts (torch.Tensor): pointclouds coordinates [T, N, 3]
+ intr (torch.Tensor): camera intrinsics [T, 3, 3]
+ poses (numpy.ndarray): camera poses [T, 4, 4]
+ """
+ poses = torch.from_numpy(poses).float().to(self.device)
+
+ T, N, _ = pts.shape
+ ones = torch.ones(T, N, 1, device=self.device, dtype=torch.float)
+ pts_hom = torch.cat([pts[:, :, :2], ones], dim=-1) # (T, N, 3)
+ pts_cam = torch.bmm(pts_hom, torch.linalg.inv(intr).transpose(1, 2)) # (T, N, 3)
+ pts_cam[:,:, :3] /= pts[:, :, 2:3]
+
+ # to homogeneous
+ pts_cam = torch.cat([pts_cam, ones], dim=-1) # (T, N, 4)
+
+ if poses.shape[0] == 1:
+ poses = poses.repeat(T, 1, 1)
+ elif poses.shape[0] != T:
+ raise ValueError(f"Poses length ({poses.shape[0]}) must match sequence length ({T})")
+
+ pts_world = torch.bmm(pts_cam, poses.transpose(1, 2))[:, :, :3] # (T, N, 3)
+
+ pts_proj = torch.bmm(pts_world, intr.transpose(1, 2)) # (T, N, 3)
+ pts_proj[:, :, :2] /= pts_proj[:, :, 2:3]
+
+ return pts_proj
+
+ def apply_traj_on_tracking(self, pred_tracks, camera_motion=None, fov=55, frame_num=49):
+ intr = self._get_intr(fov).unsqueeze(0).repeat(frame_num, 1, 1).to(self.device)
+ tracking_pts = self._apply_poses(pred_tracks.squeeze(), intr, camera_motion).unsqueeze(0)
+ return tracking_pts
+
+ ##============= SpatialTracker =============##
+
+ def generate_tracking_spatracker(self, video_tensor, density=70):
+ """Generate tracking video
+
+ Args:
+ video_tensor (torch.Tensor): Input video tensor
+
+ Returns:
+ str: Path to tracking video
+ """
+ print("Loading tracking models...")
+ # Load tracking model
+ tracker = SpaTrackerPredictor(
+ checkpoint=os.path.join(project_root, 'checkpoints/spaT_final.pth'),
+ interp_shape=(384, 576),
+ seq_length=12
+ ).to(self.device)
+
+ # Load depth model
+ self.depth_preprocessor = DepthPreprocessor.from_pretrained("Intel/zoedepth-nyu-kitti")
+ self.depth_preprocessor.to(self.device)
+
+ try:
+ video = video_tensor.unsqueeze(0).to(self.device)
+
+ video_depths = []
+ for i in range(video_tensor.shape[0]):
+ frame = (video_tensor[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
+ depth = self.depth_preprocessor(Image.fromarray(frame))[0]
+ depth_tensor = transforms.ToTensor()(depth) # [1, H, W]
+ video_depths.append(depth_tensor)
+ video_depth = torch.stack(video_depths, dim=0).to(self.device)
+ # print("Video depth shape:", video_depth.shape)
+
+ segm_mask = np.ones((480, 720), dtype=np.uint8)
+
+ pred_tracks, pred_visibility, T_Firsts = tracker(
+ video * 255,
+ video_depth=video_depth,
+ grid_size=density,
+ backward_tracking=False,
+ depth_predictor=None,
+ grid_query_frame=0,
+ segm_mask=torch.from_numpy(segm_mask)[None, None].to(self.device),
+ wind_length=12,
+ progressive_tracking=False
+ )
+
+ return pred_tracks, pred_visibility, T_Firsts
+
+ finally:
+ # Clean up GPU memory
+ del tracker, self.depth_preprocessor
+ torch.cuda.empty_cache()
+
+ def visualize_tracking_spatracker(self, video, pred_tracks, pred_visibility, T_Firsts, save_tracking=True):
+ video = video.unsqueeze(0).to(self.device)
+ vis = Visualizer(save_dir=self.output_dir, grayscale=False, fps=24, pad_value=0)
+ msk_query = (T_Firsts == 0)
+ pred_tracks = pred_tracks[:,:,msk_query.squeeze()]
+ pred_visibility = pred_visibility[:,:,msk_query.squeeze()]
+
+ tracking_video = vis.visualize(video=video, tracks=pred_tracks,
+ visibility=pred_visibility, save_video=False,
+ filename="temp")
+
+ tracking_video = tracking_video.squeeze(0) # [T, C, H, W]
+ wide_list = list(tracking_video.unbind(0))
+ wide_list = [wide.permute(1, 2, 0).cpu().numpy() for wide in wide_list]
+ clip = ImageSequenceClip(wide_list, fps=self.fps)
+
+ tracking_path = None
+ if save_tracking:
+ try:
+ tracking_path = os.path.join(self.output_dir, "tracking_video.mp4")
+ clip.write_videofile(tracking_path, codec="libx264", fps=self.fps, logger=None)
+ print(f"Video saved to {tracking_path}")
+ except Exception as e:
+ print(f"Warning: Failed to save tracking video: {e}")
+ tracking_path = None
+
+ # Convert tracking_video back to tensor in range [0,1]
+ tracking_frames = np.array(list(clip.iter_frames())) / 255.0
+ tracking_video = torch.from_numpy(tracking_frames).permute(0, 3, 1, 2).float()
+
+ return tracking_path, tracking_video
+
+ ##============= MoGe =============##
+
+ def valid_mask(self, pixels, W, H):
+ """Check if pixels are within valid image bounds
+
+ Args:
+ pixels (numpy.ndarray): Pixel coordinates of shape [N, 2]
+ W (int): Image width
+ H (int): Image height
+
+ Returns:
+ numpy.ndarray: Boolean mask of valid pixels
+ """
+ return ((pixels[:, 0] >= 0) & (pixels[:, 0] < W) & (pixels[:, 1] > 0) & \
+ (pixels[:, 1] < H))
+
+ def sort_points_by_depth(self, points, depths):
+ """Sort points by depth values
+
+ Args:
+ points (numpy.ndarray): Points array of shape [N, 2]
+ depths (numpy.ndarray): Depth values of shape [N]
+
+ Returns:
+ tuple: (sorted_points, sorted_depths, sort_index)
+ """
+ # Combine points and depths into a single array for sorting
+ combined = np.hstack((points, depths[:, None])) # Nx3 (points + depth)
+ # Sort by depth (last column) in descending order
+ sort_index = combined[:, -1].argsort()[::-1]
+ sorted_combined = combined[sort_index]
+ # Split back into points and depths
+ sorted_points = sorted_combined[:, :-1]
+ sorted_depths = sorted_combined[:, -1]
+ return sorted_points, sorted_depths, sort_index
+
+ def draw_rectangle(self, rgb, coord, side_length, color=(255, 0, 0)):
+ """Draw a rectangle on the image
+
+ Args:
+ rgb (PIL.Image): Image to draw on
+ coord (tuple): Center coordinates (x, y)
+ side_length (int): Length of rectangle sides
+ color (tuple): RGB color tuple
+ """
+ draw = ImageDraw.Draw(rgb)
+ # Calculate the bounding box of the rectangle
+ left_up_point = (coord[0] - side_length//2, coord[1] - side_length//2)
+ right_down_point = (coord[0] + side_length//2, coord[1] + side_length//2)
+ color = tuple(list(color))
+
+ draw.rectangle(
+ [left_up_point, right_down_point],
+ fill=tuple(color),
+ outline=tuple(color),
+ )
+
+ def visualize_tracking_moge(self, points, mask, save_tracking=True):
+ """Visualize tracking results from MoGe model
+
+ Args:
+ points (numpy.ndarray): Points array of shape [T, H, W, 3]
+ mask (numpy.ndarray): Binary mask of shape [H, W]
+ save_tracking (bool): Whether to save tracking video
+
+ Returns:
+ tuple: (tracking_path, tracking_video)
+ - tracking_path (str): Path to saved tracking video, None if save_tracking is False
+ - tracking_video (torch.Tensor): Tracking visualization tensor of shape [T, C, H, W] in range [0,1]
+ """
+ # Create color array
+ T, H, W, _ = points.shape
+ colors = np.zeros((H, W, 3), dtype=np.uint8)
+
+ # Set R channel - based on x coordinates (smaller on the left)
+ colors[:, :, 0] = np.tile(np.linspace(0, 255, W), (H, 1))
+
+ # Set G channel - based on y coordinates (smaller on the top)
+ colors[:, :, 1] = np.tile(np.linspace(0, 255, H), (W, 1)).T
+
+ # Set B channel - based on depth
+ z_values = points[0, :, :, 2] # get z values
+ inv_z = 1 / z_values # calculate 1/z
+ # Calculate 2% and 98% percentiles
+ p2 = np.percentile(inv_z, 2)
+ p98 = np.percentile(inv_z, 98)
+ # Normalize to [0,1] range
+ normalized_z = np.clip((inv_z - p2) / (p98 - p2), 0, 1)
+ colors[:, :, 2] = (normalized_z * 255).astype(np.uint8)
+ colors = colors.astype(np.uint8)
+
+ # colors = colors[mask]
+ # points = points * mask[None, :, :, None]
+
+ points = points.reshape(T, -1, 3)
+ colors = colors.reshape(-1, 3)
+
+ # Initialize list to store frames
+ frames = []
+
+ for i, pts_i in enumerate(tqdm(points)):
+ pixels, depths = pts_i[..., :2], pts_i[..., 2]
+ pixels[..., 0] = pixels[..., 0] * W
+ pixels[..., 1] = pixels[..., 1] * H
+ pixels = pixels.astype(int)
+
+ valid = self.valid_mask(pixels, W, H)
+ frame_rgb = colors[valid]
+ pixels = pixels[valid]
+ depths = depths[valid]
+
+ img = Image.fromarray(np.uint8(np.zeros([H, W, 3])), mode="RGB")
+ sorted_pixels, _, sort_index = self.sort_points_by_depth(pixels, depths)
+ step = 1
+ sorted_pixels = sorted_pixels[::step]
+ sorted_rgb = frame_rgb[sort_index][::step]
+
+ for j in range(sorted_pixels.shape[0]):
+ self.draw_rectangle(
+ img,
+ coord=(sorted_pixels[j, 0], sorted_pixels[j, 1]),
+ side_length=2,
+ color=sorted_rgb[j],
+ )
+ frames.append(np.array(img))
+
+ # Convert frames to video tensor in range [0,1]
+ tracking_video = torch.from_numpy(np.stack(frames)).permute(0, 3, 1, 2).float() / 255.0
+
+ tracking_path = None
+ if save_tracking:
+ try:
+ tracking_path = os.path.join(self.output_dir, "tracking_video_moge.mp4")
+ # Convert back to uint8 for saving
+ uint8_frames = [frame.astype(np.uint8) for frame in frames]
+ clip = ImageSequenceClip(uint8_frames, fps=self.fps)
+ clip.write_videofile(tracking_path, codec="libx264", fps=self.fps, logger=None)
+ print(f"Video saved to {tracking_path}")
+ except Exception as e:
+ print(f"Warning: Failed to save tracking video: {e}")
+ tracking_path = None
+
+ return tracking_path, tracking_video
+
+ def apply_tracking(self, video_tensor, fps=8, tracking_tensor=None, img_cond_tensor=None, prompt=None, checkpoint_path=None, coarse_video=None, start_noise_t=0, seed=0):
+ """Generate final video with motion transfer
+
+ Args:
+ video_tensor (torch.Tensor): Input video tensor [T,C,H,W]
+ fps (float): Input video FPS
+ tracking_tensor (torch.Tensor): Tracking video tensor [T,C,H,W]
+ image_tensor (torch.Tensor): First frame tensor [C,H,W] to use for generation
+ prompt (str): Generation prompt
+ checkpoint_path (str): Path to model checkpoint
+ """
+ self.fps = fps
+
+ # Use first frame if no image provided
+ if img_cond_tensor is None:
+ img_cond_tensor = video_tensor[0]
+
+ # Generate final video
+ final_output = os.path.join(os.path.abspath(self.output_dir), "result.mp4" if coarse_video is None else f"result_{start_noise_t}.mp4")
+ self._infer(
+ prompt=prompt,
+ model_path=checkpoint_path,
+ tracking_tensor=tracking_tensor,
+ image_tensor=img_cond_tensor,
+ output_path=final_output,
+ num_inference_steps=50,
+ guidance_scale=6.0,
+ dtype=torch.bfloat16,
+ fps=self.fps,
+ coarse_video=coarse_video,
+ start_noise_t=start_noise_t,
+ seed=seed
+ )
+ print(f"Final video generated successfully at: {final_output}")
+
+ def _set_object_motion(self, motion_type):
+ """Set object motion type
+
+ Args:
+ motion_type (str): Motion direction ('up', 'down', 'left', 'right')
+ """
+ self.object_motion = motion_type
+
+class CameraMotionGenerator:
+ def __init__(self, motion_type, frame_num=49, H=480, W=720, fx=None, fy=None, fov=55, device='cuda'):
+ self.motion_type = motion_type
+ self.frame_num = frame_num
+ self.fov = fov
+ self.device = device
+ self.W = W
+ self.H = H
+ self.intr = torch.tensor([
+ [0, 0, W / 2],
+ [0, 0, H / 2],
+ [0, 0, 1]
+ ], dtype=torch.float32, device=device)
+ # if fx, fy not provided
+ if not fx or not fy:
+ fov_rad = math.radians(fov)
+ fx = fy = (W / 2) / math.tan(fov_rad / 2)
+
+ self.intr[0, 0] = fx
+ self.intr[1, 1] = fy
+
+ def _apply_poses(self, pts, poses):
+ """
+ Args:
+ pts (torch.Tensor): pointclouds coordinates [T, N, 3]
+ intr (torch.Tensor): camera intrinsics [T, 3, 3]
+ poses (numpy.ndarray): camera poses [T, 4, 4]
+ """
+ if isinstance(poses, np.ndarray):
+ poses = torch.from_numpy(poses)
+
+ intr = self.intr.unsqueeze(0).repeat(self.frame_num, 1, 1).to(torch.float)
+ T, N, _ = pts.shape
+ ones = torch.ones(T, N, 1, device=self.device, dtype=torch.float)
+ pts_hom = torch.cat([pts[:, :, :2], ones], dim=-1) # (T, N, 3)
+ pts_cam = torch.bmm(pts_hom, torch.linalg.inv(intr).transpose(1, 2)) # (T, N, 3)
+ pts_cam[:,:, :3] *= pts[:, :, 2:3]
+
+ # to homogeneous
+ pts_cam = torch.cat([pts_cam, ones], dim=-1) # (T, N, 4)
+
+ if poses.shape[0] == 1:
+ poses = poses.repeat(T, 1, 1)
+ elif poses.shape[0] != T:
+ raise ValueError(f"Poses length ({poses.shape[0]}) must match sequence length ({T})")
+
+ poses = poses.to(torch.float).to(self.device)
+ pts_world = torch.bmm(pts_cam, poses.transpose(1, 2))[:, :, :3] # (T, N, 3)
+ pts_proj = torch.bmm(pts_world, intr.transpose(1, 2)) # (T, N, 3)
+ pts_proj[:, :, :2] /= pts_proj[:, :, 2:3]
+
+ return pts_proj
+
+ def w2s(self, pts, poses):
+ if isinstance(poses, np.ndarray):
+ poses = torch.from_numpy(poses)
+ assert poses.shape[0] == self.frame_num
+ poses = poses.to(torch.float32).to(self.device)
+ T, N, _ = pts.shape # (T, N, 3)
+ intr = self.intr.unsqueeze(0).repeat(self.frame_num, 1, 1)
+ # Step 1: 扩展点的维度,使其变成 (T, N, 4),最后一维填充1 (齐次坐标)
+ ones = torch.ones((T, N, 1), device=self.device, dtype=pts.dtype)
+ points_world_h = torch.cat([pts, ones], dim=-1)
+ points_camera_h = torch.bmm(poses, points_world_h.permute(0, 2, 1))
+ points_camera = points_camera_h[:, :3, :].permute(0, 2, 1)
+
+ points_image_h = torch.bmm(points_camera, intr.permute(0, 2, 1))
+
+ uv = points_image_h[:, :, :2] / points_image_h[:, :, 2:3]
+
+ # Step 5: 提取深度 (Z) 并拼接
+ depth = points_camera[:, :, 2:3] # (T, N, 1)
+ uvd = torch.cat([uv, depth], dim=-1) # (T, N, 3)
+
+ return uvd # 屏幕坐标 + 深度 (T, N, 3)
+
+ def apply_motion_on_pts(self, pts, camera_motion):
+ tracking_pts = self._apply_poses(pts.squeeze(), camera_motion).unsqueeze(0)
+ return tracking_pts
+
+ def set_intr(self, K):
+ if isinstance(K, np.ndarray):
+ K = torch.from_numpy(K)
+ self.intr = K.to(self.device)
+
+ def rot_poses(self, angle, axis='y'):
+ """
+ pts (torch.Tensor): [T, N, 3]
+ angle (int): angle of rotation (degree)
+ """
+ angle_rad = math.radians(angle)
+ angles = torch.linspace(0, angle_rad, self.frame_num)
+ rot_mats = torch.zeros(self.frame_num, 4, 4)
+
+ for i, theta in enumerate(angles):
+ cos_theta = torch.cos(theta)
+ sin_theta = torch.sin(theta)
+ if axis == 'x':
+ rot_mats[i] = torch.tensor([
+ [1, 0, 0, 0],
+ [0, cos_theta, -sin_theta, 0],
+ [0, sin_theta, cos_theta, 0],
+ [0, 0, 0, 1]
+ ], dtype=torch.float32)
+ elif axis == 'y':
+ rot_mats[i] = torch.tensor([
+ [cos_theta, 0, sin_theta, 0],
+ [0, 1, 0, 0],
+ [-sin_theta, 0, cos_theta, 0],
+ [0, 0, 0, 1]
+ ], dtype=torch.float32)
+
+ elif axis == 'z':
+ rot_mats[i] = torch.tensor([
+ [cos_theta, -sin_theta, 0, 0],
+ [sin_theta, cos_theta, 0, 0],
+ [0, 0, 1, 0],
+ [0, 0, 0, 1]
+ ], dtype=torch.float32)
+ else:
+ raise ValueError("Invalid axis value. Choose 'x', 'y', or 'z'.")
+
+ return rot_mats.to(self.device)
+
+ def trans_poses(self, dx, dy, dz):
+ """
+ params:
+ - dx: float, displacement along x axis。
+ - dy: float, displacement along y axis。
+ - dz: float, displacement along z axis。
+
+ ret:
+ - matrices: torch.Tensor
+ """
+ trans_mats = torch.eye(4).unsqueeze(0).repeat(self.frame_num, 1, 1) # (n, 4, 4)
+
+ delta_x = dx / (self.frame_num - 1)
+ delta_y = dy / (self.frame_num - 1)
+ delta_z = dz / (self.frame_num - 1)
+
+ for i in range(self.frame_num):
+ trans_mats[i, 0, 3] = i * delta_x
+ trans_mats[i, 1, 3] = i * delta_y
+ trans_mats[i, 2, 3] = i * delta_z
+
+ return trans_mats.to(self.device)
+
+
+ def _look_at(self, camera_position, target_position):
+ # look at direction
+ direction = target_position - camera_position
+ direction /= np.linalg.norm(direction)
+ # calculate rotation matrix
+ up = np.array([0, 1, 0])
+ right = np.cross(up, direction)
+ right /= np.linalg.norm(right)
+ up = np.cross(direction, right)
+ rotation_matrix = np.vstack([right, up, direction])
+ rotation_matrix = np.linalg.inv(rotation_matrix)
+ return rotation_matrix
+
+ def spiral_poses(self, radius, forward_ratio = 0.5, backward_ratio = 0.5, rotation_times = 0.1, look_at_times = 0.5):
+ """Generate spiral camera poses
+
+ Args:
+ radius (float): Base radius of the spiral
+ forward_ratio (float): Scale factor for forward motion
+ backward_ratio (float): Scale factor for backward motion
+ rotation_times (float): Number of rotations to complete
+ look_at_times (float): Scale factor for look-at point distance
+
+ Returns:
+ torch.Tensor: Camera poses of shape [num_frames, 4, 4]
+ """
+ # Generate spiral trajectory
+ t = np.linspace(0, 1, self.frame_num)
+ r = np.sin(np.pi * t) * radius * rotation_times
+ theta = 2 * np.pi * t
+
+ # Calculate camera positions
+ # Limit y motion for better floor/sky view
+ y = r * np.cos(theta) * 0.3
+ x = r * np.sin(theta)
+ z = -r
+ z[z < 0] *= forward_ratio
+ z[z > 0] *= backward_ratio
+
+ # Set look-at target
+ target_pos = np.array([0, 0, radius * look_at_times])
+ cam_pos = np.vstack([x, y, z]).T
+ cam_poses = []
+
+ for pos in cam_pos:
+ rot_mat = self._look_at(pos, target_pos)
+ trans_mat = np.eye(4)
+ trans_mat[:3, :3] = rot_mat
+ trans_mat[:3, 3] = pos
+ cam_poses.append(trans_mat[None])
+
+ camera_poses = np.concatenate(cam_poses, axis=0)
+ return torch.from_numpy(camera_poses).to(self.device)
+
+ def rot(self, pts, angle, axis):
+ """
+ pts: torch.Tensor, (T, N, 2)
+ """
+ rot_mats = self.rot_poses(angle, axis)
+ pts = self.apply_motion_on_pts(pts, rot_mats)
+ return pts
+
+ def trans(self, pts, dx, dy, dz):
+ if pts.shape[-1] != 3:
+ raise ValueError("points should be in the 3d coordinate.")
+ trans_mats = self.trans_poses(dx, dy, dz)
+ pts = self.apply_motion_on_pts(pts, trans_mats)
+ return pts
+
+ def spiral(self, pts, radius):
+ spiral_poses = self.spiral_poses(radius)
+ pts = self.apply_motion_on_pts(pts, spiral_poses)
+ return pts
+
+ def get_default_motion(self):
+ if self.motion_type == 'trans':
+ motion = self.trans_poses(0.1, 0, 0)
+ elif self.motion_type == 'spiral':
+ motion = self.spiral_poses(1)
+ elif self.motion_type == 'rot':
+ motion = self.rot_poses(-25, 'y')
+ else:
+ raise ValueError(f'camera_motion must be in [trans, spiral, rot], but get {self.motion_type}.')
+
+ return motion
+
+class ObjectMotionGenerator:
+ def __init__(self, device="cuda:0"):
+ """Initialize ObjectMotionGenerator
+
+ Args:
+ device (str): Device to run on
+ """
+ self.device = device
+ self.num_frames = 49
+
+ def _get_points_in_mask(self, pred_tracks, mask):
+ """Get points that fall within the mask in first frame
+
+ Args:
+ pred_tracks (torch.Tensor): [num_frames, num_points, 3]
+ mask (torch.Tensor): [H, W] binary mask
+
+ Returns:
+ torch.Tensor: Boolean mask of selected points [num_points]
+ """
+ first_frame_points = pred_tracks[0] # [num_points, 3]
+ xy_points = first_frame_points[:, :2] # [num_points, 2]
+
+ # Convert xy coordinates to pixel indices
+ xy_pixels = xy_points.round().long() # Convert to integer pixel coordinates
+
+ # Clamp coordinates to valid range
+ xy_pixels[:, 0].clamp_(0, mask.shape[1] - 1) # x coordinates
+ xy_pixels[:, 1].clamp_(0, mask.shape[0] - 1) # y coordinates
+
+ # Get mask values at point locations
+ points_in_mask = mask[xy_pixels[:, 1], xy_pixels[:, 0]] # Index using y, x order
+
+ return points_in_mask
+
+ def generate_motion(self, mask, motion_type, distance, num_frames=49):
+ """Generate motion dictionary for the given parameters
+
+ Args:
+ mask (torch.Tensor): [H, W] binary mask
+ motion_type (str): Motion direction ('up', 'down', 'left', 'right')
+ distance (float): Total distance to move
+ num_frames (int): Number of frames
+
+ Returns:
+ dict: Motion dictionary containing:
+ - mask (torch.Tensor): Binary mask
+ - motions (torch.Tensor): Per-frame motion vectors [num_frames, 4, 4]
+ """
+
+ self.num_frames = num_frames
+ # Define motion template vectors
+ template = {
+ 'up': torch.tensor([0, -1, 0]),
+ 'down': torch.tensor([0, 1, 0]),
+ 'left': torch.tensor([-1, 0, 0]),
+ 'right': torch.tensor([1, 0, 0]),
+ 'front': torch.tensor([0, 0, 1]),
+ 'back': torch.tensor([0, 0, -1])
+ }
+
+ if motion_type not in template:
+ raise ValueError(f"Unknown motion type: {motion_type}")
+
+ # Move mask to device
+ mask = mask.to(self.device)
+
+ # Generate per-frame motion matrices
+ motions = []
+ base_vec = template[motion_type].to(self.device) * distance
+
+ for frame_idx in range(num_frames):
+ # Calculate interpolation factor (0 to 1)
+ t = frame_idx / (num_frames - 1)
+
+ # Create motion matrix for current frame
+ current_motion = torch.eye(4, device=self.device)
+ current_motion[:3, 3] = base_vec * t
+ motions.append(current_motion)
+
+ motions = torch.stack(motions) # [num_frames, 4, 4]
+
+ return {
+ 'mask': mask,
+ 'motions': motions
+ }
+
+ def apply_motion(self, pred_tracks, motion_dict, tracking_method="spatracker"):
+ """Apply motion to selected points
+
+ Args:
+ pred_tracks (torch.Tensor): [num_frames, num_points, 3] for spatracker
+ or [T, H, W, 3] for moge
+ motion_dict (dict): Motion dictionary containing mask and motions
+ tracking_method (str): "spatracker" or "moge"
+
+ Returns:
+ torch.Tensor: Modified pred_tracks with same shape as input
+ """
+ pred_tracks = pred_tracks.to(self.device).float()
+
+ if tracking_method == "moge":
+ T, H, W, _ = pred_tracks.shape
+
+ selected_mask = motion_dict['mask']
+ valid_selected = ~torch.any(torch.isnan(pred_tracks[0]), dim=2) & selected_mask
+ valid_selected = valid_selected.reshape([-1])
+ modified_tracks = pred_tracks.clone().reshape(T, -1, 3)
+
+ for frame_idx in range(self.num_frames):
+ motion_mat = motion_dict['motions'][frame_idx]
+ motion_mat[0, 3] /= W
+ motion_mat[1, 3] /= H
+ points = modified_tracks[frame_idx, valid_selected]
+ points_homo = torch.cat([points, torch.ones_like(points[:, :1])], dim=1)
+
+ transformed_points = torch.matmul(points_homo, motion_mat.T)
+ modified_tracks[frame_idx, valid_selected] = transformed_points[:, :3]
+ return modified_tracks
+
+ else:
+ points_in_mask = self._get_points_in_mask(pred_tracks, motion_dict['mask'])
+ modified_tracks = pred_tracks.clone()
+
+ for frame_idx in range(pred_tracks.shape[0]):
+ motion_mat = motion_dict['motions'][frame_idx]
+ points = modified_tracks[frame_idx, points_in_mask]
+ points_homo = torch.cat([points, torch.ones_like(points[:, :1])], dim=1)
+ transformed_points = torch.matmul(points_homo, motion_mat.T)
+ modified_tracks[frame_idx, points_in_mask] = transformed_points[:, :3]
+
+ return modified_tracks
diff --git a/libs/gs/__init__.py b/libs/gs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/libs/gs/gaussian_model.py b/libs/gs/gaussian_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd2c1f625668e433a4ae762ff4ebdcd5edf0bc8c
--- /dev/null
+++ b/libs/gs/gaussian_model.py
@@ -0,0 +1,193 @@
+import torch
+import numpy as np
+from plyfile import PlyData, PlyElement
+from .general_utils import inverse_sigmoid, strip_symmetric, build_scaling_rotation
+
+class Gaussian:
+ def __init__(
+ self,
+ aabb : list,
+ sh_degree : int = 0,
+ mininum_kernel_size : float = 0.0,
+ scaling_bias : float = 0.01,
+ opacity_bias : float = 0.1,
+ scaling_activation : str = "exp",
+ device='cuda'
+ ):
+ self.init_params = {
+ 'aabb': aabb,
+ 'sh_degree': sh_degree,
+ 'mininum_kernel_size': mininum_kernel_size,
+ 'scaling_bias': scaling_bias,
+ 'opacity_bias': opacity_bias,
+ 'scaling_activation': scaling_activation,
+ }
+
+ self.sh_degree = sh_degree
+ self.active_sh_degree = sh_degree
+ self.mininum_kernel_size = mininum_kernel_size
+ self.scaling_bias = scaling_bias
+ self.opacity_bias = opacity_bias
+ self.scaling_activation_type = scaling_activation
+ self.device = device
+ self.aabb = torch.tensor(aabb, dtype=torch.float32, device=device)
+ self.setup_functions()
+
+ self._xyz = None
+ self._features_dc = None
+ self._features_rest = None
+ self._scaling = None
+ self._rotation = None
+ self._opacity = None
+
+ def setup_functions(self):
+ def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation):
+ L = build_scaling_rotation(scaling_modifier * scaling, rotation)
+ actual_covariance = L @ L.transpose(1, 2)
+ symm = strip_symmetric(actual_covariance)
+ return symm
+
+ if self.scaling_activation_type == "exp":
+ self.scaling_activation = torch.exp
+ self.inverse_scaling_activation = torch.log
+ elif self.scaling_activation_type == "softplus":
+ self.scaling_activation = torch.nn.functional.softplus
+ self.inverse_scaling_activation = lambda x: x + torch.log(-torch.expm1(-x))
+
+ self.covariance_activation = build_covariance_from_scaling_rotation
+
+ self.opacity_activation = torch.sigmoid
+ self.inverse_opacity_activation = inverse_sigmoid
+
+ self.rotation_activation = torch.nn.functional.normalize
+
+ self.scale_bias = self.inverse_scaling_activation(torch.tensor(self.scaling_bias)).cuda()
+ self.rots_bias = torch.zeros((4)).cuda()
+ self.rots_bias[0] = 1
+ self.opacity_bias = self.inverse_opacity_activation(torch.tensor(self.opacity_bias)).cuda()
+
+ @property
+ def get_scaling(self):
+ scales = self.scaling_activation(self._scaling + self.scale_bias)
+ scales = torch.square(scales) + self.mininum_kernel_size ** 2
+ scales = torch.sqrt(scales)
+ return scales
+
+ @property
+ def get_rotation(self):
+ return self.rotation_activation(self._rotation + self.rots_bias[None, :])
+
+ @property
+ def get_xyz(self):
+ return self._xyz * self.aabb[None, 3:] + self.aabb[None, :3]
+
+ @property
+ def get_features(self):
+ return torch.cat((self._features_dc, self._features_rest), dim=2) if self._features_rest is not None else self._features_dc
+
+ @property
+ def get_opacity(self):
+ return self.opacity_activation(self._opacity + self.opacity_bias)
+
+ def get_covariance(self, scaling_modifier = 1):
+ return self.covariance_activation(self.get_scaling, scaling_modifier, self._rotation + self.rots_bias[None, :])
+
+ def from_scaling(self, scales):
+ scales = torch.sqrt(torch.square(scales) - self.mininum_kernel_size ** 2)
+ self._scaling = self.inverse_scaling_activation(scales) - self.scale_bias
+
+ def from_rotation(self, rots):
+ self._rotation = rots - self.rots_bias[None, :]
+
+ def from_xyz(self, xyz):
+ self._xyz = (xyz - self.aabb[None, :3]) / self.aabb[None, 3:]
+
+ def from_features(self, features):
+ self._features_dc = features
+
+ def from_opacity(self, opacities):
+ self._opacity = self.inverse_opacity_activation(opacities) - self.opacity_bias
+
+ def construct_list_of_attributes(self):
+ l = ['x', 'y', 'z', 'nx', 'ny', 'nz']
+ # All channels except the 3 DC
+ for i in range(self._features_dc.shape[1]*self._features_dc.shape[2]):
+ l.append('f_dc_{}'.format(i))
+ l.append('opacity')
+ for i in range(self._scaling.shape[1]):
+ l.append('scale_{}'.format(i))
+ for i in range(self._rotation.shape[1]):
+ l.append('rot_{}'.format(i))
+ return l
+
+ def save_ply(self, path):
+ xyz = self.get_xyz.detach().cpu().numpy()
+ normals = np.zeros_like(xyz)
+ f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
+ opacities = inverse_sigmoid(self.get_opacity).detach().cpu().numpy()
+ scale = torch.log(self.get_scaling).detach().cpu().numpy()
+ rotation = (self._rotation + self.rots_bias[None, :]).detach().cpu().numpy()
+
+ dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()]
+
+ elements = np.empty(xyz.shape[0], dtype=dtype_full)
+ attributes = np.concatenate((xyz, normals, f_dc, opacities, scale, rotation), axis=1)
+ elements[:] = list(map(tuple, attributes))
+ el = PlyElement.describe(elements, 'vertex')
+ PlyData([el]).write(path)
+
+ def load_ply(self, path):
+ plydata = PlyData.read(path)
+
+ xyz = np.stack((np.asarray(plydata.elements[0]["x"]),
+ np.asarray(plydata.elements[0]["y"]),
+ np.asarray(plydata.elements[0]["z"])), axis=1)
+ opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis]
+
+ features_dc = np.zeros((xyz.shape[0], 3, 1))
+ features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"])
+ features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"])
+ features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"])
+
+ if self.sh_degree > 0:
+ extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")]
+ extra_f_names = sorted(extra_f_names, key = lambda x: int(x.split('_')[-1]))
+ assert len(extra_f_names)==3*(self.sh_degree + 1) ** 2 - 3
+ features_extra = np.zeros((xyz.shape[0], len(extra_f_names)))
+ for idx, attr_name in enumerate(extra_f_names):
+ features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name])
+ # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC)
+ features_extra = features_extra.reshape((features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1))
+
+ scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")]
+ scale_names = sorted(scale_names, key = lambda x: int(x.split('_')[-1]))
+ scales = np.zeros((xyz.shape[0], len(scale_names)))
+ for idx, attr_name in enumerate(scale_names):
+ scales[:, idx] = np.asarray(plydata.elements[0][attr_name])
+
+ rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")]
+ rot_names = sorted(rot_names, key = lambda x: int(x.split('_')[-1]))
+ rots = np.zeros((xyz.shape[0], len(rot_names)))
+ for idx, attr_name in enumerate(rot_names):
+ rots[:, idx] = np.asarray(plydata.elements[0][attr_name])
+
+ # convert to actual gaussian attributes
+ xyz = torch.tensor(xyz, dtype=torch.float, device=self.device)
+ features_dc = torch.tensor(features_dc, dtype=torch.float, device=self.device).transpose(1, 2).contiguous()
+ if self.sh_degree > 0:
+ features_extra = torch.tensor(features_extra, dtype=torch.float, device=self.device).transpose(1, 2).contiguous()
+ opacities = torch.sigmoid(torch.tensor(opacities, dtype=torch.float, device=self.device))
+ scales = torch.exp(torch.tensor(scales, dtype=torch.float, device=self.device))
+ rots = torch.tensor(rots, dtype=torch.float, device=self.device)
+
+ # convert to _hidden attributes
+ self._xyz = (xyz - self.aabb[None, :3]) / self.aabb[None, 3:]
+ self._features_dc = features_dc
+ if self.sh_degree > 0:
+ self._features_rest = features_extra
+ else:
+ self._features_rest = None
+ self._opacity = self.inverse_opacity_activation(opacities) - self.opacity_bias
+ self._scaling = self.inverse_scaling_activation(torch.sqrt(torch.square(scales) - self.mininum_kernel_size ** 2)) - self.scale_bias
+ self._rotation = rots - self.rots_bias[None, :]
+
\ No newline at end of file
diff --git a/libs/gs/gaussian_renderer.py b/libs/gs/gaussian_renderer.py
new file mode 100644
index 0000000000000000000000000000000000000000..d490ad679318b455dd2486cc21f751f6daa8f597
--- /dev/null
+++ b/libs/gs/gaussian_renderer.py
@@ -0,0 +1,230 @@
+#
+# Copyright (C) 2023, Inria
+# GRAPHDECO research group, https://team.inria.fr/graphdeco
+# All rights reserved.
+#
+# This software is free for non-commercial, research and evaluation use
+# under the terms of the LICENSE.md file.
+#
+# For inquiries contact george.drettakis@inria.fr
+#
+
+import torch
+import math
+from easydict import EasyDict as edict
+import numpy as np
+from .gaussian_model import Gaussian
+from .sh_utils import eval_sh
+import torch.nn.functional as F
+
+def intrinsics_to_projection(
+ intrinsics: torch.Tensor,
+ near: float,
+ far: float,
+ ) -> torch.Tensor:
+ """
+ OpenCV intrinsics to OpenGL perspective matrix
+
+ Args:
+ intrinsics (torch.Tensor): [3, 3] OpenCV intrinsics matrix
+ near (float): near plane to clip
+ far (float): far plane to clip
+ Returns:
+ (torch.Tensor): [4, 4] OpenGL perspective matrix
+ """
+ fx, fy = intrinsics[0, 0], intrinsics[1, 1]
+ cx, cy = intrinsics[0, 2], intrinsics[1, 2]
+ ret = torch.zeros((4, 4), dtype=intrinsics.dtype, device=intrinsics.device)
+ ret[0, 0] = 2 * fx
+ ret[1, 1] = 2 * fy
+ ret[0, 2] = 2 * cx - 1
+ ret[1, 2] = - 2 * cy + 1
+ ret[2, 2] = far / (far - near)
+ ret[2, 3] = near * far / (near - far)
+ ret[3, 2] = 1.
+ return ret
+
+
+def render(viewpoint_camera, pc : Gaussian, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, override_color = None):
+ """
+ Render the scene.
+
+ Background tensor (bg_color) must be on GPU!
+ """
+ # lazy import
+ if 'GaussianRasterizer' not in globals():
+ # from diff_gaussian_rasterization import GaussianRasterizer, GaussianRasterizationSettings
+ from diff_gauss import GaussianRasterizationSettings, GaussianRasterizer
+
+ # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
+ screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0
+ try:
+ screenspace_points.retain_grad()
+ except:
+ pass
+ # Set up rasterization configuration
+ tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
+ tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
+
+ kernel_size = pipe.kernel_size
+ subpixel_offset = torch.zeros((int(viewpoint_camera.image_height), int(viewpoint_camera.image_width), 2), dtype=torch.float32, device="cuda")
+
+ raster_settings = GaussianRasterizationSettings(
+ image_height=int(viewpoint_camera.image_height),
+ image_width=int(viewpoint_camera.image_width),
+ tanfovx=tanfovx,
+ tanfovy=tanfovy,
+ # kernel_size=kernel_size,
+ # subpixel_offset=subpixel_offset,
+ bg=bg_color,
+ scale_modifier=scaling_modifier,
+ viewmatrix=viewpoint_camera.world_view_transform,
+ projmatrix=viewpoint_camera.full_proj_transform,
+ sh_degree=pc.active_sh_degree,
+ campos=viewpoint_camera.camera_center,
+ prefiltered=False,
+ debug=pipe.debug
+ )
+
+ rasterizer = GaussianRasterizer(raster_settings=raster_settings)
+
+ means3D = pc.get_xyz
+ means2D = screenspace_points
+ opacity = pc.get_opacity
+
+ # If precomputed 3d covariance is provided, use it. If not, then it will be computed from
+ # scaling / rotation by the rasterizer.
+ scales = None
+ rotations = None
+ cov3D_precomp = None
+ if pipe.compute_cov3D_python:
+ cov3D_precomp = pc.get_covariance(scaling_modifier)
+ else:
+ scales = pc.get_scaling
+ rotations = pc.get_rotation
+
+ # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
+ # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
+ shs = None
+ colors_precomp = None
+ if override_color is None:
+ if pipe.convert_SHs_python:
+ shs_view = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree+1)**2)
+ dir_pp = (pc.get_xyz - viewpoint_camera.camera_center.repeat(pc.get_features.shape[0], 1))
+ dir_pp_normalized = dir_pp/dir_pp.norm(dim=1, keepdim=True)
+ sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized)
+ colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)
+ else:
+ shs = pc.get_features
+ else:
+ colors_precomp = override_color
+
+ rendered_image, depth, alpha, radii = rasterizer(
+ means3D=means3D,
+ means2D=means2D,
+ shs=shs,
+ colors_precomp=colors_precomp,
+ opacities=opacity,
+ scales=scales,
+ rotations=rotations,
+ cov3D_precomp=cov3D_precomp)
+
+ return edict({"render": rendered_image,
+ "viewspace_points": screenspace_points,
+ "visibility_filter" : radii > 0,
+ "depth": depth,
+ "alpha": alpha,
+ "radii": radii})
+
+class GaussianRenderer:
+ """
+ Renderer for the Voxel representation.
+
+ Args:
+ rendering_options (dict): Rendering options.
+ """
+
+ def __init__(self, rendering_options={}) -> None:
+ self.pipe = edict({
+ "kernel_size": 0.1,
+ "convert_SHs_python": False,
+ "compute_cov3D_python": False,
+ "scale_modifier": 1.0,
+ "debug": False
+ })
+ self.rendering_options = edict({
+ "resolution": None,
+ "near": None,
+ "far": None,
+ "ssaa": 1,
+ "bg_color": 'random',
+ })
+ self.rendering_options.update(rendering_options)
+ self.bg_color = None
+
+ def render(
+ self,
+ gausssian: Gaussian,
+ extrinsics: torch.Tensor,
+ intrinsics: torch.Tensor,
+ colors_overwrite: torch.Tensor = None
+ ) -> edict:
+ """
+ Render the gausssian.
+
+ Args:
+ gaussian : gaussianmodule
+ extrinsics (torch.Tensor): (4, 4) camera extrinsics
+ intrinsics (torch.Tensor): (3, 3) camera intrinsics
+ colors_overwrite (torch.Tensor): (N, 3) override color
+
+ Returns:
+ edict containing:
+ color (torch.Tensor): (3, H, W) rendered color image
+ """
+ resolution = self.rendering_options["resolution"]
+ near = self.rendering_options["near"]
+ far = self.rendering_options["far"]
+ ssaa = self.rendering_options["ssaa"]
+
+ if self.rendering_options["bg_color"] == 'random':
+ self.bg_color = torch.zeros(3, dtype=torch.float32, device="cuda")
+ if np.random.rand() < 0.5:
+ self.bg_color += 1
+ else:
+ self.bg_color = torch.tensor(self.rendering_options["bg_color"], dtype=torch.float32, device="cuda")
+
+ view = extrinsics
+ perspective = intrinsics_to_projection(intrinsics, near, far)
+ camera = torch.inverse(view)[:3, 3]
+ focalx = intrinsics[0, 0]
+ focaly = intrinsics[1, 1]
+ fovx = 2 * torch.atan(0.5 / focalx)
+ fovy = 2 * torch.atan(0.5 / focaly)
+
+ camera_dict = edict({
+ "image_height": resolution * ssaa,
+ "image_width": resolution * ssaa,
+ "FoVx": fovx,
+ "FoVy": fovy,
+ "znear": near,
+ "zfar": far,
+ "world_view_transform": view.T.contiguous(),
+ "projection_matrix": perspective.T.contiguous(),
+ "full_proj_transform": (perspective @ view).T.contiguous(),
+ "camera_center": camera
+ })
+
+ # Render
+ render_ret = render(camera_dict, gausssian, self.pipe, self.bg_color, override_color=colors_overwrite, scaling_modifier=self.pipe.scale_modifier)
+
+ if ssaa > 1:
+ render_ret.render = F.interpolate(render_ret.render[None], size=(resolution, resolution), mode='bilinear', align_corners=False, antialias=True).squeeze()
+
+ ret = edict({
+ 'projection': camera_dict['full_proj_transform'],
+ 'color': render_ret.render,
+ 'depth': render_ret.depth,
+ 'alpha': render_ret.alpha
+ })
+ return ret
diff --git a/libs/gs/general_utils.py b/libs/gs/general_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..541c0825229a2d86e84460b765879f86f724a59d
--- /dev/null
+++ b/libs/gs/general_utils.py
@@ -0,0 +1,133 @@
+#
+# Copyright (C) 2023, Inria
+# GRAPHDECO research group, https://team.inria.fr/graphdeco
+# All rights reserved.
+#
+# This software is free for non-commercial, research and evaluation use
+# under the terms of the LICENSE.md file.
+#
+# For inquiries contact george.drettakis@inria.fr
+#
+
+import torch
+import sys
+from datetime import datetime
+import numpy as np
+import random
+
+def inverse_sigmoid(x):
+ return torch.log(x/(1-x))
+
+def PILtoTorch(pil_image, resolution):
+ resized_image_PIL = pil_image.resize(resolution)
+ resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0
+ if len(resized_image.shape) == 3:
+ return resized_image.permute(2, 0, 1)
+ else:
+ return resized_image.unsqueeze(dim=-1).permute(2, 0, 1)
+
+def get_expon_lr_func(
+ lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000
+):
+ """
+ Copied from Plenoxels
+
+ Continuous learning rate decay function. Adapted from JaxNeRF
+ The returned rate is lr_init when step=0 and lr_final when step=max_steps, and
+ is log-linearly interpolated elsewhere (equivalent to exponential decay).
+ If lr_delay_steps>0 then the learning rate will be scaled by some smooth
+ function of lr_delay_mult, such that the initial learning rate is
+ lr_init*lr_delay_mult at the beginning of optimization but will be eased back
+ to the normal learning rate when steps>lr_delay_steps.
+ :param conf: config subtree 'lr' or similar
+ :param max_steps: int, the number of steps during optimization.
+ :return HoF which takes step as input
+ """
+
+ def helper(step):
+ if step < 0 or (lr_init == 0.0 and lr_final == 0.0):
+ # Disable this parameter
+ return 0.0
+ if lr_delay_steps > 0:
+ # A kind of reverse cosine decay.
+ delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin(
+ 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1)
+ )
+ else:
+ delay_rate = 1.0
+ t = np.clip(step / max_steps, 0, 1)
+ log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t)
+ return delay_rate * log_lerp
+
+ return helper
+
+def strip_lowerdiag(L):
+ uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda")
+
+ uncertainty[:, 0] = L[:, 0, 0]
+ uncertainty[:, 1] = L[:, 0, 1]
+ uncertainty[:, 2] = L[:, 0, 2]
+ uncertainty[:, 3] = L[:, 1, 1]
+ uncertainty[:, 4] = L[:, 1, 2]
+ uncertainty[:, 5] = L[:, 2, 2]
+ return uncertainty
+
+def strip_symmetric(sym):
+ return strip_lowerdiag(sym)
+
+def build_rotation(r):
+ norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3])
+
+ q = r / norm[:, None]
+
+ R = torch.zeros((q.size(0), 3, 3), device='cuda')
+
+ r = q[:, 0]
+ x = q[:, 1]
+ y = q[:, 2]
+ z = q[:, 3]
+
+ R[:, 0, 0] = 1 - 2 * (y*y + z*z)
+ R[:, 0, 1] = 2 * (x*y - r*z)
+ R[:, 0, 2] = 2 * (x*z + r*y)
+ R[:, 1, 0] = 2 * (x*y + r*z)
+ R[:, 1, 1] = 1 - 2 * (x*x + z*z)
+ R[:, 1, 2] = 2 * (y*z - r*x)
+ R[:, 2, 0] = 2 * (x*z - r*y)
+ R[:, 2, 1] = 2 * (y*z + r*x)
+ R[:, 2, 2] = 1 - 2 * (x*x + y*y)
+ return R
+
+def build_scaling_rotation(s, r):
+ L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda")
+ R = build_rotation(r)
+
+ L[:,0,0] = s[:,0]
+ L[:,1,1] = s[:,1]
+ L[:,2,2] = s[:,2]
+
+ L = R @ L
+ return L
+
+def safe_state(silent):
+ old_f = sys.stdout
+ class F:
+ def __init__(self, silent):
+ self.silent = silent
+
+ def write(self, x):
+ if not self.silent:
+ if x.endswith("\n"):
+ old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S")))))
+ else:
+ old_f.write(x)
+
+ def flush(self):
+ old_f.flush()
+
+ sys.stdout = F(silent)
+
+ random.seed(0)
+ np.random.seed(0)
+ torch.manual_seed(0)
+ torch.cuda.set_device(torch.device("cuda:0"))
diff --git a/libs/gs/sh_utils.py b/libs/gs/sh_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..bbca7d192aa3a7edf8c5b2d24dee535eac765785
--- /dev/null
+++ b/libs/gs/sh_utils.py
@@ -0,0 +1,118 @@
+# Copyright 2021 The PlenOctree Authors.
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice,
+# this list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
+# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+
+import torch
+
+C0 = 0.28209479177387814
+C1 = 0.4886025119029199
+C2 = [
+ 1.0925484305920792,
+ -1.0925484305920792,
+ 0.31539156525252005,
+ -1.0925484305920792,
+ 0.5462742152960396
+]
+C3 = [
+ -0.5900435899266435,
+ 2.890611442640554,
+ -0.4570457994644658,
+ 0.3731763325901154,
+ -0.4570457994644658,
+ 1.445305721320277,
+ -0.5900435899266435
+]
+C4 = [
+ 2.5033429417967046,
+ -1.7701307697799304,
+ 0.9461746957575601,
+ -0.6690465435572892,
+ 0.10578554691520431,
+ -0.6690465435572892,
+ 0.47308734787878004,
+ -1.7701307697799304,
+ 0.6258357354491761,
+]
+
+
+def eval_sh(deg, sh, dirs):
+ """
+ Evaluate spherical harmonics at unit directions
+ using hardcoded SH polynomials.
+ Works with torch/np/jnp.
+ ... Can be 0 or more batch dimensions.
+ Args:
+ deg: int SH deg. Currently, 0-3 supported
+ sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2]
+ dirs: jnp.ndarray unit directions [..., 3]
+ Returns:
+ [..., C]
+ """
+ assert deg <= 4 and deg >= 0
+ coeff = (deg + 1) ** 2
+ assert sh.shape[-1] >= coeff
+
+ result = C0 * sh[..., 0]
+ if deg > 0:
+ x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3]
+ result = (result -
+ C1 * y * sh[..., 1] +
+ C1 * z * sh[..., 2] -
+ C1 * x * sh[..., 3])
+
+ if deg > 1:
+ xx, yy, zz = x * x, y * y, z * z
+ xy, yz, xz = x * y, y * z, x * z
+ result = (result +
+ C2[0] * xy * sh[..., 4] +
+ C2[1] * yz * sh[..., 5] +
+ C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] +
+ C2[3] * xz * sh[..., 7] +
+ C2[4] * (xx - yy) * sh[..., 8])
+
+ if deg > 2:
+ result = (result +
+ C3[0] * y * (3 * xx - yy) * sh[..., 9] +
+ C3[1] * xy * z * sh[..., 10] +
+ C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] +
+ C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] +
+ C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] +
+ C3[5] * z * (xx - yy) * sh[..., 14] +
+ C3[6] * x * (xx - 3 * yy) * sh[..., 15])
+
+ if deg > 3:
+ result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] +
+ C4[1] * yz * (3 * xx - yy) * sh[..., 17] +
+ C4[2] * xy * (7 * zz - 1) * sh[..., 18] +
+ C4[3] * yz * (7 * zz - 3) * sh[..., 19] +
+ C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] +
+ C4[5] * xz * (7 * zz - 3) * sh[..., 21] +
+ C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] +
+ C4[7] * xz * (xx - 3 * yy) * sh[..., 23] +
+ C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24])
+ return result
+
+def RGB2SH(rgb):
+ return (rgb - 0.5) / C0
+
+def SH2RGB(sh):
+ return sh * C0 + 0.5
\ No newline at end of file
diff --git a/libs/sam2/__init__.py b/libs/sam2/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0712dd03cb280ab94ba04f8a32aa8ddc8aa3db4a
--- /dev/null
+++ b/libs/sam2/__init__.py
@@ -0,0 +1,11 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from hydra import initialize_config_module
+from hydra.core.global_hydra import GlobalHydra
+
+if not GlobalHydra.instance().is_initialized():
+ initialize_config_module("sam2", version_base="1.2")
diff --git a/libs/sam2/automatic_mask_generator.py b/libs/sam2/automatic_mask_generator.py
new file mode 100644
index 0000000000000000000000000000000000000000..065e469e27c2d3af40d51d072031e828692c799b
--- /dev/null
+++ b/libs/sam2/automatic_mask_generator.py
@@ -0,0 +1,454 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# Adapted from https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py
+from typing import Any, Dict, List, Optional, Tuple
+
+import numpy as np
+import torch
+from torchvision.ops.boxes import batched_nms, box_area # type: ignore
+
+from sam2.modeling.sam2_base import SAM2Base
+from sam2.sam2_image_predictor import SAM2ImagePredictor
+from sam2.utils.amg import (
+ area_from_rle,
+ batch_iterator,
+ batched_mask_to_box,
+ box_xyxy_to_xywh,
+ build_all_layer_point_grids,
+ calculate_stability_score,
+ coco_encode_rle,
+ generate_crop_boxes,
+ is_box_near_crop_edge,
+ mask_to_rle_pytorch,
+ MaskData,
+ remove_small_regions,
+ rle_to_mask,
+ uncrop_boxes_xyxy,
+ uncrop_masks,
+ uncrop_points,
+)
+
+
+class SAM2AutomaticMaskGenerator:
+ def __init__(
+ self,
+ model: SAM2Base,
+ points_per_side: Optional[int] = 32,
+ points_per_batch: int = 64,
+ pred_iou_thresh: float = 0.8,
+ stability_score_thresh: float = 0.95,
+ stability_score_offset: float = 1.0,
+ mask_threshold: float = 0.0,
+ box_nms_thresh: float = 0.7,
+ crop_n_layers: int = 0,
+ crop_nms_thresh: float = 0.7,
+ crop_overlap_ratio: float = 512 / 1500,
+ crop_n_points_downscale_factor: int = 1,
+ point_grids: Optional[List[np.ndarray]] = None,
+ min_mask_region_area: int = 0,
+ output_mode: str = "binary_mask",
+ use_m2m: bool = False,
+ multimask_output: bool = True,
+ **kwargs,
+ ) -> None:
+ """
+ Using a SAM 2 model, generates masks for the entire image.
+ Generates a grid of point prompts over the image, then filters
+ low quality and duplicate masks. The default settings are chosen
+ for SAM 2 with a HieraL backbone.
+
+ Arguments:
+ model (Sam): The SAM 2 model to use for mask prediction.
+ points_per_side (int or None): The number of points to be sampled
+ along one side of the image. The total number of points is
+ points_per_side**2. If None, 'point_grids' must provide explicit
+ point sampling.
+ points_per_batch (int): Sets the number of points run simultaneously
+ by the model. Higher numbers may be faster but use more GPU memory.
+ pred_iou_thresh (float): A filtering threshold in [0,1], using the
+ model's predicted mask quality.
+ stability_score_thresh (float): A filtering threshold in [0,1], using
+ the stability of the mask under changes to the cutoff used to binarize
+ the model's mask predictions.
+ stability_score_offset (float): The amount to shift the cutoff when
+ calculated the stability score.
+ mask_threshold (float): Threshold for binarizing the mask logits
+ box_nms_thresh (float): The box IoU cutoff used by non-maximal
+ suppression to filter duplicate masks.
+ crop_n_layers (int): If >0, mask prediction will be run again on
+ crops of the image. Sets the number of layers to run, where each
+ layer has 2**i_layer number of image crops.
+ crop_nms_thresh (float): The box IoU cutoff used by non-maximal
+ suppression to filter duplicate masks between different crops.
+ crop_overlap_ratio (float): Sets the degree to which crops overlap.
+ In the first crop layer, crops will overlap by this fraction of
+ the image length. Later layers with more crops scale down this overlap.
+ crop_n_points_downscale_factor (int): The number of points-per-side
+ sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
+ point_grids (list(np.ndarray) or None): A list over explicit grids
+ of points used for sampling, normalized to [0,1]. The nth grid in the
+ list is used in the nth crop layer. Exclusive with points_per_side.
+ min_mask_region_area (int): If >0, postprocessing will be applied
+ to remove disconnected regions and holes in masks with area smaller
+ than min_mask_region_area. Requires opencv.
+ output_mode (str): The form masks are returned in. Can be 'binary_mask',
+ 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
+ For large resolutions, 'binary_mask' may consume large amounts of
+ memory.
+ use_m2m (bool): Whether to add a one step refinement using previous mask predictions.
+ multimask_output (bool): Whether to output multimask at each point of the grid.
+ """
+
+ assert (points_per_side is None) != (
+ point_grids is None
+ ), "Exactly one of points_per_side or point_grid must be provided."
+ if points_per_side is not None:
+ self.point_grids = build_all_layer_point_grids(
+ points_per_side,
+ crop_n_layers,
+ crop_n_points_downscale_factor,
+ )
+ elif point_grids is not None:
+ self.point_grids = point_grids
+ else:
+ raise ValueError("Can't have both points_per_side and point_grid be None.")
+
+ assert output_mode in [
+ "binary_mask",
+ "uncompressed_rle",
+ "coco_rle",
+ ], f"Unknown output_mode {output_mode}."
+ if output_mode == "coco_rle":
+ try:
+ from pycocotools import mask as mask_utils # type: ignore # noqa: F401
+ except ImportError as e:
+ print("Please install pycocotools")
+ raise e
+
+ self.predictor = SAM2ImagePredictor(
+ model,
+ max_hole_area=min_mask_region_area,
+ max_sprinkle_area=min_mask_region_area,
+ )
+ self.points_per_batch = points_per_batch
+ self.pred_iou_thresh = pred_iou_thresh
+ self.stability_score_thresh = stability_score_thresh
+ self.stability_score_offset = stability_score_offset
+ self.mask_threshold = mask_threshold
+ self.box_nms_thresh = box_nms_thresh
+ self.crop_n_layers = crop_n_layers
+ self.crop_nms_thresh = crop_nms_thresh
+ self.crop_overlap_ratio = crop_overlap_ratio
+ self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
+ self.min_mask_region_area = min_mask_region_area
+ self.output_mode = output_mode
+ self.use_m2m = use_m2m
+ self.multimask_output = multimask_output
+
+ @classmethod
+ def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2AutomaticMaskGenerator":
+ """
+ Load a pretrained model from the Hugging Face hub.
+
+ Arguments:
+ model_id (str): The Hugging Face repository ID.
+ **kwargs: Additional arguments to pass to the model constructor.
+
+ Returns:
+ (SAM2AutomaticMaskGenerator): The loaded model.
+ """
+ from sam2.build_sam import build_sam2_hf
+
+ sam_model = build_sam2_hf(model_id, **kwargs)
+ return cls(sam_model, **kwargs)
+
+ @torch.no_grad()
+ def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
+ """
+ Generates masks for the given image.
+
+ Arguments:
+ image (np.ndarray): The image to generate masks for, in HWC uint8 format.
+
+ Returns:
+ list(dict(str, any)): A list over records for masks. Each record is
+ a dict containing the following keys:
+ segmentation (dict(str, any) or np.ndarray): The mask. If
+ output_mode='binary_mask', is an array of shape HW. Otherwise,
+ is a dictionary containing the RLE.
+ bbox (list(float)): The box around the mask, in XYWH format.
+ area (int): The area in pixels of the mask.
+ predicted_iou (float): The model's own prediction of the mask's
+ quality. This is filtered by the pred_iou_thresh parameter.
+ point_coords (list(list(float))): The point coordinates input
+ to the model to generate this mask.
+ stability_score (float): A measure of the mask's quality. This
+ is filtered on using the stability_score_thresh parameter.
+ crop_box (list(float)): The crop of the image used to generate
+ the mask, given in XYWH format.
+ """
+
+ # Generate masks
+ mask_data = self._generate_masks(image)
+
+ # Encode masks
+ if self.output_mode == "coco_rle":
+ mask_data["segmentations"] = [
+ coco_encode_rle(rle) for rle in mask_data["rles"]
+ ]
+ elif self.output_mode == "binary_mask":
+ mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
+ else:
+ mask_data["segmentations"] = mask_data["rles"]
+
+ # Write mask records
+ curr_anns = []
+ for idx in range(len(mask_data["segmentations"])):
+ ann = {
+ "segmentation": mask_data["segmentations"][idx],
+ "area": area_from_rle(mask_data["rles"][idx]),
+ "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
+ "predicted_iou": mask_data["iou_preds"][idx].item(),
+ "point_coords": [mask_data["points"][idx].tolist()],
+ "stability_score": mask_data["stability_score"][idx].item(),
+ "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
+ }
+ curr_anns.append(ann)
+
+ return curr_anns
+
+ def _generate_masks(self, image: np.ndarray) -> MaskData:
+ orig_size = image.shape[:2]
+ crop_boxes, layer_idxs = generate_crop_boxes(
+ orig_size, self.crop_n_layers, self.crop_overlap_ratio
+ )
+
+ # Iterate over image crops
+ data = MaskData()
+ for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
+ crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
+ data.cat(crop_data)
+
+ # Remove duplicate masks between crops
+ if len(crop_boxes) > 1:
+ # Prefer masks from smaller crops
+ scores = 1 / box_area(data["crop_boxes"])
+ scores = scores.to(data["boxes"].device)
+ keep_by_nms = batched_nms(
+ data["boxes"].float(),
+ scores,
+ torch.zeros_like(data["boxes"][:, 0]), # categories
+ iou_threshold=self.crop_nms_thresh,
+ )
+ data.filter(keep_by_nms)
+ data.to_numpy()
+ return data
+
+ def _process_crop(
+ self,
+ image: np.ndarray,
+ crop_box: List[int],
+ crop_layer_idx: int,
+ orig_size: Tuple[int, ...],
+ ) -> MaskData:
+ # Crop the image and calculate embeddings
+ x0, y0, x1, y1 = crop_box
+ cropped_im = image[y0:y1, x0:x1, :]
+ cropped_im_size = cropped_im.shape[:2]
+ self.predictor.set_image(cropped_im)
+
+ # Get points for this crop
+ points_scale = np.array(cropped_im_size)[None, ::-1]
+ points_for_image = self.point_grids[crop_layer_idx] * points_scale
+
+ # Generate masks for this crop in batches
+ data = MaskData()
+ for (points,) in batch_iterator(self.points_per_batch, points_for_image):
+ batch_data = self._process_batch(
+ points, cropped_im_size, crop_box, orig_size, normalize=True
+ )
+ data.cat(batch_data)
+ del batch_data
+ self.predictor.reset_predictor()
+
+ # Remove duplicates within this crop.
+ keep_by_nms = batched_nms(
+ data["boxes"].float(),
+ data["iou_preds"],
+ torch.zeros_like(data["boxes"][:, 0]), # categories
+ iou_threshold=self.box_nms_thresh,
+ )
+ data.filter(keep_by_nms)
+
+ # Return to the original image frame
+ data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box)
+ data["points"] = uncrop_points(data["points"], crop_box)
+ data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
+
+ return data
+
+ def _process_batch(
+ self,
+ points: np.ndarray,
+ im_size: Tuple[int, ...],
+ crop_box: List[int],
+ orig_size: Tuple[int, ...],
+ normalize=False,
+ ) -> MaskData:
+ orig_h, orig_w = orig_size
+
+ # Run model on this batch
+ points = torch.as_tensor(
+ points, dtype=torch.float32, device=self.predictor.device
+ )
+ in_points = self.predictor._transforms.transform_coords(
+ points, normalize=normalize, orig_hw=im_size
+ )
+ in_labels = torch.ones(
+ in_points.shape[0], dtype=torch.int, device=in_points.device
+ )
+ masks, iou_preds, low_res_masks = self.predictor._predict(
+ in_points[:, None, :],
+ in_labels[:, None],
+ multimask_output=self.multimask_output,
+ return_logits=True,
+ )
+
+ # Serialize predictions and store in MaskData
+ data = MaskData(
+ masks=masks.flatten(0, 1),
+ iou_preds=iou_preds.flatten(0, 1),
+ points=points.repeat_interleave(masks.shape[1], dim=0),
+ low_res_masks=low_res_masks.flatten(0, 1),
+ )
+ del masks
+
+ if not self.use_m2m:
+ # Filter by predicted IoU
+ if self.pred_iou_thresh > 0.0:
+ keep_mask = data["iou_preds"] > self.pred_iou_thresh
+ data.filter(keep_mask)
+
+ # Calculate and filter by stability score
+ data["stability_score"] = calculate_stability_score(
+ data["masks"], self.mask_threshold, self.stability_score_offset
+ )
+ if self.stability_score_thresh > 0.0:
+ keep_mask = data["stability_score"] >= self.stability_score_thresh
+ data.filter(keep_mask)
+ else:
+ # One step refinement using previous mask predictions
+ in_points = self.predictor._transforms.transform_coords(
+ data["points"], normalize=normalize, orig_hw=im_size
+ )
+ labels = torch.ones(
+ in_points.shape[0], dtype=torch.int, device=in_points.device
+ )
+ masks, ious = self.refine_with_m2m(
+ in_points, labels, data["low_res_masks"], self.points_per_batch
+ )
+ data["masks"] = masks.squeeze(1)
+ data["iou_preds"] = ious.squeeze(1)
+
+ if self.pred_iou_thresh > 0.0:
+ keep_mask = data["iou_preds"] > self.pred_iou_thresh
+ data.filter(keep_mask)
+
+ data["stability_score"] = calculate_stability_score(
+ data["masks"], self.mask_threshold, self.stability_score_offset
+ )
+ if self.stability_score_thresh > 0.0:
+ keep_mask = data["stability_score"] >= self.stability_score_thresh
+ data.filter(keep_mask)
+
+ # Threshold masks and calculate boxes
+ data["masks"] = data["masks"] > self.mask_threshold
+ data["boxes"] = batched_mask_to_box(data["masks"])
+
+ # Filter boxes that touch crop boundaries
+ keep_mask = ~is_box_near_crop_edge(
+ data["boxes"], crop_box, [0, 0, orig_w, orig_h]
+ )
+ if not torch.all(keep_mask):
+ data.filter(keep_mask)
+
+ # Compress to RLE
+ data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
+ data["rles"] = mask_to_rle_pytorch(data["masks"])
+ del data["masks"]
+
+ return data
+
+ @staticmethod
+ def postprocess_small_regions(
+ mask_data: MaskData, min_area: int, nms_thresh: float
+ ) -> MaskData:
+ """
+ Removes small disconnected regions and holes in masks, then reruns
+ box NMS to remove any new duplicates.
+
+ Edits mask_data in place.
+
+ Requires open-cv as a dependency.
+ """
+ if len(mask_data["rles"]) == 0:
+ return mask_data
+
+ # Filter small disconnected regions and holes
+ new_masks = []
+ scores = []
+ for rle in mask_data["rles"]:
+ mask = rle_to_mask(rle)
+
+ mask, changed = remove_small_regions(mask, min_area, mode="holes")
+ unchanged = not changed
+ mask, changed = remove_small_regions(mask, min_area, mode="islands")
+ unchanged = unchanged and not changed
+
+ new_masks.append(torch.as_tensor(mask).unsqueeze(0))
+ # Give score=0 to changed masks and score=1 to unchanged masks
+ # so NMS will prefer ones that didn't need postprocessing
+ scores.append(float(unchanged))
+
+ # Recalculate boxes and remove any new duplicates
+ masks = torch.cat(new_masks, dim=0)
+ boxes = batched_mask_to_box(masks)
+ keep_by_nms = batched_nms(
+ boxes.float(),
+ torch.as_tensor(scores),
+ torch.zeros_like(boxes[:, 0]), # categories
+ iou_threshold=nms_thresh,
+ )
+
+ # Only recalculate RLEs for masks that have changed
+ for i_mask in keep_by_nms:
+ if scores[i_mask] == 0.0:
+ mask_torch = masks[i_mask].unsqueeze(0)
+ mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
+ mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly
+ mask_data.filter(keep_by_nms)
+
+ return mask_data
+
+ def refine_with_m2m(self, points, point_labels, low_res_masks, points_per_batch):
+ new_masks = []
+ new_iou_preds = []
+
+ for cur_points, cur_point_labels, low_res_mask in batch_iterator(
+ points_per_batch, points, point_labels, low_res_masks
+ ):
+ best_masks, best_iou_preds, _ = self.predictor._predict(
+ cur_points[:, None, :],
+ cur_point_labels[:, None],
+ mask_input=low_res_mask[:, None, :],
+ multimask_output=False,
+ return_logits=True,
+ )
+ new_masks.append(best_masks)
+ new_iou_preds.append(best_iou_preds)
+ masks = torch.cat(new_masks, dim=0)
+ return masks, torch.cat(new_iou_preds, dim=0)
diff --git a/libs/sam2/build_sam.py b/libs/sam2/build_sam.py
new file mode 100644
index 0000000000000000000000000000000000000000..223dcd51db70aebcf0536a2bb9d9ec2308a17fa5
--- /dev/null
+++ b/libs/sam2/build_sam.py
@@ -0,0 +1,167 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import logging
+import os
+
+import torch
+from hydra import compose
+from hydra.utils import instantiate
+from omegaconf import OmegaConf
+
+import sam2
+
+# Check if the user is running Python from the parent directory of the sam2 repo
+# (i.e. the directory where this repo is cloned into) -- this is not supported since
+# it could shadow the sam2 package and cause issues.
+if os.path.isdir(os.path.join(sam2.__path__[0], "sam2")):
+ # If the user has "sam2/sam2" in their path, they are likey importing the repo itself
+ # as "sam2" rather than importing the "sam2" python package (i.e. "sam2/sam2" directory).
+ # This typically happens because the user is running Python from the parent directory
+ # that contains the sam2 repo they cloned.
+ raise RuntimeError(
+ "You're likely running Python from the parent directory of the sam2 repository "
+ "(i.e. the directory where https://github.com/facebookresearch/sam2 is cloned into). "
+ "This is not supported since the `sam2` Python package could be shadowed by the "
+ "repository name (the repository is also named `sam2` and contains the Python package "
+ "in `sam2/sam2`). Please run Python from another directory (e.g. from the repo dir "
+ "rather than its parent dir, or from your home directory) after installing SAM 2."
+ )
+
+
+HF_MODEL_ID_TO_FILENAMES = {
+ "facebook/sam2-hiera-tiny": (
+ "configs/sam2/sam2_hiera_t.yaml",
+ "sam2_hiera_tiny.pt",
+ ),
+ "facebook/sam2-hiera-small": (
+ "configs/sam2/sam2_hiera_s.yaml",
+ "sam2_hiera_small.pt",
+ ),
+ "facebook/sam2-hiera-base-plus": (
+ "configs/sam2/sam2_hiera_b+.yaml",
+ "sam2_hiera_base_plus.pt",
+ ),
+ "facebook/sam2-hiera-large": (
+ "configs/sam2/sam2_hiera_l.yaml",
+ "sam2_hiera_large.pt",
+ ),
+ "facebook/sam2.1-hiera-tiny": (
+ "configs/sam2.1/sam2.1_hiera_t.yaml",
+ "sam2.1_hiera_tiny.pt",
+ ),
+ "facebook/sam2.1-hiera-small": (
+ "configs/sam2.1/sam2.1_hiera_s.yaml",
+ "sam2.1_hiera_small.pt",
+ ),
+ "facebook/sam2.1-hiera-base-plus": (
+ "configs/sam2.1/sam2.1_hiera_b+.yaml",
+ "sam2.1_hiera_base_plus.pt",
+ ),
+ "facebook/sam2.1-hiera-large": (
+ "configs/sam2.1/sam2.1_hiera_l.yaml",
+ "sam2.1_hiera_large.pt",
+ ),
+}
+
+
+def build_sam2(
+ config_file,
+ ckpt_path=None,
+ device="cuda",
+ mode="eval",
+ hydra_overrides_extra=[],
+ apply_postprocessing=True,
+ **kwargs,
+):
+
+ if apply_postprocessing:
+ hydra_overrides_extra = hydra_overrides_extra.copy()
+ hydra_overrides_extra += [
+ # dynamically fall back to multi-mask if the single mask is not stable
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
+ ]
+ # Read config and init model
+ cfg = compose(config_name=config_file, overrides=hydra_overrides_extra)
+ OmegaConf.resolve(cfg)
+ model = instantiate(cfg.model, _recursive_=True)
+ _load_checkpoint(model, ckpt_path)
+ model = model.to(device)
+ if mode == "eval":
+ model.eval()
+ return model
+
+
+def build_sam2_video_predictor(
+ config_file,
+ ckpt_path=None,
+ device="cuda",
+ mode="eval",
+ hydra_overrides_extra=[],
+ apply_postprocessing=True,
+ **kwargs,
+):
+ hydra_overrides = [
+ "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor",
+ ]
+ if apply_postprocessing:
+ hydra_overrides_extra = hydra_overrides_extra.copy()
+ hydra_overrides_extra += [
+ # dynamically fall back to multi-mask if the single mask is not stable
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
+ # the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking
+ "++model.binarize_mask_from_pts_for_mem_enc=true",
+ # fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution)
+ "++model.fill_hole_area=8",
+ ]
+ hydra_overrides.extend(hydra_overrides_extra)
+
+ # Read config and init model
+ cfg = compose(config_name=config_file, overrides=hydra_overrides)
+ OmegaConf.resolve(cfg)
+ model = instantiate(cfg.model, _recursive_=True)
+ _load_checkpoint(model, ckpt_path)
+ model = model.to(device)
+ if mode == "eval":
+ model.eval()
+ return model
+
+
+def _hf_download(model_id, **kwargs):
+ from huggingface_hub import hf_hub_download
+
+ config_name, checkpoint_name = HF_MODEL_ID_TO_FILENAMES[model_id]
+ ckpt_path = hf_hub_download(repo_id=model_id, filename=checkpoint_name, **kwargs)
+ return config_name, ckpt_path
+
+
+def build_sam2_hf(model_id, cache_dir, device):
+ config_name, ckpt_path = _hf_download(model_id, cache_dir=cache_dir)
+ return build_sam2(config_file=config_name, ckpt_path=ckpt_path, device=device)
+
+
+def build_sam2_video_predictor_hf(model_id, **kwargs):
+ config_name, ckpt_path = _hf_download(model_id)
+ return build_sam2_video_predictor(
+ config_file=config_name, ckpt_path=ckpt_path, **kwargs
+ )
+
+
+def _load_checkpoint(model, ckpt_path):
+ if ckpt_path is not None:
+ sd = torch.load(ckpt_path, map_location="cpu", weights_only=True)["model"]
+ missing_keys, unexpected_keys = model.load_state_dict(sd)
+ if missing_keys:
+ logging.error(missing_keys)
+ raise RuntimeError()
+ if unexpected_keys:
+ logging.error(unexpected_keys)
+ raise RuntimeError()
+ logging.info("Loaded checkpoint sucessfully")
diff --git a/libs/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml b/libs/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..cbee3cf9b3977ebe4cc868797a9bfa9e348cb3a3
--- /dev/null
+++ b/libs/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml
@@ -0,0 +1,116 @@
+# @package _global_
+
+# Model
+model:
+ _target_: sam2.modeling.sam2_base.SAM2Base
+ image_encoder:
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
+ scalp: 1
+ trunk:
+ _target_: sam2.modeling.backbones.hieradet.Hiera
+ embed_dim: 112
+ num_heads: 2
+ neck:
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
+ position_encoding:
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
+ num_pos_feats: 256
+ normalize: true
+ scale: null
+ temperature: 10000
+ d_model: 256
+ backbone_channel_list: [896, 448, 224, 112]
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
+ fpn_interp_model: nearest
+
+ memory_attention:
+ _target_: sam2.modeling.memory_attention.MemoryAttention
+ d_model: 256
+ pos_enc_at_input: true
+ layer:
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
+ activation: relu
+ dim_feedforward: 2048
+ dropout: 0.1
+ pos_enc_at_attn: false
+ self_attention:
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
+ rope_theta: 10000.0
+ feat_sizes: [32, 32]
+ embedding_dim: 256
+ num_heads: 1
+ downsample_rate: 1
+ dropout: 0.1
+ d_model: 256
+ pos_enc_at_cross_attn_keys: true
+ pos_enc_at_cross_attn_queries: false
+ cross_attention:
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
+ rope_theta: 10000.0
+ feat_sizes: [32, 32]
+ rope_k_repeat: True
+ embedding_dim: 256
+ num_heads: 1
+ downsample_rate: 1
+ dropout: 0.1
+ kv_in_dim: 64
+ num_layers: 4
+
+ memory_encoder:
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
+ out_dim: 64
+ position_encoding:
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
+ num_pos_feats: 64
+ normalize: true
+ scale: null
+ temperature: 10000
+ mask_downsampler:
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
+ kernel_size: 3
+ stride: 2
+ padding: 1
+ fuser:
+ _target_: sam2.modeling.memory_encoder.Fuser
+ layer:
+ _target_: sam2.modeling.memory_encoder.CXBlock
+ dim: 256
+ kernel_size: 7
+ padding: 3
+ layer_scale_init_value: 1e-6
+ use_dwconv: True # depth-wise convs
+ num_layers: 2
+
+ num_maskmem: 7
+ image_size: 1024
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
+ sigmoid_scale_for_mem_enc: 20.0
+ sigmoid_bias_for_mem_enc: -10.0
+ use_mask_input_as_output_without_sam: true
+ # Memory
+ directly_add_no_mem_embed: true
+ no_obj_embed_spatial: true
+ # use high-resolution feature map in the SAM mask decoder
+ use_high_res_features_in_sam: true
+ # output 3 masks on the first click on initial conditioning frames
+ multimask_output_in_sam: true
+ # SAM heads
+ iou_prediction_use_sigmoid: True
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
+ use_obj_ptrs_in_encoder: true
+ add_tpos_enc_to_obj_ptrs: true
+ proj_tpos_enc_in_obj_ptrs: true
+ use_signed_tpos_enc_to_obj_ptrs: true
+ only_obj_ptrs_in_the_past_for_eval: true
+ # object occlusion prediction
+ pred_obj_scores: true
+ pred_obj_scores_mlp: true
+ fixed_no_obj_ptr: true
+ # multimask tracking settings
+ multimask_output_for_tracking: true
+ use_multimask_token_for_obj_ptr: true
+ multimask_min_pt_num: 0
+ multimask_max_pt_num: 1
+ use_mlp_for_obj_ptr_proj: true
+ # Compilation flag
+ compile_image_encoder: False
diff --git a/libs/sam2/configs/sam2.1/sam2.1_hiera_l.yaml b/libs/sam2/configs/sam2.1/sam2.1_hiera_l.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..33c9097f34ea90beae52776eb88ad8eb1632ab66
--- /dev/null
+++ b/libs/sam2/configs/sam2.1/sam2.1_hiera_l.yaml
@@ -0,0 +1,120 @@
+# @package _global_
+
+# Model
+model:
+ _target_: sam2.modeling.sam2_base.SAM2Base
+ image_encoder:
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
+ scalp: 1
+ trunk:
+ _target_: sam2.modeling.backbones.hieradet.Hiera
+ embed_dim: 144
+ num_heads: 2
+ stages: [2, 6, 36, 4]
+ global_att_blocks: [23, 33, 43]
+ window_pos_embed_bkg_spatial_size: [7, 7]
+ window_spec: [8, 4, 16, 8]
+ neck:
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
+ position_encoding:
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
+ num_pos_feats: 256
+ normalize: true
+ scale: null
+ temperature: 10000
+ d_model: 256
+ backbone_channel_list: [1152, 576, 288, 144]
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
+ fpn_interp_model: nearest
+
+ memory_attention:
+ _target_: sam2.modeling.memory_attention.MemoryAttention
+ d_model: 256
+ pos_enc_at_input: true
+ layer:
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
+ activation: relu
+ dim_feedforward: 2048
+ dropout: 0.1
+ pos_enc_at_attn: false
+ self_attention:
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
+ rope_theta: 10000.0
+ feat_sizes: [32, 32]
+ embedding_dim: 256
+ num_heads: 1
+ downsample_rate: 1
+ dropout: 0.1
+ d_model: 256
+ pos_enc_at_cross_attn_keys: true
+ pos_enc_at_cross_attn_queries: false
+ cross_attention:
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
+ rope_theta: 10000.0
+ feat_sizes: [32, 32]
+ rope_k_repeat: True
+ embedding_dim: 256
+ num_heads: 1
+ downsample_rate: 1
+ dropout: 0.1
+ kv_in_dim: 64
+ num_layers: 4
+
+ memory_encoder:
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
+ out_dim: 64
+ position_encoding:
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
+ num_pos_feats: 64
+ normalize: true
+ scale: null
+ temperature: 10000
+ mask_downsampler:
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
+ kernel_size: 3
+ stride: 2
+ padding: 1
+ fuser:
+ _target_: sam2.modeling.memory_encoder.Fuser
+ layer:
+ _target_: sam2.modeling.memory_encoder.CXBlock
+ dim: 256
+ kernel_size: 7
+ padding: 3
+ layer_scale_init_value: 1e-6
+ use_dwconv: True # depth-wise convs
+ num_layers: 2
+
+ num_maskmem: 7
+ image_size: 1024
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
+ sigmoid_scale_for_mem_enc: 20.0
+ sigmoid_bias_for_mem_enc: -10.0
+ use_mask_input_as_output_without_sam: true
+ # Memory
+ directly_add_no_mem_embed: true
+ no_obj_embed_spatial: true
+ # use high-resolution feature map in the SAM mask decoder
+ use_high_res_features_in_sam: true
+ # output 3 masks on the first click on initial conditioning frames
+ multimask_output_in_sam: true
+ # SAM heads
+ iou_prediction_use_sigmoid: True
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
+ use_obj_ptrs_in_encoder: true
+ add_tpos_enc_to_obj_ptrs: true
+ proj_tpos_enc_in_obj_ptrs: true
+ use_signed_tpos_enc_to_obj_ptrs: true
+ only_obj_ptrs_in_the_past_for_eval: true
+ # object occlusion prediction
+ pred_obj_scores: true
+ pred_obj_scores_mlp: true
+ fixed_no_obj_ptr: true
+ # multimask tracking settings
+ multimask_output_for_tracking: true
+ use_multimask_token_for_obj_ptr: true
+ multimask_min_pt_num: 0
+ multimask_max_pt_num: 1
+ use_mlp_for_obj_ptr_proj: true
+ # Compilation flag
+ compile_image_encoder: False
diff --git a/libs/sam2/configs/sam2.1/sam2.1_hiera_s.yaml b/libs/sam2/configs/sam2.1/sam2.1_hiera_s.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..8e803dfea5904f5eb5e73981918c913197587728
--- /dev/null
+++ b/libs/sam2/configs/sam2.1/sam2.1_hiera_s.yaml
@@ -0,0 +1,119 @@
+# @package _global_
+
+# Model
+model:
+ _target_: sam2.modeling.sam2_base.SAM2Base
+ image_encoder:
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
+ scalp: 1
+ trunk:
+ _target_: sam2.modeling.backbones.hieradet.Hiera
+ embed_dim: 96
+ num_heads: 1
+ stages: [1, 2, 11, 2]
+ global_att_blocks: [7, 10, 13]
+ window_pos_embed_bkg_spatial_size: [7, 7]
+ neck:
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
+ position_encoding:
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
+ num_pos_feats: 256
+ normalize: true
+ scale: null
+ temperature: 10000
+ d_model: 256
+ backbone_channel_list: [768, 384, 192, 96]
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
+ fpn_interp_model: nearest
+
+ memory_attention:
+ _target_: sam2.modeling.memory_attention.MemoryAttention
+ d_model: 256
+ pos_enc_at_input: true
+ layer:
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
+ activation: relu
+ dim_feedforward: 2048
+ dropout: 0.1
+ pos_enc_at_attn: false
+ self_attention:
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
+ rope_theta: 10000.0
+ feat_sizes: [32, 32]
+ embedding_dim: 256
+ num_heads: 1
+ downsample_rate: 1
+ dropout: 0.1
+ d_model: 256
+ pos_enc_at_cross_attn_keys: true
+ pos_enc_at_cross_attn_queries: false
+ cross_attention:
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
+ rope_theta: 10000.0
+ feat_sizes: [32, 32]
+ rope_k_repeat: True
+ embedding_dim: 256
+ num_heads: 1
+ downsample_rate: 1
+ dropout: 0.1
+ kv_in_dim: 64
+ num_layers: 4
+
+ memory_encoder:
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
+ out_dim: 64
+ position_encoding:
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
+ num_pos_feats: 64
+ normalize: true
+ scale: null
+ temperature: 10000
+ mask_downsampler:
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
+ kernel_size: 3
+ stride: 2
+ padding: 1
+ fuser:
+ _target_: sam2.modeling.memory_encoder.Fuser
+ layer:
+ _target_: sam2.modeling.memory_encoder.CXBlock
+ dim: 256
+ kernel_size: 7
+ padding: 3
+ layer_scale_init_value: 1e-6
+ use_dwconv: True # depth-wise convs
+ num_layers: 2
+
+ num_maskmem: 7
+ image_size: 1024
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
+ sigmoid_scale_for_mem_enc: 20.0
+ sigmoid_bias_for_mem_enc: -10.0
+ use_mask_input_as_output_without_sam: true
+ # Memory
+ directly_add_no_mem_embed: true
+ no_obj_embed_spatial: true
+ # use high-resolution feature map in the SAM mask decoder
+ use_high_res_features_in_sam: true
+ # output 3 masks on the first click on initial conditioning frames
+ multimask_output_in_sam: true
+ # SAM heads
+ iou_prediction_use_sigmoid: True
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
+ use_obj_ptrs_in_encoder: true
+ add_tpos_enc_to_obj_ptrs: true
+ proj_tpos_enc_in_obj_ptrs: true
+ use_signed_tpos_enc_to_obj_ptrs: true
+ only_obj_ptrs_in_the_past_for_eval: true
+ # object occlusion prediction
+ pred_obj_scores: true
+ pred_obj_scores_mlp: true
+ fixed_no_obj_ptr: true
+ # multimask tracking settings
+ multimask_output_for_tracking: true
+ use_multimask_token_for_obj_ptr: true
+ multimask_min_pt_num: 0
+ multimask_max_pt_num: 1
+ use_mlp_for_obj_ptr_proj: true
+ # Compilation flag
+ compile_image_encoder: False
diff --git a/libs/sam2/configs/sam2.1/sam2.1_hiera_t.yaml b/libs/sam2/configs/sam2.1/sam2.1_hiera_t.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..983c2ea031b7a17db439fe89fa8b7bd426ecd9bb
--- /dev/null
+++ b/libs/sam2/configs/sam2.1/sam2.1_hiera_t.yaml
@@ -0,0 +1,121 @@
+# @package _global_
+
+# Model
+model:
+ _target_: sam2.modeling.sam2_base.SAM2Base
+ image_encoder:
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
+ scalp: 1
+ trunk:
+ _target_: sam2.modeling.backbones.hieradet.Hiera
+ embed_dim: 96
+ num_heads: 1
+ stages: [1, 2, 7, 2]
+ global_att_blocks: [5, 7, 9]
+ window_pos_embed_bkg_spatial_size: [7, 7]
+ neck:
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
+ position_encoding:
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
+ num_pos_feats: 256
+ normalize: true
+ scale: null
+ temperature: 10000
+ d_model: 256
+ backbone_channel_list: [768, 384, 192, 96]
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
+ fpn_interp_model: nearest
+
+ memory_attention:
+ _target_: sam2.modeling.memory_attention.MemoryAttention
+ d_model: 256
+ pos_enc_at_input: true
+ layer:
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
+ activation: relu
+ dim_feedforward: 2048
+ dropout: 0.1
+ pos_enc_at_attn: false
+ self_attention:
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
+ rope_theta: 10000.0
+ feat_sizes: [32, 32]
+ embedding_dim: 256
+ num_heads: 1
+ downsample_rate: 1
+ dropout: 0.1
+ d_model: 256
+ pos_enc_at_cross_attn_keys: true
+ pos_enc_at_cross_attn_queries: false
+ cross_attention:
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
+ rope_theta: 10000.0
+ feat_sizes: [32, 32]
+ rope_k_repeat: True
+ embedding_dim: 256
+ num_heads: 1
+ downsample_rate: 1
+ dropout: 0.1
+ kv_in_dim: 64
+ num_layers: 4
+
+ memory_encoder:
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
+ out_dim: 64
+ position_encoding:
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
+ num_pos_feats: 64
+ normalize: true
+ scale: null
+ temperature: 10000
+ mask_downsampler:
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
+ kernel_size: 3
+ stride: 2
+ padding: 1
+ fuser:
+ _target_: sam2.modeling.memory_encoder.Fuser
+ layer:
+ _target_: sam2.modeling.memory_encoder.CXBlock
+ dim: 256
+ kernel_size: 7
+ padding: 3
+ layer_scale_init_value: 1e-6
+ use_dwconv: True # depth-wise convs
+ num_layers: 2
+
+ num_maskmem: 7
+ image_size: 1024
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
+ # SAM decoder
+ sigmoid_scale_for_mem_enc: 20.0
+ sigmoid_bias_for_mem_enc: -10.0
+ use_mask_input_as_output_without_sam: true
+ # Memory
+ directly_add_no_mem_embed: true
+ no_obj_embed_spatial: true
+ # use high-resolution feature map in the SAM mask decoder
+ use_high_res_features_in_sam: true
+ # output 3 masks on the first click on initial conditioning frames
+ multimask_output_in_sam: true
+ # SAM heads
+ iou_prediction_use_sigmoid: True
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
+ use_obj_ptrs_in_encoder: true
+ add_tpos_enc_to_obj_ptrs: true
+ proj_tpos_enc_in_obj_ptrs: true
+ use_signed_tpos_enc_to_obj_ptrs: true
+ only_obj_ptrs_in_the_past_for_eval: true
+ # object occlusion prediction
+ pred_obj_scores: true
+ pred_obj_scores_mlp: true
+ fixed_no_obj_ptr: true
+ # multimask tracking settings
+ multimask_output_for_tracking: true
+ use_multimask_token_for_obj_ptr: true
+ multimask_min_pt_num: 0
+ multimask_max_pt_num: 1
+ use_mlp_for_obj_ptr_proj: true
+ # Compilation flag
+ # HieraT does not currently support compilation, should always be set to False
+ compile_image_encoder: False
diff --git a/libs/sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml b/libs/sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..204679146854110ce8a59e9adc462a6688e56d30
--- /dev/null
+++ b/libs/sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml
@@ -0,0 +1,339 @@
+# @package _global_
+
+scratch:
+ resolution: 1024
+ train_batch_size: 1
+ num_train_workers: 10
+ num_frames: 8
+ max_num_objects: 3
+ base_lr: 5.0e-6
+ vision_lr: 3.0e-06
+ phases_per_epoch: 1
+ num_epochs: 40
+
+dataset:
+ # PATHS to Dataset
+ img_folder: null # PATH to MOSE JPEGImages folder
+ gt_folder: null # PATH to MOSE Annotations folder
+ file_list_txt: training/assets/MOSE_sample_train_list.txt # Optional PATH to filelist containing a subset of videos to be used for training
+ multiplier: 2
+
+# Video transforms
+vos:
+ train_transforms:
+ - _target_: training.dataset.transforms.ComposeAPI
+ transforms:
+ - _target_: training.dataset.transforms.RandomHorizontalFlip
+ consistent_transform: True
+ - _target_: training.dataset.transforms.RandomAffine
+ degrees: 25
+ shear: 20
+ image_interpolation: bilinear
+ consistent_transform: True
+ - _target_: training.dataset.transforms.RandomResizeAPI
+ sizes: ${scratch.resolution}
+ square: true
+ consistent_transform: True
+ - _target_: training.dataset.transforms.ColorJitter
+ consistent_transform: True
+ brightness: 0.1
+ contrast: 0.03
+ saturation: 0.03
+ hue: null
+ - _target_: training.dataset.transforms.RandomGrayscale
+ p: 0.05
+ consistent_transform: True
+ - _target_: training.dataset.transforms.ColorJitter
+ consistent_transform: False
+ brightness: 0.1
+ contrast: 0.05
+ saturation: 0.05
+ hue: null
+ - _target_: training.dataset.transforms.ToTensorAPI
+ - _target_: training.dataset.transforms.NormalizeAPI
+ mean: [0.485, 0.456, 0.406]
+ std: [0.229, 0.224, 0.225]
+
+trainer:
+ _target_: training.trainer.Trainer
+ mode: train_only
+ max_epochs: ${times:${scratch.num_epochs},${scratch.phases_per_epoch}}
+ accelerator: cuda
+ seed_value: 123
+
+ model:
+ _target_: training.model.sam2.SAM2Train
+ image_encoder:
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
+ scalp: 1
+ trunk:
+ _target_: sam2.modeling.backbones.hieradet.Hiera
+ embed_dim: 112
+ num_heads: 2
+ drop_path_rate: 0.1
+ neck:
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
+ position_encoding:
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
+ num_pos_feats: 256
+ normalize: true
+ scale: null
+ temperature: 10000
+ d_model: 256
+ backbone_channel_list: [896, 448, 224, 112]
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
+ fpn_interp_model: nearest
+
+ memory_attention:
+ _target_: sam2.modeling.memory_attention.MemoryAttention
+ d_model: 256
+ pos_enc_at_input: true
+ layer:
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
+ activation: relu
+ dim_feedforward: 2048
+ dropout: 0.1
+ pos_enc_at_attn: false
+ self_attention:
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
+ rope_theta: 10000.0
+ feat_sizes: [32, 32]
+ embedding_dim: 256
+ num_heads: 1
+ downsample_rate: 1
+ dropout: 0.1
+ d_model: 256
+ pos_enc_at_cross_attn_keys: true
+ pos_enc_at_cross_attn_queries: false
+ cross_attention:
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
+ rope_theta: 10000.0
+ feat_sizes: [32, 32]
+ rope_k_repeat: True
+ embedding_dim: 256
+ num_heads: 1
+ downsample_rate: 1
+ dropout: 0.1
+ kv_in_dim: 64
+ num_layers: 4
+
+ memory_encoder:
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
+ out_dim: 64
+ position_encoding:
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
+ num_pos_feats: 64
+ normalize: true
+ scale: null
+ temperature: 10000
+ mask_downsampler:
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
+ kernel_size: 3
+ stride: 2
+ padding: 1
+ fuser:
+ _target_: sam2.modeling.memory_encoder.Fuser
+ layer:
+ _target_: sam2.modeling.memory_encoder.CXBlock
+ dim: 256
+ kernel_size: 7
+ padding: 3
+ layer_scale_init_value: 1e-6
+ use_dwconv: True # depth-wise convs
+ num_layers: 2
+
+ num_maskmem: 7
+ image_size: ${scratch.resolution}
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
+ sigmoid_scale_for_mem_enc: 20.0
+ sigmoid_bias_for_mem_enc: -10.0
+ use_mask_input_as_output_without_sam: true
+ # Memory
+ directly_add_no_mem_embed: true
+ no_obj_embed_spatial: true
+ # use high-resolution feature map in the SAM mask decoder
+ use_high_res_features_in_sam: true
+ # output 3 masks on the first click on initial conditioning frames
+ multimask_output_in_sam: true
+ # SAM heads
+ iou_prediction_use_sigmoid: True
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
+ use_obj_ptrs_in_encoder: true
+ add_tpos_enc_to_obj_ptrs: true
+ proj_tpos_enc_in_obj_ptrs: true
+ use_signed_tpos_enc_to_obj_ptrs: true
+ only_obj_ptrs_in_the_past_for_eval: true
+ # object occlusion prediction
+ pred_obj_scores: true
+ pred_obj_scores_mlp: true
+ fixed_no_obj_ptr: true
+ # multimask tracking settings
+ multimask_output_for_tracking: true
+ use_multimask_token_for_obj_ptr: true
+ multimask_min_pt_num: 0
+ multimask_max_pt_num: 1
+ use_mlp_for_obj_ptr_proj: true
+ # Compilation flag
+ # compile_image_encoder: False
+
+ ####### Training specific params #######
+ # box/point input and corrections
+ prob_to_use_pt_input_for_train: 0.5
+ prob_to_use_pt_input_for_eval: 0.0
+ prob_to_use_box_input_for_train: 0.5 # 0.5*0.5 = 0.25 prob to use box instead of points
+ prob_to_use_box_input_for_eval: 0.0
+ prob_to_sample_from_gt_for_train: 0.1 # with a small prob, sampling correction points from GT mask instead of prediction errors
+ num_frames_to_correct_for_train: 2 # iteratively sample on random 1~2 frames (always include the first frame)
+ num_frames_to_correct_for_eval: 1 # only iteratively sample on first frame
+ rand_frames_to_correct_for_train: True # random #init-cond-frame ~ 2
+ add_all_frames_to_correct_as_cond: True # when a frame receives a correction click, it becomes a conditioning frame (even if it's not initially a conditioning frame)
+ # maximum 2 initial conditioning frames
+ num_init_cond_frames_for_train: 2
+ rand_init_cond_frames_for_train: True # random 1~2
+ num_correction_pt_per_frame: 7
+ use_act_ckpt_iterative_pt_sampling: false
+
+
+
+ num_init_cond_frames_for_eval: 1 # only mask on the first frame
+ forward_backbone_per_frame_for_eval: True
+
+
+ data:
+ train:
+ _target_: training.dataset.sam2_datasets.TorchTrainMixedDataset
+ phases_per_epoch: ${scratch.phases_per_epoch}
+ batch_sizes:
+ - ${scratch.train_batch_size}
+
+ datasets:
+ - _target_: training.dataset.utils.RepeatFactorWrapper
+ dataset:
+ _target_: training.dataset.utils.ConcatDataset
+ datasets:
+ - _target_: training.dataset.vos_dataset.VOSDataset
+ transforms: ${vos.train_transforms}
+ training: true
+ video_dataset:
+ _target_: training.dataset.vos_raw_dataset.PNGRawDataset
+ img_folder: ${dataset.img_folder}
+ gt_folder: ${dataset.gt_folder}
+ file_list_txt: ${dataset.file_list_txt}
+ sampler:
+ _target_: training.dataset.vos_sampler.RandomUniformSampler
+ num_frames: ${scratch.num_frames}
+ max_num_objects: ${scratch.max_num_objects}
+ multiplier: ${dataset.multiplier}
+ shuffle: True
+ num_workers: ${scratch.num_train_workers}
+ pin_memory: True
+ drop_last: True
+ collate_fn:
+ _target_: training.utils.data_utils.collate_fn
+ _partial_: true
+ dict_key: all
+
+ optim:
+ amp:
+ enabled: True
+ amp_dtype: bfloat16
+
+ optimizer:
+ _target_: torch.optim.AdamW
+
+ gradient_clip:
+ _target_: training.optimizer.GradientClipper
+ max_norm: 0.1
+ norm_type: 2
+
+ param_group_modifiers:
+ - _target_: training.optimizer.layer_decay_param_modifier
+ _partial_: True
+ layer_decay_value: 0.9
+ apply_to: 'image_encoder.trunk'
+ overrides:
+ - pattern: '*pos_embed*'
+ value: 1.0
+
+ options:
+ lr:
+ - scheduler:
+ _target_: fvcore.common.param_scheduler.CosineParamScheduler
+ start_value: ${scratch.base_lr}
+ end_value: ${divide:${scratch.base_lr},10}
+ - scheduler:
+ _target_: fvcore.common.param_scheduler.CosineParamScheduler
+ start_value: ${scratch.vision_lr}
+ end_value: ${divide:${scratch.vision_lr},10}
+ param_names:
+ - 'image_encoder.*'
+ weight_decay:
+ - scheduler:
+ _target_: fvcore.common.param_scheduler.ConstantParamScheduler
+ value: 0.1
+ - scheduler:
+ _target_: fvcore.common.param_scheduler.ConstantParamScheduler
+ value: 0.0
+ param_names:
+ - '*bias*'
+ module_cls_names: ['torch.nn.LayerNorm']
+
+ loss:
+ all:
+ _target_: training.loss_fns.MultiStepMultiMasksAndIous
+ weight_dict:
+ loss_mask: 20
+ loss_dice: 1
+ loss_iou: 1
+ loss_class: 1
+ supervise_all_iou: true
+ iou_use_l1_loss: true
+ pred_obj_scores: true
+ focal_gamma_obj_score: 0.0
+ focal_alpha_obj_score: -1.0
+
+ distributed:
+ backend: nccl
+ find_unused_parameters: True
+
+ logging:
+ tensorboard_writer:
+ _target_: training.utils.logger.make_tensorboard_logger
+ log_dir: ${launcher.experiment_log_dir}/tensorboard
+ flush_secs: 120
+ should_log: True
+ log_dir: ${launcher.experiment_log_dir}/logs
+ log_freq: 10
+
+ # initialize from a SAM 2 checkpoint
+ checkpoint:
+ save_dir: ${launcher.experiment_log_dir}/checkpoints
+ save_freq: 0 # 0 only last checkpoint is saved.
+ model_weight_initializer:
+ _partial_: True
+ _target_: training.utils.checkpoint_utils.load_state_dict_into_model
+ strict: True
+ ignore_unexpected_keys: null
+ ignore_missing_keys: null
+
+ state_dict:
+ _target_: training.utils.checkpoint_utils.load_checkpoint_and_apply_kernels
+ checkpoint_path: ./checkpoints/sam2.1_hiera_base_plus.pt # PATH to SAM 2.1 checkpoint
+ ckpt_state_dict_keys: ['model']
+
+launcher:
+ num_nodes: 1
+ gpus_per_node: 8
+ experiment_log_dir: null # Path to log directory, defaults to ./sam2_logs/${config_name}
+
+# SLURM args if running on a cluster
+submitit:
+ partition: null
+ account: null
+ qos: null
+ cpus_per_task: 10
+ use_cluster: false
+ timeout_hour: 24
+ name: null
+ port_range: [10000, 65000]
+
diff --git a/libs/sam2/configs/sam2/sam2_hiera_b+.yaml b/libs/sam2/configs/sam2/sam2_hiera_b+.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..58f3eb81554018e873f8515ecb98e36d16ac29e4
--- /dev/null
+++ b/libs/sam2/configs/sam2/sam2_hiera_b+.yaml
@@ -0,0 +1,113 @@
+# @package _global_
+
+# Model
+model:
+ _target_: sam2.modeling.sam2_base.SAM2Base
+ image_encoder:
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
+ scalp: 1
+ trunk:
+ _target_: sam2.modeling.backbones.hieradet.Hiera
+ embed_dim: 112
+ num_heads: 2
+ neck:
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
+ position_encoding:
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
+ num_pos_feats: 256
+ normalize: true
+ scale: null
+ temperature: 10000
+ d_model: 256
+ backbone_channel_list: [896, 448, 224, 112]
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
+ fpn_interp_model: nearest
+
+ memory_attention:
+ _target_: sam2.modeling.memory_attention.MemoryAttention
+ d_model: 256
+ pos_enc_at_input: true
+ layer:
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
+ activation: relu
+ dim_feedforward: 2048
+ dropout: 0.1
+ pos_enc_at_attn: false
+ self_attention:
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
+ rope_theta: 10000.0
+ feat_sizes: [32, 32]
+ embedding_dim: 256
+ num_heads: 1
+ downsample_rate: 1
+ dropout: 0.1
+ d_model: 256
+ pos_enc_at_cross_attn_keys: true
+ pos_enc_at_cross_attn_queries: false
+ cross_attention:
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
+ rope_theta: 10000.0
+ feat_sizes: [32, 32]
+ rope_k_repeat: True
+ embedding_dim: 256
+ num_heads: 1
+ downsample_rate: 1
+ dropout: 0.1
+ kv_in_dim: 64
+ num_layers: 4
+
+ memory_encoder:
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
+ out_dim: 64
+ position_encoding:
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
+ num_pos_feats: 64
+ normalize: true
+ scale: null
+ temperature: 10000
+ mask_downsampler:
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
+ kernel_size: 3
+ stride: 2
+ padding: 1
+ fuser:
+ _target_: sam2.modeling.memory_encoder.Fuser
+ layer:
+ _target_: sam2.modeling.memory_encoder.CXBlock
+ dim: 256
+ kernel_size: 7
+ padding: 3
+ layer_scale_init_value: 1e-6
+ use_dwconv: True # depth-wise convs
+ num_layers: 2
+
+ num_maskmem: 7
+ image_size: 1024
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
+ sigmoid_scale_for_mem_enc: 20.0
+ sigmoid_bias_for_mem_enc: -10.0
+ use_mask_input_as_output_without_sam: true
+ # Memory
+ directly_add_no_mem_embed: true
+ # use high-resolution feature map in the SAM mask decoder
+ use_high_res_features_in_sam: true
+ # output 3 masks on the first click on initial conditioning frames
+ multimask_output_in_sam: true
+ # SAM heads
+ iou_prediction_use_sigmoid: True
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
+ use_obj_ptrs_in_encoder: true
+ add_tpos_enc_to_obj_ptrs: false
+ only_obj_ptrs_in_the_past_for_eval: true
+ # object occlusion prediction
+ pred_obj_scores: true
+ pred_obj_scores_mlp: true
+ fixed_no_obj_ptr: true
+ # multimask tracking settings
+ multimask_output_for_tracking: true
+ use_multimask_token_for_obj_ptr: true
+ multimask_min_pt_num: 0
+ multimask_max_pt_num: 1
+ use_mlp_for_obj_ptr_proj: true
+ # Compilation flag
+ compile_image_encoder: False
diff --git a/libs/sam2/configs/sam2/sam2_hiera_l.yaml b/libs/sam2/configs/sam2/sam2_hiera_l.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..918667f50c3e1ad2dcf77c0c14cb4dd114cfd080
--- /dev/null
+++ b/libs/sam2/configs/sam2/sam2_hiera_l.yaml
@@ -0,0 +1,117 @@
+# @package _global_
+
+# Model
+model:
+ _target_: sam2.modeling.sam2_base.SAM2Base
+ image_encoder:
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
+ scalp: 1
+ trunk:
+ _target_: sam2.modeling.backbones.hieradet.Hiera
+ embed_dim: 144
+ num_heads: 2
+ stages: [2, 6, 36, 4]
+ global_att_blocks: [23, 33, 43]
+ window_pos_embed_bkg_spatial_size: [7, 7]
+ window_spec: [8, 4, 16, 8]
+ neck:
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
+ position_encoding:
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
+ num_pos_feats: 256
+ normalize: true
+ scale: null
+ temperature: 10000
+ d_model: 256
+ backbone_channel_list: [1152, 576, 288, 144]
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
+ fpn_interp_model: nearest
+
+ memory_attention:
+ _target_: sam2.modeling.memory_attention.MemoryAttention
+ d_model: 256
+ pos_enc_at_input: true
+ layer:
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
+ activation: relu
+ dim_feedforward: 2048
+ dropout: 0.1
+ pos_enc_at_attn: false
+ self_attention:
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
+ rope_theta: 10000.0
+ feat_sizes: [32, 32]
+ embedding_dim: 256
+ num_heads: 1
+ downsample_rate: 1
+ dropout: 0.1
+ d_model: 256
+ pos_enc_at_cross_attn_keys: true
+ pos_enc_at_cross_attn_queries: false
+ cross_attention:
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
+ rope_theta: 10000.0
+ feat_sizes: [32, 32]
+ rope_k_repeat: True
+ embedding_dim: 256
+ num_heads: 1
+ downsample_rate: 1
+ dropout: 0.1
+ kv_in_dim: 64
+ num_layers: 4
+
+ memory_encoder:
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
+ out_dim: 64
+ position_encoding:
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
+ num_pos_feats: 64
+ normalize: true
+ scale: null
+ temperature: 10000
+ mask_downsampler:
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
+ kernel_size: 3
+ stride: 2
+ padding: 1
+ fuser:
+ _target_: sam2.modeling.memory_encoder.Fuser
+ layer:
+ _target_: sam2.modeling.memory_encoder.CXBlock
+ dim: 256
+ kernel_size: 7
+ padding: 3
+ layer_scale_init_value: 1e-6
+ use_dwconv: True # depth-wise convs
+ num_layers: 2
+
+ num_maskmem: 7
+ image_size: 1024
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
+ sigmoid_scale_for_mem_enc: 20.0
+ sigmoid_bias_for_mem_enc: -10.0
+ use_mask_input_as_output_without_sam: true
+ # Memory
+ directly_add_no_mem_embed: true
+ # use high-resolution feature map in the SAM mask decoder
+ use_high_res_features_in_sam: true
+ # output 3 masks on the first click on initial conditioning frames
+ multimask_output_in_sam: true
+ # SAM heads
+ iou_prediction_use_sigmoid: True
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
+ use_obj_ptrs_in_encoder: true
+ add_tpos_enc_to_obj_ptrs: false
+ only_obj_ptrs_in_the_past_for_eval: true
+ # object occlusion prediction
+ pred_obj_scores: true
+ pred_obj_scores_mlp: true
+ fixed_no_obj_ptr: true
+ # multimask tracking settings
+ multimask_output_for_tracking: true
+ use_multimask_token_for_obj_ptr: true
+ multimask_min_pt_num: 0
+ multimask_max_pt_num: 1
+ use_mlp_for_obj_ptr_proj: true
+ # Compilation flag
+ compile_image_encoder: False
diff --git a/libs/sam2/configs/sam2/sam2_hiera_s.yaml b/libs/sam2/configs/sam2/sam2_hiera_s.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..26e5d4d39f7b2892396106005c37c7ffe6c83bc2
--- /dev/null
+++ b/libs/sam2/configs/sam2/sam2_hiera_s.yaml
@@ -0,0 +1,116 @@
+# @package _global_
+
+# Model
+model:
+ _target_: sam2.modeling.sam2_base.SAM2Base
+ image_encoder:
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
+ scalp: 1
+ trunk:
+ _target_: sam2.modeling.backbones.hieradet.Hiera
+ embed_dim: 96
+ num_heads: 1
+ stages: [1, 2, 11, 2]
+ global_att_blocks: [7, 10, 13]
+ window_pos_embed_bkg_spatial_size: [7, 7]
+ neck:
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
+ position_encoding:
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
+ num_pos_feats: 256
+ normalize: true
+ scale: null
+ temperature: 10000
+ d_model: 256
+ backbone_channel_list: [768, 384, 192, 96]
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
+ fpn_interp_model: nearest
+
+ memory_attention:
+ _target_: sam2.modeling.memory_attention.MemoryAttention
+ d_model: 256
+ pos_enc_at_input: true
+ layer:
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
+ activation: relu
+ dim_feedforward: 2048
+ dropout: 0.1
+ pos_enc_at_attn: false
+ self_attention:
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
+ rope_theta: 10000.0
+ feat_sizes: [32, 32]
+ embedding_dim: 256
+ num_heads: 1
+ downsample_rate: 1
+ dropout: 0.1
+ d_model: 256
+ pos_enc_at_cross_attn_keys: true
+ pos_enc_at_cross_attn_queries: false
+ cross_attention:
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
+ rope_theta: 10000.0
+ feat_sizes: [32, 32]
+ rope_k_repeat: True
+ embedding_dim: 256
+ num_heads: 1
+ downsample_rate: 1
+ dropout: 0.1
+ kv_in_dim: 64
+ num_layers: 4
+
+ memory_encoder:
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
+ out_dim: 64
+ position_encoding:
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
+ num_pos_feats: 64
+ normalize: true
+ scale: null
+ temperature: 10000
+ mask_downsampler:
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
+ kernel_size: 3
+ stride: 2
+ padding: 1
+ fuser:
+ _target_: sam2.modeling.memory_encoder.Fuser
+ layer:
+ _target_: sam2.modeling.memory_encoder.CXBlock
+ dim: 256
+ kernel_size: 7
+ padding: 3
+ layer_scale_init_value: 1e-6
+ use_dwconv: True # depth-wise convs
+ num_layers: 2
+
+ num_maskmem: 7
+ image_size: 1024
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
+ sigmoid_scale_for_mem_enc: 20.0
+ sigmoid_bias_for_mem_enc: -10.0
+ use_mask_input_as_output_without_sam: true
+ # Memory
+ directly_add_no_mem_embed: true
+ # use high-resolution feature map in the SAM mask decoder
+ use_high_res_features_in_sam: true
+ # output 3 masks on the first click on initial conditioning frames
+ multimask_output_in_sam: true
+ # SAM heads
+ iou_prediction_use_sigmoid: True
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
+ use_obj_ptrs_in_encoder: true
+ add_tpos_enc_to_obj_ptrs: false
+ only_obj_ptrs_in_the_past_for_eval: true
+ # object occlusion prediction
+ pred_obj_scores: true
+ pred_obj_scores_mlp: true
+ fixed_no_obj_ptr: true
+ # multimask tracking settings
+ multimask_output_for_tracking: true
+ use_multimask_token_for_obj_ptr: true
+ multimask_min_pt_num: 0
+ multimask_max_pt_num: 1
+ use_mlp_for_obj_ptr_proj: true
+ # Compilation flag
+ compile_image_encoder: False
diff --git a/libs/sam2/configs/sam2/sam2_hiera_t.yaml b/libs/sam2/configs/sam2/sam2_hiera_t.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a62c903aaa5f80828077c6e06a59626926570ed6
--- /dev/null
+++ b/libs/sam2/configs/sam2/sam2_hiera_t.yaml
@@ -0,0 +1,118 @@
+# @package _global_
+
+# Model
+model:
+ _target_: sam2.modeling.sam2_base.SAM2Base
+ image_encoder:
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
+ scalp: 1
+ trunk:
+ _target_: sam2.modeling.backbones.hieradet.Hiera
+ embed_dim: 96
+ num_heads: 1
+ stages: [1, 2, 7, 2]
+ global_att_blocks: [5, 7, 9]
+ window_pos_embed_bkg_spatial_size: [7, 7]
+ neck:
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
+ position_encoding:
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
+ num_pos_feats: 256
+ normalize: true
+ scale: null
+ temperature: 10000
+ d_model: 256
+ backbone_channel_list: [768, 384, 192, 96]
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
+ fpn_interp_model: nearest
+
+ memory_attention:
+ _target_: sam2.modeling.memory_attention.MemoryAttention
+ d_model: 256
+ pos_enc_at_input: true
+ layer:
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
+ activation: relu
+ dim_feedforward: 2048
+ dropout: 0.1
+ pos_enc_at_attn: false
+ self_attention:
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
+ rope_theta: 10000.0
+ feat_sizes: [32, 32]
+ embedding_dim: 256
+ num_heads: 1
+ downsample_rate: 1
+ dropout: 0.1
+ d_model: 256
+ pos_enc_at_cross_attn_keys: true
+ pos_enc_at_cross_attn_queries: false
+ cross_attention:
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
+ rope_theta: 10000.0
+ feat_sizes: [32, 32]
+ rope_k_repeat: True
+ embedding_dim: 256
+ num_heads: 1
+ downsample_rate: 1
+ dropout: 0.1
+ kv_in_dim: 64
+ num_layers: 4
+
+ memory_encoder:
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
+ out_dim: 64
+ position_encoding:
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
+ num_pos_feats: 64
+ normalize: true
+ scale: null
+ temperature: 10000
+ mask_downsampler:
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
+ kernel_size: 3
+ stride: 2
+ padding: 1
+ fuser:
+ _target_: sam2.modeling.memory_encoder.Fuser
+ layer:
+ _target_: sam2.modeling.memory_encoder.CXBlock
+ dim: 256
+ kernel_size: 7
+ padding: 3
+ layer_scale_init_value: 1e-6
+ use_dwconv: True # depth-wise convs
+ num_layers: 2
+
+ num_maskmem: 7
+ image_size: 1024
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
+ # SAM decoder
+ sigmoid_scale_for_mem_enc: 20.0
+ sigmoid_bias_for_mem_enc: -10.0
+ use_mask_input_as_output_without_sam: true
+ # Memory
+ directly_add_no_mem_embed: true
+ # use high-resolution feature map in the SAM mask decoder
+ use_high_res_features_in_sam: true
+ # output 3 masks on the first click on initial conditioning frames
+ multimask_output_in_sam: true
+ # SAM heads
+ iou_prediction_use_sigmoid: True
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
+ use_obj_ptrs_in_encoder: true
+ add_tpos_enc_to_obj_ptrs: false
+ only_obj_ptrs_in_the_past_for_eval: true
+ # object occlusion prediction
+ pred_obj_scores: true
+ pred_obj_scores_mlp: true
+ fixed_no_obj_ptr: true
+ # multimask tracking settings
+ multimask_output_for_tracking: true
+ use_multimask_token_for_obj_ptr: true
+ multimask_min_pt_num: 0
+ multimask_max_pt_num: 1
+ use_mlp_for_obj_ptr_proj: true
+ # Compilation flag
+ # HieraT does not currently support compilation, should always be set to False
+ compile_image_encoder: False
diff --git a/libs/sam2/csrc/connected_components.cu b/libs/sam2/csrc/connected_components.cu
new file mode 100644
index 0000000000000000000000000000000000000000..ced21eb32eaaadb818d441c1322b99d1bf068f45
--- /dev/null
+++ b/libs/sam2/csrc/connected_components.cu
@@ -0,0 +1,289 @@
+// Copyright (c) Meta Platforms, Inc. and affiliates.
+// All rights reserved.
+
+// This source code is licensed under the license found in the
+// LICENSE file in the root directory of this source tree.
+
+// adapted from https://github.com/zsef123/Connected_components_PyTorch
+// with license found in the LICENSE_cctorch file in the root directory.
+#include
+#include
+#include
+#include
+#include
+#include
+
+// 2d
+#define BLOCK_ROWS 16
+#define BLOCK_COLS 16
+
+namespace cc2d {
+
+template
+__device__ __forceinline__ unsigned char hasBit(T bitmap, unsigned char pos) {
+ return (bitmap >> pos) & 1;
+}
+
+__device__ int32_t find(const int32_t* s_buf, int32_t n) {
+ while (s_buf[n] != n)
+ n = s_buf[n];
+ return n;
+}
+
+__device__ int32_t find_n_compress(int32_t* s_buf, int32_t n) {
+ const int32_t id = n;
+ while (s_buf[n] != n) {
+ n = s_buf[n];
+ s_buf[id] = n;
+ }
+ return n;
+}
+
+__device__ void union_(int32_t* s_buf, int32_t a, int32_t b) {
+ bool done;
+ do {
+ a = find(s_buf, a);
+ b = find(s_buf, b);
+
+ if (a < b) {
+ int32_t old = atomicMin(s_buf + b, a);
+ done = (old == b);
+ b = old;
+ } else if (b < a) {
+ int32_t old = atomicMin(s_buf + a, b);
+ done = (old == a);
+ a = old;
+ } else
+ done = true;
+
+ } while (!done);
+}
+
+__global__ void
+init_labeling(int32_t* label, const uint32_t W, const uint32_t H) {
+ const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2;
+ const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2;
+ const uint32_t idx = row * W + col;
+
+ if (row < H && col < W)
+ label[idx] = idx;
+}
+
+__global__ void
+merge(uint8_t* img, int32_t* label, const uint32_t W, const uint32_t H) {
+ const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2;
+ const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2;
+ const uint32_t idx = row * W + col;
+
+ if (row >= H || col >= W)
+ return;
+
+ uint32_t P = 0;
+
+ if (img[idx])
+ P |= 0x777;
+ if (row + 1 < H && img[idx + W])
+ P |= 0x777 << 4;
+ if (col + 1 < W && img[idx + 1])
+ P |= 0x777 << 1;
+
+ if (col == 0)
+ P &= 0xEEEE;
+ if (col + 1 >= W)
+ P &= 0x3333;
+ else if (col + 2 >= W)
+ P &= 0x7777;
+
+ if (row == 0)
+ P &= 0xFFF0;
+ if (row + 1 >= H)
+ P &= 0xFF;
+
+ if (P > 0) {
+ // If need check about top-left pixel(if flag the first bit) and hit the
+ // top-left pixel
+ if (hasBit(P, 0) && img[idx - W - 1]) {
+ union_(label, idx, idx - 2 * W - 2); // top left block
+ }
+
+ if ((hasBit(P, 1) && img[idx - W]) || (hasBit(P, 2) && img[idx - W + 1]))
+ union_(label, idx, idx - 2 * W); // top bottom block
+
+ if (hasBit(P, 3) && img[idx + 2 - W])
+ union_(label, idx, idx - 2 * W + 2); // top right block
+
+ if ((hasBit(P, 4) && img[idx - 1]) || (hasBit(P, 8) && img[idx + W - 1]))
+ union_(label, idx, idx - 2); // just left block
+ }
+}
+
+__global__ void compression(int32_t* label, const int32_t W, const int32_t H) {
+ const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2;
+ const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2;
+ const uint32_t idx = row * W + col;
+
+ if (row < H && col < W)
+ find_n_compress(label, idx);
+}
+
+__global__ void final_labeling(
+ const uint8_t* img,
+ int32_t* label,
+ const int32_t W,
+ const int32_t H) {
+ const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2;
+ const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2;
+ const uint32_t idx = row * W + col;
+
+ if (row >= H || col >= W)
+ return;
+
+ int32_t y = label[idx] + 1;
+
+ if (img[idx])
+ label[idx] = y;
+ else
+ label[idx] = 0;
+
+ if (col + 1 < W) {
+ if (img[idx + 1])
+ label[idx + 1] = y;
+ else
+ label[idx + 1] = 0;
+
+ if (row + 1 < H) {
+ if (img[idx + W + 1])
+ label[idx + W + 1] = y;
+ else
+ label[idx + W + 1] = 0;
+ }
+ }
+
+ if (row + 1 < H) {
+ if (img[idx + W])
+ label[idx + W] = y;
+ else
+ label[idx + W] = 0;
+ }
+}
+
+__global__ void init_counting(
+ const int32_t* label,
+ int32_t* count_init,
+ const int32_t W,
+ const int32_t H) {
+ const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y);
+ const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x);
+ const uint32_t idx = row * W + col;
+
+ if (row >= H || col >= W)
+ return;
+
+ int32_t y = label[idx];
+ if (y > 0) {
+ int32_t count_idx = y - 1;
+ atomicAdd(count_init + count_idx, 1);
+ }
+}
+
+__global__ void final_counting(
+ const int32_t* label,
+ const int32_t* count_init,
+ int32_t* count_final,
+ const int32_t W,
+ const int32_t H) {
+ const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y);
+ const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x);
+ const uint32_t idx = row * W + col;
+
+ if (row >= H || col >= W)
+ return;
+
+ int32_t y = label[idx];
+ if (y > 0) {
+ int32_t count_idx = y - 1;
+ count_final[idx] = count_init[count_idx];
+ } else {
+ count_final[idx] = 0;
+ }
+}
+
+} // namespace cc2d
+
+std::vector get_connected_componnets(
+ const torch::Tensor& inputs) {
+ AT_ASSERTM(inputs.is_cuda(), "inputs must be a CUDA tensor");
+ AT_ASSERTM(inputs.ndimension() == 4, "inputs must be [N, 1, H, W] shape");
+ AT_ASSERTM(
+ inputs.scalar_type() == torch::kUInt8, "inputs must be a uint8 type");
+
+ const uint32_t N = inputs.size(0);
+ const uint32_t C = inputs.size(1);
+ const uint32_t H = inputs.size(2);
+ const uint32_t W = inputs.size(3);
+
+ AT_ASSERTM(C == 1, "inputs must be [N, 1, H, W] shape");
+ AT_ASSERTM((H % 2) == 0, "height must be an even number");
+ AT_ASSERTM((W % 2) == 0, "width must be an even number");
+
+ // label must be uint32_t
+ auto label_options =
+ torch::TensorOptions().dtype(torch::kInt32).device(inputs.device());
+ torch::Tensor labels = torch::zeros({N, C, H, W}, label_options);
+ torch::Tensor counts_init = torch::zeros({N, C, H, W}, label_options);
+ torch::Tensor counts_final = torch::zeros({N, C, H, W}, label_options);
+
+ dim3 grid = dim3(
+ ((W + 1) / 2 + BLOCK_COLS - 1) / BLOCK_COLS,
+ ((H + 1) / 2 + BLOCK_ROWS - 1) / BLOCK_ROWS);
+ dim3 block = dim3(BLOCK_COLS, BLOCK_ROWS);
+ dim3 grid_count =
+ dim3((W + BLOCK_COLS) / BLOCK_COLS, (H + BLOCK_ROWS) / BLOCK_ROWS);
+ dim3 block_count = dim3(BLOCK_COLS, BLOCK_ROWS);
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+
+ for (int n = 0; n < N; n++) {
+ uint32_t offset = n * H * W;
+
+ cc2d::init_labeling<<>>(
+ labels.data_ptr() + offset, W, H);
+ cc2d::merge<<>>(
+ inputs.data_ptr() + offset,
+ labels.data_ptr() + offset,
+ W,
+ H);
+ cc2d::compression<<>>(
+ labels.data_ptr() + offset, W, H);
+ cc2d::final_labeling<<>>(
+ inputs.data_ptr() + offset,
+ labels.data_ptr() + offset,
+ W,
+ H);
+
+ // get the counting of each pixel
+ cc2d::init_counting<<>>(
+ labels.data_ptr() + offset,
+ counts_init.data_ptr() + offset,
+ W,
+ H);
+ cc2d::final_counting<<>>(
+ labels.data_ptr() + offset,
+ counts_init.data_ptr() + offset,
+ counts_final.data_ptr() + offset,
+ W,
+ H);
+ }
+
+ // returned values are [labels, counts]
+ std::vector outputs;
+ outputs.push_back(labels);
+ outputs.push_back(counts_final);
+ return outputs;
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def(
+ "get_connected_componnets",
+ &get_connected_componnets,
+ "get_connected_componnets");
+}
diff --git a/libs/sam2/modeling/__init__.py b/libs/sam2/modeling/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5277f46157403e47fd830fc519144b97ef69d4ae
--- /dev/null
+++ b/libs/sam2/modeling/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
diff --git a/libs/sam2/modeling/backbones/__init__.py b/libs/sam2/modeling/backbones/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5277f46157403e47fd830fc519144b97ef69d4ae
--- /dev/null
+++ b/libs/sam2/modeling/backbones/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
diff --git a/libs/sam2/modeling/backbones/hieradet.py b/libs/sam2/modeling/backbones/hieradet.py
new file mode 100644
index 0000000000000000000000000000000000000000..19ac77b61d8e1345a301686d39ef2ab6e4b035fb
--- /dev/null
+++ b/libs/sam2/modeling/backbones/hieradet.py
@@ -0,0 +1,317 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import logging
+from functools import partial
+from typing import List, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from iopath.common.file_io import g_pathmgr
+
+from sam2.modeling.backbones.utils import (
+ PatchEmbed,
+ window_partition,
+ window_unpartition,
+)
+
+from sam2.modeling.sam2_utils import DropPath, MLP
+
+
+def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor:
+ if pool is None:
+ return x
+ # (B, H, W, C) -> (B, C, H, W)
+ x = x.permute(0, 3, 1, 2)
+ x = pool(x)
+ # (B, C, H', W') -> (B, H', W', C)
+ x = x.permute(0, 2, 3, 1)
+ if norm:
+ x = norm(x)
+
+ return x
+
+
+class MultiScaleAttention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ dim_out: int,
+ num_heads: int,
+ q_pool: nn.Module = None,
+ ):
+ super().__init__()
+
+ self.dim = dim
+ self.dim_out = dim_out
+ self.num_heads = num_heads
+ self.q_pool = q_pool
+ self.qkv = nn.Linear(dim, dim_out * 3)
+ self.proj = nn.Linear(dim_out, dim_out)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ B, H, W, _ = x.shape
+ # qkv with shape (B, H * W, 3, nHead, C)
+ qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1)
+ # q, k, v with shape (B, H * W, nheads, C)
+ q, k, v = torch.unbind(qkv, 2)
+
+ # Q pooling (for downsample at stage changes)
+ if self.q_pool:
+ q = do_pool(q.reshape(B, H, W, -1), self.q_pool)
+ H, W = q.shape[1:3] # downsampled shape
+ q = q.reshape(B, H * W, self.num_heads, -1)
+
+ # Torch's SDPA expects [B, nheads, H*W, C] so we transpose
+ x = F.scaled_dot_product_attention(
+ q.transpose(1, 2),
+ k.transpose(1, 2),
+ v.transpose(1, 2),
+ )
+ # Transpose back
+ x = x.transpose(1, 2)
+ x = x.reshape(B, H, W, -1)
+
+ x = self.proj(x)
+
+ return x
+
+
+class MultiScaleBlock(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ dim_out: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ drop_path: float = 0.0,
+ norm_layer: Union[nn.Module, str] = "LayerNorm",
+ q_stride: Tuple[int, int] = None,
+ act_layer: nn.Module = nn.GELU,
+ window_size: int = 0,
+ ):
+ super().__init__()
+
+ if isinstance(norm_layer, str):
+ norm_layer = partial(getattr(nn, norm_layer), eps=1e-6)
+
+ self.dim = dim
+ self.dim_out = dim_out
+ self.norm1 = norm_layer(dim)
+
+ self.window_size = window_size
+
+ self.pool, self.q_stride = None, q_stride
+ if self.q_stride:
+ self.pool = nn.MaxPool2d(
+ kernel_size=q_stride, stride=q_stride, ceil_mode=False
+ )
+
+ self.attn = MultiScaleAttention(
+ dim,
+ dim_out,
+ num_heads=num_heads,
+ q_pool=self.pool,
+ )
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.norm2 = norm_layer(dim_out)
+ self.mlp = MLP(
+ dim_out,
+ int(dim_out * mlp_ratio),
+ dim_out,
+ num_layers=2,
+ activation=act_layer,
+ )
+
+ if dim != dim_out:
+ self.proj = nn.Linear(dim, dim_out)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ shortcut = x # B, H, W, C
+ x = self.norm1(x)
+
+ # Skip connection
+ if self.dim != self.dim_out:
+ shortcut = do_pool(self.proj(x), self.pool)
+
+ # Window partition
+ window_size = self.window_size
+ if window_size > 0:
+ H, W = x.shape[1], x.shape[2]
+ x, pad_hw = window_partition(x, window_size)
+
+ # Window Attention + Q Pooling (if stage change)
+ x = self.attn(x)
+ if self.q_stride:
+ # Shapes have changed due to Q pooling
+ window_size = self.window_size // self.q_stride[0]
+ H, W = shortcut.shape[1:3]
+
+ pad_h = (window_size - H % window_size) % window_size
+ pad_w = (window_size - W % window_size) % window_size
+ pad_hw = (H + pad_h, W + pad_w)
+
+ # Reverse window partition
+ if self.window_size > 0:
+ x = window_unpartition(x, window_size, pad_hw, (H, W))
+
+ x = shortcut + self.drop_path(x)
+ # MLP
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ return x
+
+
+class Hiera(nn.Module):
+ """
+ Reference: https://arxiv.org/abs/2306.00989
+ """
+
+ def __init__(
+ self,
+ embed_dim: int = 96, # initial embed dim
+ num_heads: int = 1, # initial number of heads
+ drop_path_rate: float = 0.0, # stochastic depth
+ q_pool: int = 3, # number of q_pool stages
+ q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages
+ stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage
+ dim_mul: float = 2.0, # dim_mul factor at stage shift
+ head_mul: float = 2.0, # head_mul factor at stage shift
+ window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14),
+ # window size per stage, when not using global att.
+ window_spec: Tuple[int, ...] = (
+ 8,
+ 4,
+ 14,
+ 7,
+ ),
+ # global attn in these blocks
+ global_att_blocks: Tuple[int, ...] = (
+ 12,
+ 16,
+ 20,
+ ),
+ weights_path=None,
+ return_interm_layers=True, # return feats from every stage
+ ):
+ super().__init__()
+
+ assert len(stages) == len(window_spec)
+ self.window_spec = window_spec
+
+ depth = sum(stages)
+ self.q_stride = q_stride
+ self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)]
+ assert 0 <= q_pool <= len(self.stage_ends[:-1])
+ self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool]
+ self.return_interm_layers = return_interm_layers
+
+ self.patch_embed = PatchEmbed(
+ embed_dim=embed_dim,
+ )
+ # Which blocks have global att?
+ self.global_att_blocks = global_att_blocks
+
+ # Windowed positional embedding (https://arxiv.org/abs/2311.05613)
+ self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size
+ self.pos_embed = nn.Parameter(
+ torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size)
+ )
+ self.pos_embed_window = nn.Parameter(
+ torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0])
+ )
+
+ dpr = [
+ x.item() for x in torch.linspace(0, drop_path_rate, depth)
+ ] # stochastic depth decay rule
+
+ cur_stage = 1
+ self.blocks = nn.ModuleList()
+
+ for i in range(depth):
+ dim_out = embed_dim
+ # lags by a block, so first block of
+ # next stage uses an initial window size
+ # of previous stage and final window size of current stage
+ window_size = self.window_spec[cur_stage - 1]
+
+ if self.global_att_blocks is not None:
+ window_size = 0 if i in self.global_att_blocks else window_size
+
+ if i - 1 in self.stage_ends:
+ dim_out = int(embed_dim * dim_mul)
+ num_heads = int(num_heads * head_mul)
+ cur_stage += 1
+
+ block = MultiScaleBlock(
+ dim=embed_dim,
+ dim_out=dim_out,
+ num_heads=num_heads,
+ drop_path=dpr[i],
+ q_stride=self.q_stride if i in self.q_pool_blocks else None,
+ window_size=window_size,
+ )
+
+ embed_dim = dim_out
+ self.blocks.append(block)
+
+ self.channel_list = (
+ [self.blocks[i].dim_out for i in self.stage_ends[::-1]]
+ if return_interm_layers
+ else [self.blocks[-1].dim_out]
+ )
+
+ if weights_path is not None:
+ with g_pathmgr.open(weights_path, "rb") as f:
+ chkpt = torch.load(f, map_location="cpu")
+ logging.info("loading Hiera", self.load_state_dict(chkpt, strict=False))
+
+ def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor:
+ h, w = hw
+ window_embed = self.pos_embed_window
+ pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
+ pos_embed = pos_embed + window_embed.tile(
+ [x // y for x, y in zip(pos_embed.shape, window_embed.shape)]
+ )
+ pos_embed = pos_embed.permute(0, 2, 3, 1)
+ return pos_embed
+
+ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
+ x = self.patch_embed(x)
+ # x: (B, H, W, C)
+
+ # Add pos embed
+ x = x + self._get_pos_embed(x.shape[1:3])
+
+ outputs = []
+ for i, blk in enumerate(self.blocks):
+ x = blk(x)
+ if (i == self.stage_ends[-1]) or (
+ i in self.stage_ends and self.return_interm_layers
+ ):
+ feats = x.permute(0, 3, 1, 2)
+ outputs.append(feats)
+
+ return outputs
+
+ def get_layer_id(self, layer_name):
+ # https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
+ num_layers = self.get_num_layers()
+
+ if layer_name.find("rel_pos") != -1:
+ return num_layers + 1
+ elif layer_name.find("pos_embed") != -1:
+ return 0
+ elif layer_name.find("patch_embed") != -1:
+ return 0
+ elif layer_name.find("blocks") != -1:
+ return int(layer_name.split("blocks")[1].split(".")[1]) + 1
+ else:
+ return num_layers + 1
+
+ def get_num_layers(self) -> int:
+ return len(self.blocks)
diff --git a/libs/sam2/modeling/backbones/image_encoder.py b/libs/sam2/modeling/backbones/image_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..37e9266bc98596e97ca303118c910ed24f6cee2c
--- /dev/null
+++ b/libs/sam2/modeling/backbones/image_encoder.py
@@ -0,0 +1,134 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import List, Optional
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class ImageEncoder(nn.Module):
+ def __init__(
+ self,
+ trunk: nn.Module,
+ neck: nn.Module,
+ scalp: int = 0,
+ ):
+ super().__init__()
+ self.trunk = trunk
+ self.neck = neck
+ self.scalp = scalp
+ assert (
+ self.trunk.channel_list == self.neck.backbone_channel_list
+ ), f"Channel dims of trunk and neck do not match. Trunk: {self.trunk.channel_list}, neck: {self.neck.backbone_channel_list}"
+
+ def forward(self, sample: torch.Tensor):
+ # Forward through backbone
+ features, pos = self.neck(self.trunk(sample))
+ if self.scalp > 0:
+ # Discard the lowest resolution features
+ features, pos = features[: -self.scalp], pos[: -self.scalp]
+
+ src = features[-1]
+ output = {
+ "vision_features": src,
+ "vision_pos_enc": pos,
+ "backbone_fpn": features,
+ }
+ return output
+
+
+class FpnNeck(nn.Module):
+ """
+ A modified variant of Feature Pyramid Network (FPN) neck
+ (we remove output conv and also do bicubic interpolation similar to ViT
+ pos embed interpolation)
+ """
+
+ def __init__(
+ self,
+ position_encoding: nn.Module,
+ d_model: int,
+ backbone_channel_list: List[int],
+ kernel_size: int = 1,
+ stride: int = 1,
+ padding: int = 0,
+ fpn_interp_model: str = "bilinear",
+ fuse_type: str = "sum",
+ fpn_top_down_levels: Optional[List[int]] = None,
+ ):
+ """Initialize the neck
+ :param trunk: the backbone
+ :param position_encoding: the positional encoding to use
+ :param d_model: the dimension of the model
+ :param neck_norm: the normalization to use
+ """
+ super().__init__()
+ self.position_encoding = position_encoding
+ self.convs = nn.ModuleList()
+ self.backbone_channel_list = backbone_channel_list
+ self.d_model = d_model
+ for dim in backbone_channel_list:
+ current = nn.Sequential()
+ current.add_module(
+ "conv",
+ nn.Conv2d(
+ in_channels=dim,
+ out_channels=d_model,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ ),
+ )
+
+ self.convs.append(current)
+ self.fpn_interp_model = fpn_interp_model
+ assert fuse_type in ["sum", "avg"]
+ self.fuse_type = fuse_type
+
+ # levels to have top-down features in its outputs
+ # e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3
+ # have top-down propagation, while outputs of level 0 and level 1 have only
+ # lateral features from the same backbone level.
+ if fpn_top_down_levels is None:
+ # default is to have top-down features on all levels
+ fpn_top_down_levels = range(len(self.convs))
+ self.fpn_top_down_levels = list(fpn_top_down_levels)
+
+ def forward(self, xs: List[torch.Tensor]):
+
+ out = [None] * len(self.convs)
+ pos = [None] * len(self.convs)
+ assert len(xs) == len(self.convs)
+ # fpn forward pass
+ # see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py
+ prev_features = None
+ # forward in top-down order (from low to high resolution)
+ n = len(self.convs) - 1
+ for i in range(n, -1, -1):
+ x = xs[i]
+ lateral_features = self.convs[n - i](x)
+ if i in self.fpn_top_down_levels and prev_features is not None:
+ top_down_features = F.interpolate(
+ prev_features.to(dtype=torch.float32),
+ scale_factor=2.0,
+ mode=self.fpn_interp_model,
+ align_corners=(
+ None if self.fpn_interp_model == "nearest" else False
+ ),
+ antialias=False,
+ )
+ prev_features = lateral_features + top_down_features
+ if self.fuse_type == "avg":
+ prev_features /= 2
+ else:
+ prev_features = lateral_features
+ x_out = prev_features
+ out[i] = x_out
+ pos[i] = self.position_encoding(x_out).to(x_out.dtype)
+
+ return out, pos
diff --git a/libs/sam2/modeling/backbones/utils.py b/libs/sam2/modeling/backbones/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..32d55c7545f064de133a5ff0200ba1ece9b504b7
--- /dev/null
+++ b/libs/sam2/modeling/backbones/utils.py
@@ -0,0 +1,95 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""Some utilities for backbones, in particular for windowing"""
+
+from typing import Tuple
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def window_partition(x, window_size):
+ """
+ Partition into non-overlapping windows with padding if needed.
+ Args:
+ x (tensor): input tokens with [B, H, W, C].
+ window_size (int): window size.
+ Returns:
+ windows: windows after partition with [B * num_windows, window_size, window_size, C].
+ (Hp, Wp): padded height and width before partition
+ """
+ B, H, W, C = x.shape
+
+ pad_h = (window_size - H % window_size) % window_size
+ pad_w = (window_size - W % window_size) % window_size
+ if pad_h > 0 or pad_w > 0:
+ x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
+ Hp, Wp = H + pad_h, W + pad_w
+
+ x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
+ windows = (
+ x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
+ )
+ return windows, (Hp, Wp)
+
+
+def window_unpartition(windows, window_size, pad_hw, hw):
+ """
+ Window unpartition into original sequences and removing padding.
+ Args:
+ x (tensor): input tokens with [B * num_windows, window_size, window_size, C].
+ window_size (int): window size.
+ pad_hw (Tuple): padded height and width (Hp, Wp).
+ hw (Tuple): original height and width (H, W) before padding.
+ Returns:
+ x: unpartitioned sequences with [B, H, W, C].
+ """
+ Hp, Wp = pad_hw
+ H, W = hw
+ B = windows.shape[0] // (Hp * Wp // window_size // window_size)
+ x = windows.view(
+ B, Hp // window_size, Wp // window_size, window_size, window_size, -1
+ )
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
+
+ if Hp > H or Wp > W:
+ x = x[:, :H, :W, :].contiguous()
+ return x
+
+
+class PatchEmbed(nn.Module):
+ """
+ Image to Patch Embedding.
+ """
+
+ def __init__(
+ self,
+ kernel_size: Tuple[int, ...] = (7, 7),
+ stride: Tuple[int, ...] = (4, 4),
+ padding: Tuple[int, ...] = (3, 3),
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ ):
+ """
+ Args:
+ kernel_size (Tuple): kernel size of the projection layer.
+ stride (Tuple): stride of the projection layer.
+ padding (Tuple): padding size of the projection layer.
+ in_chans (int): Number of input image channels.
+ embed_dim (int): embed_dim (int): Patch embedding dimension.
+ """
+ super().__init__()
+ self.proj = nn.Conv2d(
+ in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.proj(x)
+ # B C H W -> B H W C
+ x = x.permute(0, 2, 3, 1)
+ return x
diff --git a/libs/sam2/modeling/memory_attention.py b/libs/sam2/modeling/memory_attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b07f9d87e3d8194ca5e11fc20f01604d591a59d
--- /dev/null
+++ b/libs/sam2/modeling/memory_attention.py
@@ -0,0 +1,169 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Optional
+
+import torch
+from torch import nn, Tensor
+
+from sam2.modeling.sam.transformer import RoPEAttention
+
+from sam2.modeling.sam2_utils import get_activation_fn, get_clones
+
+
+class MemoryAttentionLayer(nn.Module):
+
+ def __init__(
+ self,
+ activation: str,
+ cross_attention: nn.Module,
+ d_model: int,
+ dim_feedforward: int,
+ dropout: float,
+ pos_enc_at_attn: bool,
+ pos_enc_at_cross_attn_keys: bool,
+ pos_enc_at_cross_attn_queries: bool,
+ self_attention: nn.Module,
+ ):
+ super().__init__()
+ self.d_model = d_model
+ self.dim_feedforward = dim_feedforward
+ self.dropout_value = dropout
+ self.self_attn = self_attention
+ self.cross_attn_image = cross_attention
+
+ # Implementation of Feedforward model
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
+ self.dropout = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
+
+ self.norm1 = nn.LayerNorm(d_model)
+ self.norm2 = nn.LayerNorm(d_model)
+ self.norm3 = nn.LayerNorm(d_model)
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(dropout)
+ self.dropout3 = nn.Dropout(dropout)
+
+ self.activation_str = activation
+ self.activation = get_activation_fn(activation)
+
+ # Where to add pos enc
+ self.pos_enc_at_attn = pos_enc_at_attn
+ self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries
+ self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys
+
+ def _forward_sa(self, tgt, query_pos):
+ # Self-Attention
+ tgt2 = self.norm1(tgt)
+ q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2
+ tgt2 = self.self_attn(q, k, v=tgt2)
+ tgt = tgt + self.dropout1(tgt2)
+ return tgt
+
+ def _forward_ca(self, tgt, memory, query_pos, pos, num_k_exclude_rope=0):
+ kwds = {}
+ if num_k_exclude_rope > 0:
+ assert isinstance(self.cross_attn_image, RoPEAttention)
+ kwds = {"num_k_exclude_rope": num_k_exclude_rope}
+
+ # Cross-Attention
+ tgt2 = self.norm2(tgt)
+ tgt2 = self.cross_attn_image(
+ q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2,
+ k=memory + pos if self.pos_enc_at_cross_attn_keys else memory,
+ v=memory,
+ **kwds,
+ )
+ tgt = tgt + self.dropout2(tgt2)
+ return tgt
+
+ def forward(
+ self,
+ tgt,
+ memory,
+ pos: Optional[Tensor] = None,
+ query_pos: Optional[Tensor] = None,
+ num_k_exclude_rope: int = 0,
+ ) -> torch.Tensor:
+
+ # Self-Attn, Cross-Attn
+ tgt = self._forward_sa(tgt, query_pos)
+ tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope)
+ # MLP
+ tgt2 = self.norm3(tgt)
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
+ tgt = tgt + self.dropout3(tgt2)
+ return tgt
+
+
+class MemoryAttention(nn.Module):
+ def __init__(
+ self,
+ d_model: int,
+ pos_enc_at_input: bool,
+ layer: nn.Module,
+ num_layers: int,
+ batch_first: bool = True, # Do layers expect batch first input?
+ ):
+ super().__init__()
+ self.d_model = d_model
+ self.layers = get_clones(layer, num_layers)
+ self.num_layers = num_layers
+ self.norm = nn.LayerNorm(d_model)
+ self.pos_enc_at_input = pos_enc_at_input
+ self.batch_first = batch_first
+
+ def forward(
+ self,
+ curr: torch.Tensor, # self-attention inputs
+ memory: torch.Tensor, # cross-attention inputs
+ curr_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs
+ memory_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs
+ num_obj_ptr_tokens: int = 0, # number of object pointer *tokens*
+ ):
+ if isinstance(curr, list):
+ assert isinstance(curr_pos, list)
+ assert len(curr) == len(curr_pos) == 1
+ curr, curr_pos = (
+ curr[0],
+ curr_pos[0],
+ )
+
+ assert (
+ curr.shape[1] == memory.shape[1]
+ ), "Batch size must be the same for curr and memory"
+
+ output = curr
+ if self.pos_enc_at_input and curr_pos is not None:
+ output = output + 0.1 * curr_pos
+
+ if self.batch_first:
+ # Convert to batch first
+ output = output.transpose(0, 1)
+ curr_pos = curr_pos.transpose(0, 1)
+ memory = memory.transpose(0, 1)
+ memory_pos = memory_pos.transpose(0, 1)
+
+ for layer in self.layers:
+ kwds = {}
+ if isinstance(layer.cross_attn_image, RoPEAttention):
+ kwds = {"num_k_exclude_rope": num_obj_ptr_tokens}
+
+ output = layer(
+ tgt=output,
+ memory=memory,
+ pos=memory_pos,
+ query_pos=curr_pos,
+ **kwds,
+ )
+ normed_output = self.norm(output)
+
+ if self.batch_first:
+ # Convert back to seq first
+ normed_output = normed_output.transpose(0, 1)
+ curr_pos = curr_pos.transpose(0, 1)
+
+ return normed_output
diff --git a/libs/sam2/modeling/memory_encoder.py b/libs/sam2/modeling/memory_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..f60202dfaba87232c3870fb2101b5322a119d985
--- /dev/null
+++ b/libs/sam2/modeling/memory_encoder.py
@@ -0,0 +1,181 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+from typing import Tuple
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from sam2.modeling.sam2_utils import DropPath, get_clones, LayerNorm2d
+
+
+class MaskDownSampler(nn.Module):
+ """
+ Progressively downsample a mask by total_stride, each time by stride.
+ Note that LayerNorm is applied per *token*, like in ViT.
+
+ With each downsample (by a factor stride**2), channel capacity increases by the same factor.
+ In the end, we linearly project to embed_dim channels.
+ """
+
+ def __init__(
+ self,
+ embed_dim=256,
+ kernel_size=4,
+ stride=4,
+ padding=0,
+ total_stride=16,
+ activation=nn.GELU,
+ ):
+ super().__init__()
+ num_layers = int(math.log2(total_stride) // math.log2(stride))
+ assert stride**num_layers == total_stride
+ self.encoder = nn.Sequential()
+ mask_in_chans, mask_out_chans = 1, 1
+ for _ in range(num_layers):
+ mask_out_chans = mask_in_chans * (stride**2)
+ self.encoder.append(
+ nn.Conv2d(
+ mask_in_chans,
+ mask_out_chans,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ )
+ )
+ self.encoder.append(LayerNorm2d(mask_out_chans))
+ self.encoder.append(activation())
+ mask_in_chans = mask_out_chans
+
+ self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1))
+
+ def forward(self, x):
+ return self.encoder(x)
+
+
+# Lightly adapted from ConvNext (https://github.com/facebookresearch/ConvNeXt)
+class CXBlock(nn.Module):
+ r"""ConvNeXt Block. There are two equivalent implementations:
+ (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
+ (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
+ We use (2) as we find it slightly faster in PyTorch
+
+ Args:
+ dim (int): Number of input channels.
+ drop_path (float): Stochastic depth rate. Default: 0.0
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
+ """
+
+ def __init__(
+ self,
+ dim,
+ kernel_size=7,
+ padding=3,
+ drop_path=0.0,
+ layer_scale_init_value=1e-6,
+ use_dwconv=True,
+ ):
+ super().__init__()
+ self.dwconv = nn.Conv2d(
+ dim,
+ dim,
+ kernel_size=kernel_size,
+ padding=padding,
+ groups=dim if use_dwconv else 1,
+ ) # depthwise conv
+ self.norm = LayerNorm2d(dim, eps=1e-6)
+ self.pwconv1 = nn.Linear(
+ dim, 4 * dim
+ ) # pointwise/1x1 convs, implemented with linear layers
+ self.act = nn.GELU()
+ self.pwconv2 = nn.Linear(4 * dim, dim)
+ self.gamma = (
+ nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
+ if layer_scale_init_value > 0
+ else None
+ )
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ def forward(self, x):
+ input = x
+ x = self.dwconv(x)
+ x = self.norm(x)
+ x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
+ x = self.pwconv1(x)
+ x = self.act(x)
+ x = self.pwconv2(x)
+ if self.gamma is not None:
+ x = self.gamma * x
+ x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
+
+ x = input + self.drop_path(x)
+ return x
+
+
+class Fuser(nn.Module):
+ def __init__(self, layer, num_layers, dim=None, input_projection=False):
+ super().__init__()
+ self.proj = nn.Identity()
+ self.layers = get_clones(layer, num_layers)
+
+ if input_projection:
+ assert dim is not None
+ self.proj = nn.Conv2d(dim, dim, kernel_size=1)
+
+ def forward(self, x):
+ # normally x: (N, C, H, W)
+ x = self.proj(x)
+ for layer in self.layers:
+ x = layer(x)
+ return x
+
+
+class MemoryEncoder(nn.Module):
+ def __init__(
+ self,
+ out_dim,
+ mask_downsampler,
+ fuser,
+ position_encoding,
+ in_dim=256, # in_dim of pix_feats
+ ):
+ super().__init__()
+
+ self.mask_downsampler = mask_downsampler
+
+ self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1)
+ self.fuser = fuser
+ self.position_encoding = position_encoding
+ self.out_proj = nn.Identity()
+ if out_dim != in_dim:
+ self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1)
+
+ def forward(
+ self,
+ pix_feat: torch.Tensor,
+ masks: torch.Tensor,
+ skip_mask_sigmoid: bool = False,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ ## Process masks
+ # sigmoid, so that less domain shift from gt masks which are bool
+ if not skip_mask_sigmoid:
+ masks = F.sigmoid(masks)
+ masks = self.mask_downsampler(masks)
+
+ ## Fuse pix_feats and downsampled masks
+ # in case the visual features are on CPU, cast them to CUDA
+ pix_feat = pix_feat.to(masks.device)
+
+ x = self.pix_feat_proj(pix_feat)
+ x = x + masks
+ x = self.fuser(x)
+ x = self.out_proj(x)
+
+ pos = self.position_encoding(x).to(x.dtype)
+
+ return {"vision_features": x, "vision_pos_enc": [pos]}
diff --git a/libs/sam2/modeling/position_encoding.py b/libs/sam2/modeling/position_encoding.py
new file mode 100644
index 0000000000000000000000000000000000000000..52ac22674d5d4fdd9e83b6bdf034bff56d04bc0d
--- /dev/null
+++ b/libs/sam2/modeling/position_encoding.py
@@ -0,0 +1,221 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+from typing import Any, Optional, Tuple
+
+import numpy as np
+
+import torch
+from torch import nn
+
+
+class PositionEmbeddingSine(nn.Module):
+ """
+ This is a more standard version of the position embedding, very similar to the one
+ used by the Attention Is All You Need paper, generalized to work on images.
+ """
+
+ def __init__(
+ self,
+ num_pos_feats,
+ temperature: int = 10000,
+ normalize: bool = True,
+ scale: Optional[float] = None,
+ ):
+ super().__init__()
+ assert num_pos_feats % 2 == 0, "Expecting even model width"
+ self.num_pos_feats = num_pos_feats // 2
+ self.temperature = temperature
+ self.normalize = normalize
+ if scale is not None and normalize is False:
+ raise ValueError("normalize should be True if scale is passed")
+ if scale is None:
+ scale = 2 * math.pi
+ self.scale = scale
+
+ self.cache = {}
+
+ def _encode_xy(self, x, y):
+ # The positions are expected to be normalized
+ assert len(x) == len(y) and x.ndim == y.ndim == 1
+ x_embed = x * self.scale
+ y_embed = y * self.scale
+
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
+
+ pos_x = x_embed[:, None] / dim_t
+ pos_y = y_embed[:, None] / dim_t
+ pos_x = torch.stack(
+ (pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2
+ ).flatten(1)
+ pos_y = torch.stack(
+ (pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2
+ ).flatten(1)
+ return pos_x, pos_y
+
+ @torch.no_grad()
+ def encode_boxes(self, x, y, w, h):
+ pos_x, pos_y = self._encode_xy(x, y)
+ pos = torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1)
+ return pos
+
+ encode = encode_boxes # Backwards compatibility
+
+ @torch.no_grad()
+ def encode_points(self, x, y, labels):
+ (bx, nx), (by, ny), (bl, nl) = x.shape, y.shape, labels.shape
+ assert bx == by and nx == ny and bx == bl and nx == nl
+ pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten())
+ pos_x, pos_y = pos_x.reshape(bx, nx, -1), pos_y.reshape(by, ny, -1)
+ pos = torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2)
+ return pos
+
+ @torch.no_grad()
+ def forward(self, x: torch.Tensor):
+ cache_key = (x.shape[-2], x.shape[-1])
+ if cache_key in self.cache:
+ return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1)
+ y_embed = (
+ torch.arange(1, x.shape[-2] + 1, dtype=torch.float32, device=x.device)
+ .view(1, -1, 1)
+ .repeat(x.shape[0], 1, x.shape[-1])
+ )
+ x_embed = (
+ torch.arange(1, x.shape[-1] + 1, dtype=torch.float32, device=x.device)
+ .view(1, 1, -1)
+ .repeat(x.shape[0], x.shape[-2], 1)
+ )
+
+ if self.normalize:
+ eps = 1e-6
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
+
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
+
+ pos_x = x_embed[:, :, :, None] / dim_t
+ pos_y = y_embed[:, :, :, None] / dim_t
+ pos_x = torch.stack(
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
+ ).flatten(3)
+ pos_y = torch.stack(
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
+ ).flatten(3)
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
+ self.cache[cache_key] = pos[0]
+ return pos
+
+
+class PositionEmbeddingRandom(nn.Module):
+ """
+ Positional encoding using random spatial frequencies.
+ """
+
+ def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
+ super().__init__()
+ if scale is None or scale <= 0.0:
+ scale = 1.0
+ self.register_buffer(
+ "positional_encoding_gaussian_matrix",
+ scale * torch.randn((2, num_pos_feats)),
+ )
+
+ def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
+ """Positionally encode points that are normalized to [0,1]."""
+ # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
+ coords = 2 * coords - 1
+ coords = coords @ self.positional_encoding_gaussian_matrix
+ coords = 2 * np.pi * coords
+ # outputs d_1 x ... x d_n x C shape
+ return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
+
+ def forward(self, size: Tuple[int, int]) -> torch.Tensor:
+ """Generate positional encoding for a grid of the specified size."""
+ h, w = size
+ device: Any = self.positional_encoding_gaussian_matrix.device
+ grid = torch.ones((h, w), device=device, dtype=torch.float32)
+ y_embed = grid.cumsum(dim=0) - 0.5
+ x_embed = grid.cumsum(dim=1) - 0.5
+ y_embed = y_embed / h
+ x_embed = x_embed / w
+
+ pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
+ return pe.permute(2, 0, 1) # C x H x W
+
+ def forward_with_coords(
+ self, coords_input: torch.Tensor, image_size: Tuple[int, int]
+ ) -> torch.Tensor:
+ """Positionally encode points that are not normalized to [0,1]."""
+ coords = coords_input.clone()
+ coords[:, :, 0] = coords[:, :, 0] / image_size[1]
+ coords[:, :, 1] = coords[:, :, 1] / image_size[0]
+ return self._pe_encoding(coords.to(torch.float)) # B x N x C
+
+
+# Rotary Positional Encoding, adapted from:
+# 1. https://github.com/meta-llama/codellama/blob/main/llama/model.py
+# 2. https://github.com/naver-ai/rope-vit
+# 3. https://github.com/lucidrains/rotary-embedding-torch
+
+
+def init_t_xy(end_x: int, end_y: int):
+ t = torch.arange(end_x * end_y, dtype=torch.float32)
+ t_x = (t % end_x).float()
+ t_y = torch.div(t, end_x, rounding_mode="floor").float()
+ return t_x, t_y
+
+
+def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
+ freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
+ freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
+
+ t_x, t_y = init_t_xy(end_x, end_y)
+ freqs_x = torch.outer(t_x, freqs_x)
+ freqs_y = torch.outer(t_y, freqs_y)
+ freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x)
+ freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y)
+ return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1)
+
+
+def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
+ ndim = x.ndim
+ assert 0 <= 1 < ndim
+ assert freqs_cis.shape == (x.shape[-2], x.shape[-1])
+ shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)]
+ return freqs_cis.view(*shape)
+
+
+def apply_rotary_enc(
+ xq: torch.Tensor,
+ xk: torch.Tensor,
+ freqs_cis: torch.Tensor,
+ repeat_freqs_k: bool = False,
+):
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
+ xk_ = (
+ torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
+ if xk.shape[-2] != 0
+ else None
+ )
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
+ if xk_ is None:
+ # no keys to rotate, due to dropout
+ return xq_out.type_as(xq).to(xq.device), xk
+ # repeat freqs along seq_len dim to match k seq_len
+ if repeat_freqs_k:
+ r = xk_.shape[-2] // xq_.shape[-2]
+ if freqs_cis.is_cuda:
+ freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
+ else:
+ # torch.repeat on complex numbers may not be supported on non-CUDA devices
+ # (freqs_cis has 4 dims and we repeat on dim 2) so we use expand + flatten
+ freqs_cis = freqs_cis.unsqueeze(2).expand(-1, -1, r, -1, -1).flatten(2, 3)
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
+ return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)
diff --git a/libs/sam2/modeling/sam/__init__.py b/libs/sam2/modeling/sam/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5277f46157403e47fd830fc519144b97ef69d4ae
--- /dev/null
+++ b/libs/sam2/modeling/sam/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
diff --git a/libs/sam2/modeling/sam/mask_decoder.py b/libs/sam2/modeling/sam/mask_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..9bebc0366b2703ffcb80a44bfd19cce8339b4fed
--- /dev/null
+++ b/libs/sam2/modeling/sam/mask_decoder.py
@@ -0,0 +1,295 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import List, Optional, Tuple, Type
+
+import torch
+from torch import nn
+
+from sam2.modeling.sam2_utils import LayerNorm2d, MLP
+
+
+class MaskDecoder(nn.Module):
+ def __init__(
+ self,
+ *,
+ transformer_dim: int,
+ transformer: nn.Module,
+ num_multimask_outputs: int = 3,
+ activation: Type[nn.Module] = nn.GELU,
+ iou_head_depth: int = 3,
+ iou_head_hidden_dim: int = 256,
+ use_high_res_features: bool = False,
+ iou_prediction_use_sigmoid=False,
+ dynamic_multimask_via_stability=False,
+ dynamic_multimask_stability_delta=0.05,
+ dynamic_multimask_stability_thresh=0.98,
+ pred_obj_scores: bool = False,
+ pred_obj_scores_mlp: bool = False,
+ use_multimask_token_for_obj_ptr: bool = False,
+ ) -> None:
+ """
+ Predicts masks given an image and prompt embeddings, using a
+ transformer architecture.
+
+ Arguments:
+ transformer_dim (int): the channel dimension of the transformer
+ transformer (nn.Module): the transformer used to predict masks
+ num_multimask_outputs (int): the number of masks to predict
+ when disambiguating masks
+ activation (nn.Module): the type of activation to use when
+ upscaling masks
+ iou_head_depth (int): the depth of the MLP used to predict
+ mask quality
+ iou_head_hidden_dim (int): the hidden dimension of the MLP
+ used to predict mask quality
+ """
+ super().__init__()
+ self.transformer_dim = transformer_dim
+ self.transformer = transformer
+
+ self.num_multimask_outputs = num_multimask_outputs
+
+ self.iou_token = nn.Embedding(1, transformer_dim)
+ self.num_mask_tokens = num_multimask_outputs + 1
+ self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
+
+ self.pred_obj_scores = pred_obj_scores
+ if self.pred_obj_scores:
+ self.obj_score_token = nn.Embedding(1, transformer_dim)
+ self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr
+
+ self.output_upscaling = nn.Sequential(
+ nn.ConvTranspose2d(
+ transformer_dim, transformer_dim // 4, kernel_size=2, stride=2
+ ),
+ LayerNorm2d(transformer_dim // 4),
+ activation(),
+ nn.ConvTranspose2d(
+ transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2
+ ),
+ activation(),
+ )
+ self.use_high_res_features = use_high_res_features
+ if use_high_res_features:
+ self.conv_s0 = nn.Conv2d(
+ transformer_dim, transformer_dim // 8, kernel_size=1, stride=1
+ )
+ self.conv_s1 = nn.Conv2d(
+ transformer_dim, transformer_dim // 4, kernel_size=1, stride=1
+ )
+
+ self.output_hypernetworks_mlps = nn.ModuleList(
+ [
+ MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
+ for i in range(self.num_mask_tokens)
+ ]
+ )
+
+ self.iou_prediction_head = MLP(
+ transformer_dim,
+ iou_head_hidden_dim,
+ self.num_mask_tokens,
+ iou_head_depth,
+ sigmoid_output=iou_prediction_use_sigmoid,
+ )
+ if self.pred_obj_scores:
+ self.pred_obj_score_head = nn.Linear(transformer_dim, 1)
+ if pred_obj_scores_mlp:
+ self.pred_obj_score_head = MLP(transformer_dim, transformer_dim, 1, 3)
+
+ # When outputting a single mask, optionally we can dynamically fall back to the best
+ # multimask output token if the single mask output token gives low stability scores.
+ self.dynamic_multimask_via_stability = dynamic_multimask_via_stability
+ self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta
+ self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh
+
+ def forward(
+ self,
+ image_embeddings: torch.Tensor,
+ image_pe: torch.Tensor,
+ sparse_prompt_embeddings: torch.Tensor,
+ dense_prompt_embeddings: torch.Tensor,
+ multimask_output: bool,
+ repeat_image: bool,
+ high_res_features: Optional[List[torch.Tensor]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Predict masks given image and prompt embeddings.
+
+ Arguments:
+ image_embeddings (torch.Tensor): the embeddings from the image encoder
+ image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
+ sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
+ dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
+ multimask_output (bool): Whether to return multiple masks or a single
+ mask.
+
+ Returns:
+ torch.Tensor: batched predicted masks
+ torch.Tensor: batched predictions of mask quality
+ torch.Tensor: batched SAM token for mask output
+ """
+ masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks(
+ image_embeddings=image_embeddings,
+ image_pe=image_pe,
+ sparse_prompt_embeddings=sparse_prompt_embeddings,
+ dense_prompt_embeddings=dense_prompt_embeddings,
+ repeat_image=repeat_image,
+ high_res_features=high_res_features,
+ )
+
+ # Select the correct mask or masks for output
+ if multimask_output:
+ masks = masks[:, 1:, :, :]
+ iou_pred = iou_pred[:, 1:]
+ elif self.dynamic_multimask_via_stability and not self.training:
+ masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred)
+ else:
+ masks = masks[:, 0:1, :, :]
+ iou_pred = iou_pred[:, 0:1]
+
+ if multimask_output and self.use_multimask_token_for_obj_ptr:
+ sam_tokens_out = mask_tokens_out[:, 1:] # [b, 3, c] shape
+ else:
+ # Take the mask output token. Here we *always* use the token for single mask output.
+ # At test time, even if we track after 1-click (and using multimask_output=True),
+ # we still take the single mask token here. The rationale is that we always track
+ # after multiple clicks during training, so the past tokens seen during training
+ # are always the single mask token (and we'll let it be the object-memory token).
+ sam_tokens_out = mask_tokens_out[:, 0:1] # [b, 1, c] shape
+
+ # Prepare output
+ return masks, iou_pred, sam_tokens_out, object_score_logits
+
+ def predict_masks(
+ self,
+ image_embeddings: torch.Tensor,
+ image_pe: torch.Tensor,
+ sparse_prompt_embeddings: torch.Tensor,
+ dense_prompt_embeddings: torch.Tensor,
+ repeat_image: bool,
+ high_res_features: Optional[List[torch.Tensor]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Predicts masks. See 'forward' for more details."""
+ # Concatenate output tokens
+ s = 0
+ if self.pred_obj_scores:
+ output_tokens = torch.cat(
+ [
+ self.obj_score_token.weight,
+ self.iou_token.weight,
+ self.mask_tokens.weight,
+ ],
+ dim=0,
+ )
+ s = 1
+ else:
+ output_tokens = torch.cat(
+ [self.iou_token.weight, self.mask_tokens.weight], dim=0
+ )
+ output_tokens = output_tokens.unsqueeze(0).expand(
+ sparse_prompt_embeddings.size(0), -1, -1
+ )
+ tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
+
+ # Expand per-image data in batch direction to be per-mask
+ if repeat_image:
+ src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
+ else:
+ assert image_embeddings.shape[0] == tokens.shape[0]
+ src = image_embeddings
+ src = src + dense_prompt_embeddings
+ assert (
+ image_pe.size(0) == 1
+ ), "image_pe should have size 1 in batch dim (from `get_dense_pe()`)"
+ pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
+ b, c, h, w = src.shape
+
+ # Run the transformer
+ hs, src = self.transformer(src, pos_src, tokens)
+ iou_token_out = hs[:, s, :]
+ mask_tokens_out = hs[:, s + 1 : (s + 1 + self.num_mask_tokens), :]
+
+ # Upscale mask embeddings and predict masks using the mask tokens
+ src = src.transpose(1, 2).view(b, c, h, w)
+ if not self.use_high_res_features:
+ upscaled_embedding = self.output_upscaling(src)
+ else:
+ dc1, ln1, act1, dc2, act2 = self.output_upscaling
+ feat_s0, feat_s1 = high_res_features
+ upscaled_embedding = act1(ln1(dc1(src) + feat_s1))
+ upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0)
+
+ hyper_in_list: List[torch.Tensor] = []
+ for i in range(self.num_mask_tokens):
+ hyper_in_list.append(
+ self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])
+ )
+ hyper_in = torch.stack(hyper_in_list, dim=1)
+ b, c, h, w = upscaled_embedding.shape
+ masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
+
+ # Generate mask quality predictions
+ iou_pred = self.iou_prediction_head(iou_token_out)
+ if self.pred_obj_scores:
+ assert s == 1
+ object_score_logits = self.pred_obj_score_head(hs[:, 0, :])
+ else:
+ # Obj scores logits - default to 10.0, i.e. assuming the object is present, sigmoid(10)=1
+ object_score_logits = 10.0 * iou_pred.new_ones(iou_pred.shape[0], 1)
+
+ return masks, iou_pred, mask_tokens_out, object_score_logits
+
+ def _get_stability_scores(self, mask_logits):
+ """
+ Compute stability scores of the mask logits based on the IoU between upper and
+ lower thresholds.
+ """
+ mask_logits = mask_logits.flatten(-2)
+ stability_delta = self.dynamic_multimask_stability_delta
+ area_i = torch.sum(mask_logits > stability_delta, dim=-1).float()
+ area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float()
+ stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0)
+ return stability_scores
+
+ def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores):
+ """
+ When outputting a single mask, if the stability score from the current single-mask
+ output (based on output token 0) falls below a threshold, we instead select from
+ multi-mask outputs (based on output token 1~3) the mask with the highest predicted
+ IoU score. This is intended to ensure a valid mask for both clicking and tracking.
+ """
+ # The best mask from multimask output tokens (1~3)
+ multimask_logits = all_mask_logits[:, 1:, :, :]
+ multimask_iou_scores = all_iou_scores[:, 1:]
+ best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1)
+ batch_inds = torch.arange(
+ multimask_iou_scores.size(0), device=all_iou_scores.device
+ )
+ best_multimask_logits = multimask_logits[batch_inds, best_scores_inds]
+ best_multimask_logits = best_multimask_logits.unsqueeze(1)
+ best_multimask_iou_scores = multimask_iou_scores[batch_inds, best_scores_inds]
+ best_multimask_iou_scores = best_multimask_iou_scores.unsqueeze(1)
+
+ # The mask from singlemask output token 0 and its stability score
+ singlemask_logits = all_mask_logits[:, 0:1, :, :]
+ singlemask_iou_scores = all_iou_scores[:, 0:1]
+ stability_scores = self._get_stability_scores(singlemask_logits)
+ is_stable = stability_scores >= self.dynamic_multimask_stability_thresh
+
+ # Dynamically fall back to best multimask output upon low stability scores.
+ mask_logits_out = torch.where(
+ is_stable[..., None, None].expand_as(singlemask_logits),
+ singlemask_logits,
+ best_multimask_logits,
+ )
+ iou_scores_out = torch.where(
+ is_stable.expand_as(singlemask_iou_scores),
+ singlemask_iou_scores,
+ best_multimask_iou_scores,
+ )
+ return mask_logits_out, iou_scores_out
diff --git a/libs/sam2/modeling/sam/prompt_encoder.py b/libs/sam2/modeling/sam/prompt_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b3bbb95be0aea9c88f49f586ac959a9fda1b18b
--- /dev/null
+++ b/libs/sam2/modeling/sam/prompt_encoder.py
@@ -0,0 +1,182 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Optional, Tuple, Type
+
+import torch
+from torch import nn
+
+from sam2.modeling.position_encoding import PositionEmbeddingRandom
+
+from sam2.modeling.sam2_utils import LayerNorm2d
+
+
+class PromptEncoder(nn.Module):
+ def __init__(
+ self,
+ embed_dim: int,
+ image_embedding_size: Tuple[int, int],
+ input_image_size: Tuple[int, int],
+ mask_in_chans: int,
+ activation: Type[nn.Module] = nn.GELU,
+ ) -> None:
+ """
+ Encodes prompts for input to SAM's mask decoder.
+
+ Arguments:
+ embed_dim (int): The prompts' embedding dimension
+ image_embedding_size (tuple(int, int)): The spatial size of the
+ image embedding, as (H, W).
+ input_image_size (int): The padded size of the image as input
+ to the image encoder, as (H, W).
+ mask_in_chans (int): The number of hidden channels used for
+ encoding input masks.
+ activation (nn.Module): The activation to use when encoding
+ input masks.
+ """
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.input_image_size = input_image_size
+ self.image_embedding_size = image_embedding_size
+ self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
+
+ self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
+ point_embeddings = [
+ nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)
+ ]
+ self.point_embeddings = nn.ModuleList(point_embeddings)
+ self.not_a_point_embed = nn.Embedding(1, embed_dim)
+
+ self.mask_input_size = (
+ 4 * image_embedding_size[0],
+ 4 * image_embedding_size[1],
+ )
+ self.mask_downscaling = nn.Sequential(
+ nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
+ LayerNorm2d(mask_in_chans // 4),
+ activation(),
+ nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
+ LayerNorm2d(mask_in_chans),
+ activation(),
+ nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
+ )
+ self.no_mask_embed = nn.Embedding(1, embed_dim)
+
+ def get_dense_pe(self) -> torch.Tensor:
+ """
+ Returns the positional encoding used to encode point prompts,
+ applied to a dense set of points the shape of the image encoding.
+
+ Returns:
+ torch.Tensor: Positional encoding with shape
+ 1x(embed_dim)x(embedding_h)x(embedding_w)
+ """
+ return self.pe_layer(self.image_embedding_size).unsqueeze(0)
+
+ def _embed_points(
+ self,
+ points: torch.Tensor,
+ labels: torch.Tensor,
+ pad: bool,
+ ) -> torch.Tensor:
+ """Embeds point prompts."""
+ points = points + 0.5 # Shift to center of pixel
+ if pad:
+ padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
+ padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
+ points = torch.cat([points, padding_point], dim=1)
+ labels = torch.cat([labels, padding_label], dim=1)
+ point_embedding = self.pe_layer.forward_with_coords(
+ points, self.input_image_size
+ )
+ point_embedding[labels == -1] = 0.0
+ point_embedding[labels == -1] += self.not_a_point_embed.weight
+ point_embedding[labels == 0] += self.point_embeddings[0].weight
+ point_embedding[labels == 1] += self.point_embeddings[1].weight
+ point_embedding[labels == 2] += self.point_embeddings[2].weight
+ point_embedding[labels == 3] += self.point_embeddings[3].weight
+ return point_embedding
+
+ def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
+ """Embeds box prompts."""
+ boxes = boxes + 0.5 # Shift to center of pixel
+ coords = boxes.reshape(-1, 2, 2)
+ corner_embedding = self.pe_layer.forward_with_coords(
+ coords, self.input_image_size
+ )
+ corner_embedding[:, 0, :] += self.point_embeddings[2].weight
+ corner_embedding[:, 1, :] += self.point_embeddings[3].weight
+ return corner_embedding
+
+ def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
+ """Embeds mask inputs."""
+ mask_embedding = self.mask_downscaling(masks)
+ return mask_embedding
+
+ def _get_batch_size(
+ self,
+ points: Optional[Tuple[torch.Tensor, torch.Tensor]],
+ boxes: Optional[torch.Tensor],
+ masks: Optional[torch.Tensor],
+ ) -> int:
+ """
+ Gets the batch size of the output given the batch size of the input prompts.
+ """
+ if points is not None:
+ return points[0].shape[0]
+ elif boxes is not None:
+ return boxes.shape[0]
+ elif masks is not None:
+ return masks.shape[0]
+ else:
+ return 1
+
+ def _get_device(self) -> torch.device:
+ return self.point_embeddings[0].weight.device
+
+ def forward(
+ self,
+ points: Optional[Tuple[torch.Tensor, torch.Tensor]],
+ boxes: Optional[torch.Tensor],
+ masks: Optional[torch.Tensor],
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Embeds different types of prompts, returning both sparse and dense
+ embeddings.
+
+ Arguments:
+ points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
+ and labels to embed.
+ boxes (torch.Tensor or none): boxes to embed
+ masks (torch.Tensor or none): masks to embed
+
+ Returns:
+ torch.Tensor: sparse embeddings for the points and boxes, with shape
+ BxNx(embed_dim), where N is determined by the number of input points
+ and boxes.
+ torch.Tensor: dense embeddings for the masks, in the shape
+ Bx(embed_dim)x(embed_H)x(embed_W)
+ """
+ bs = self._get_batch_size(points, boxes, masks)
+ sparse_embeddings = torch.empty(
+ (bs, 0, self.embed_dim), device=self._get_device()
+ )
+ if points is not None:
+ coords, labels = points
+ point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
+ sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
+ if boxes is not None:
+ box_embeddings = self._embed_boxes(boxes)
+ sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
+
+ if masks is not None:
+ dense_embeddings = self._embed_masks(masks)
+ else:
+ dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
+ bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
+ )
+
+ return sparse_embeddings, dense_embeddings
diff --git a/libs/sam2/modeling/sam/transformer.py b/libs/sam2/modeling/sam/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..b5b6fa2f87e85a7f222fb2ba0b661734dc57a08a
--- /dev/null
+++ b/libs/sam2/modeling/sam/transformer.py
@@ -0,0 +1,360 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import contextlib
+import math
+import warnings
+from functools import partial
+from typing import Tuple, Type
+
+import torch
+import torch.nn.functional as F
+from torch import nn, Tensor
+
+from sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis
+from sam2.modeling.sam2_utils import MLP
+from sam2.utils.misc import get_sdpa_settings
+
+warnings.simplefilter(action="ignore", category=FutureWarning)
+# Check whether Flash Attention is available (and use it by default)
+OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings()
+# A fallback setting to allow all available kernels if Flash Attention fails
+ALLOW_ALL_KERNELS = False
+
+
+def sdp_kernel_context(dropout_p):
+ """
+ Get the context for the attention scaled dot-product kernel. We use Flash Attention
+ by default, but fall back to all available kernels if Flash Attention fails.
+ """
+ if ALLOW_ALL_KERNELS:
+ return contextlib.nullcontext()
+
+ return torch.backends.cuda.sdp_kernel(
+ enable_flash=USE_FLASH_ATTN,
+ # if Flash attention kernel is off, then math kernel needs to be enabled
+ enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON,
+ enable_mem_efficient=OLD_GPU,
+ )
+
+
+class TwoWayTransformer(nn.Module):
+ def __init__(
+ self,
+ depth: int,
+ embedding_dim: int,
+ num_heads: int,
+ mlp_dim: int,
+ activation: Type[nn.Module] = nn.ReLU,
+ attention_downsample_rate: int = 2,
+ ) -> None:
+ """
+ A transformer decoder that attends to an input image using
+ queries whose positional embedding is supplied.
+
+ Args:
+ depth (int): number of layers in the transformer
+ embedding_dim (int): the channel dimension for the input embeddings
+ num_heads (int): the number of heads for multihead attention. Must
+ divide embedding_dim
+ mlp_dim (int): the channel dimension internal to the MLP block
+ activation (nn.Module): the activation to use in the MLP block
+ """
+ super().__init__()
+ self.depth = depth
+ self.embedding_dim = embedding_dim
+ self.num_heads = num_heads
+ self.mlp_dim = mlp_dim
+ self.layers = nn.ModuleList()
+
+ for i in range(depth):
+ self.layers.append(
+ TwoWayAttentionBlock(
+ embedding_dim=embedding_dim,
+ num_heads=num_heads,
+ mlp_dim=mlp_dim,
+ activation=activation,
+ attention_downsample_rate=attention_downsample_rate,
+ skip_first_layer_pe=(i == 0),
+ )
+ )
+
+ self.final_attn_token_to_image = Attention(
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
+ )
+ self.norm_final_attn = nn.LayerNorm(embedding_dim)
+
+ def forward(
+ self,
+ image_embedding: Tensor,
+ image_pe: Tensor,
+ point_embedding: Tensor,
+ ) -> Tuple[Tensor, Tensor]:
+ """
+ Args:
+ image_embedding (torch.Tensor): image to attend to. Should be shape
+ B x embedding_dim x h x w for any h and w.
+ image_pe (torch.Tensor): the positional encoding to add to the image. Must
+ have the same shape as image_embedding.
+ point_embedding (torch.Tensor): the embedding to add to the query points.
+ Must have shape B x N_points x embedding_dim for any N_points.
+
+ Returns:
+ torch.Tensor: the processed point_embedding
+ torch.Tensor: the processed image_embedding
+ """
+ # BxCxHxW -> BxHWxC == B x N_image_tokens x C
+ bs, c, h, w = image_embedding.shape
+ image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
+ image_pe = image_pe.flatten(2).permute(0, 2, 1)
+
+ # Prepare queries
+ queries = point_embedding
+ keys = image_embedding
+
+ # Apply transformer blocks and final layernorm
+ for layer in self.layers:
+ queries, keys = layer(
+ queries=queries,
+ keys=keys,
+ query_pe=point_embedding,
+ key_pe=image_pe,
+ )
+
+ # Apply the final attention layer from the points to the image
+ q = queries + point_embedding
+ k = keys + image_pe
+ attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
+ queries = queries + attn_out
+ queries = self.norm_final_attn(queries)
+
+ return queries, keys
+
+
+class TwoWayAttentionBlock(nn.Module):
+ def __init__(
+ self,
+ embedding_dim: int,
+ num_heads: int,
+ mlp_dim: int = 2048,
+ activation: Type[nn.Module] = nn.ReLU,
+ attention_downsample_rate: int = 2,
+ skip_first_layer_pe: bool = False,
+ ) -> None:
+ """
+ A transformer block with four layers: (1) self-attention of sparse
+ inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
+ block on sparse inputs, and (4) cross attention of dense inputs to sparse
+ inputs.
+
+ Arguments:
+ embedding_dim (int): the channel dimension of the embeddings
+ num_heads (int): the number of heads in the attention layers
+ mlp_dim (int): the hidden dimension of the mlp block
+ activation (nn.Module): the activation of the mlp block
+ skip_first_layer_pe (bool): skip the PE on the first layer
+ """
+ super().__init__()
+ self.self_attn = Attention(embedding_dim, num_heads)
+ self.norm1 = nn.LayerNorm(embedding_dim)
+
+ self.cross_attn_token_to_image = Attention(
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
+ )
+ self.norm2 = nn.LayerNorm(embedding_dim)
+
+ self.mlp = MLP(
+ embedding_dim, mlp_dim, embedding_dim, num_layers=2, activation=activation
+ )
+ self.norm3 = nn.LayerNorm(embedding_dim)
+
+ self.norm4 = nn.LayerNorm(embedding_dim)
+ self.cross_attn_image_to_token = Attention(
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
+ )
+
+ self.skip_first_layer_pe = skip_first_layer_pe
+
+ def forward(
+ self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
+ ) -> Tuple[Tensor, Tensor]:
+ # Self attention block
+ if self.skip_first_layer_pe:
+ queries = self.self_attn(q=queries, k=queries, v=queries)
+ else:
+ q = queries + query_pe
+ attn_out = self.self_attn(q=q, k=q, v=queries)
+ queries = queries + attn_out
+ queries = self.norm1(queries)
+
+ # Cross attention block, tokens attending to image embedding
+ q = queries + query_pe
+ k = keys + key_pe
+ attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
+ queries = queries + attn_out
+ queries = self.norm2(queries)
+
+ # MLP block
+ mlp_out = self.mlp(queries)
+ queries = queries + mlp_out
+ queries = self.norm3(queries)
+
+ # Cross attention block, image embedding attending to tokens
+ q = queries + query_pe
+ k = keys + key_pe
+ attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
+ keys = keys + attn_out
+ keys = self.norm4(keys)
+
+ return queries, keys
+
+
+class Attention(nn.Module):
+ """
+ An attention layer that allows for downscaling the size of the embedding
+ after projection to queries, keys, and values.
+ """
+
+ def __init__(
+ self,
+ embedding_dim: int,
+ num_heads: int,
+ downsample_rate: int = 1,
+ dropout: float = 0.0,
+ kv_in_dim: int = None,
+ ) -> None:
+ super().__init__()
+ self.embedding_dim = embedding_dim
+ self.kv_in_dim = kv_in_dim if kv_in_dim is not None else embedding_dim
+ self.internal_dim = embedding_dim // downsample_rate
+ self.num_heads = num_heads
+ assert (
+ self.internal_dim % num_heads == 0
+ ), "num_heads must divide embedding_dim."
+
+ self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
+ self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
+ self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
+ self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
+
+ self.dropout_p = dropout
+
+ def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
+ b, n, c = x.shape
+ x = x.reshape(b, n, num_heads, c // num_heads)
+ return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
+
+ def _recombine_heads(self, x: Tensor) -> Tensor:
+ b, n_heads, n_tokens, c_per_head = x.shape
+ x = x.transpose(1, 2)
+ return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
+
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
+ # Input projections
+ q = self.q_proj(q)
+ k = self.k_proj(k)
+ v = self.v_proj(v)
+
+ # Separate into heads
+ q = self._separate_heads(q, self.num_heads)
+ k = self._separate_heads(k, self.num_heads)
+ v = self._separate_heads(v, self.num_heads)
+
+ dropout_p = self.dropout_p if self.training else 0.0
+ # Attention
+ try:
+ with sdp_kernel_context(dropout_p):
+ out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
+ except Exception as e:
+ # Fall back to all kernels if the Flash attention kernel fails
+ warnings.warn(
+ f"Flash Attention kernel failed due to: {e}\nFalling back to all available "
+ f"kernels for scaled_dot_product_attention (which may have a slower speed).",
+ category=UserWarning,
+ stacklevel=2,
+ )
+ global ALLOW_ALL_KERNELS
+ ALLOW_ALL_KERNELS = True
+ out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
+
+ out = self._recombine_heads(out)
+ out = self.out_proj(out)
+
+ return out
+
+
+class RoPEAttention(Attention):
+ """Attention with rotary position encoding."""
+
+ def __init__(
+ self,
+ *args,
+ rope_theta=10000.0,
+ # whether to repeat q rope to match k length
+ # this is needed for cross-attention to memories
+ rope_k_repeat=False,
+ feat_sizes=(32, 32), # [w, h] for stride 16 feats at 512 resolution
+ **kwargs,
+ ):
+ super().__init__(*args, **kwargs)
+
+ self.compute_cis = partial(
+ compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta
+ )
+ freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1])
+ self.freqs_cis = freqs_cis
+ self.rope_k_repeat = rope_k_repeat
+
+ def forward(
+ self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int = 0
+ ) -> Tensor:
+ # Input projections
+ q = self.q_proj(q)
+ k = self.k_proj(k)
+ v = self.v_proj(v)
+
+ # Separate into heads
+ q = self._separate_heads(q, self.num_heads)
+ k = self._separate_heads(k, self.num_heads)
+ v = self._separate_heads(v, self.num_heads)
+
+ # Apply rotary position encoding
+ w = h = math.sqrt(q.shape[-2])
+ self.freqs_cis = self.freqs_cis.to(q.device)
+ if self.freqs_cis.shape[0] != q.shape[-2]:
+ self.freqs_cis = self.compute_cis(end_x=w, end_y=h).to(q.device)
+ if q.shape[-2] != k.shape[-2]:
+ assert self.rope_k_repeat
+
+ num_k_rope = k.size(-2) - num_k_exclude_rope
+ q, k[:, :, :num_k_rope] = apply_rotary_enc(
+ q,
+ k[:, :, :num_k_rope],
+ freqs_cis=self.freqs_cis,
+ repeat_freqs_k=self.rope_k_repeat,
+ )
+
+ dropout_p = self.dropout_p if self.training else 0.0
+ # Attention
+ try:
+ with sdp_kernel_context(dropout_p):
+ out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
+ except Exception as e:
+ # Fall back to all kernels if the Flash attention kernel fails
+ warnings.warn(
+ f"Flash Attention kernel failed due to: {e}\nFalling back to all available "
+ f"kernels for scaled_dot_product_attention (which may have a slower speed).",
+ category=UserWarning,
+ stacklevel=2,
+ )
+ global ALLOW_ALL_KERNELS
+ ALLOW_ALL_KERNELS = True
+ out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
+
+ out = self._recombine_heads(out)
+ out = self.out_proj(out)
+
+ return out
diff --git a/libs/sam2/modeling/sam2_base.py b/libs/sam2/modeling/sam2_base.py
new file mode 100644
index 0000000000000000000000000000000000000000..a5d243adc9d7071f254dee115f92ff03d3b6e871
--- /dev/null
+++ b/libs/sam2/modeling/sam2_base.py
@@ -0,0 +1,907 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.distributed
+import torch.nn.functional as F
+
+from torch.nn.init import trunc_normal_
+
+from sam2.modeling.sam.mask_decoder import MaskDecoder
+from sam2.modeling.sam.prompt_encoder import PromptEncoder
+from sam2.modeling.sam.transformer import TwoWayTransformer
+from sam2.modeling.sam2_utils import get_1d_sine_pe, MLP, select_closest_cond_frames
+
+# a large negative value as a placeholder score for missing objects
+NO_OBJ_SCORE = -1024.0
+
+
+class SAM2Base(torch.nn.Module):
+ def __init__(
+ self,
+ image_encoder,
+ memory_attention,
+ memory_encoder,
+ num_maskmem=7, # default 1 input frame + 6 previous frames
+ image_size=512,
+ backbone_stride=16, # stride of the image backbone output
+ sigmoid_scale_for_mem_enc=1.0, # scale factor for mask sigmoid prob
+ sigmoid_bias_for_mem_enc=0.0, # bias factor for mask sigmoid prob
+ # During evaluation, whether to binarize the sigmoid mask logits on interacted frames with clicks
+ binarize_mask_from_pts_for_mem_enc=False,
+ use_mask_input_as_output_without_sam=False, # on frames with mask input, whether to directly output the input mask without using a SAM prompt encoder + mask decoder
+ # The maximum number of conditioning frames to participate in the memory attention (-1 means no limit; if there are more conditioning frames than this limit,
+ # we only cross-attend to the temporally closest `max_cond_frames_in_attn` conditioning frames in the encoder when tracking each frame). This gives the model
+ # a temporal locality when handling a large number of annotated frames (since closer frames should be more important) and also avoids GPU OOM.
+ max_cond_frames_in_attn=-1,
+ # on the first frame, whether to directly add the no-memory embedding to the image feature
+ # (instead of using the transformer encoder)
+ directly_add_no_mem_embed=False,
+ # whether to use high-resolution feature maps in the SAM mask decoder
+ use_high_res_features_in_sam=False,
+ # whether to output multiple (3) masks for the first click on initial conditioning frames
+ multimask_output_in_sam=False,
+ # the minimum and maximum number of clicks to use multimask_output_in_sam (only relevant when `multimask_output_in_sam=True`;
+ # default is 1 for both, meaning that only the first click gives multimask output; also note that a box counts as two points)
+ multimask_min_pt_num=1,
+ multimask_max_pt_num=1,
+ # whether to also use multimask output for tracking (not just for the first click on initial conditioning frames; only relevant when `multimask_output_in_sam=True`)
+ multimask_output_for_tracking=False,
+ # Whether to use multimask tokens for obj ptr; Only relevant when both
+ # use_obj_ptrs_in_encoder=True and multimask_output_for_tracking=True
+ use_multimask_token_for_obj_ptr: bool = False,
+ # whether to use sigmoid to restrict ious prediction to [0-1]
+ iou_prediction_use_sigmoid=False,
+ # The memory bank's temporal stride during evaluation (i.e. the `r` parameter in XMem and Cutie; XMem and Cutie use r=5).
+ # For r>1, the (self.num_maskmem - 1) non-conditioning memory frames consist of
+ # (self.num_maskmem - 2) nearest frames from every r-th frames, plus the last frame.
+ memory_temporal_stride_for_eval=1,
+ # whether to apply non-overlapping constraints on the object masks in the memory encoder during evaluation (to avoid/alleviate superposing masks)
+ non_overlap_masks_for_mem_enc=False,
+ # whether to cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
+ use_obj_ptrs_in_encoder=False,
+ # the maximum number of object pointers from other frames in encoder cross attention (only relevant when `use_obj_ptrs_in_encoder=True`)
+ max_obj_ptrs_in_encoder=16,
+ # whether to add temporal positional encoding to the object pointers in the encoder (only relevant when `use_obj_ptrs_in_encoder=True`)
+ add_tpos_enc_to_obj_ptrs=True,
+ # whether to add an extra linear projection layer for the temporal positional encoding in the object pointers to avoid potential interference
+ # with spatial positional encoding (only relevant when both `use_obj_ptrs_in_encoder=True` and `add_tpos_enc_to_obj_ptrs=True`)
+ proj_tpos_enc_in_obj_ptrs=False,
+ # whether to use signed distance (instead of unsigned absolute distance) in the temporal positional encoding in the object pointers
+ # (only relevant when both `use_obj_ptrs_in_encoder=True` and `add_tpos_enc_to_obj_ptrs=True`)
+ use_signed_tpos_enc_to_obj_ptrs=False,
+ # whether to only attend to object pointers in the past (before the current frame) in the encoder during evaluation
+ # (only relevant when `use_obj_ptrs_in_encoder=True`; this might avoid pointer information too far in the future to distract the initial tracking)
+ only_obj_ptrs_in_the_past_for_eval=False,
+ # Whether to predict if there is an object in the frame
+ pred_obj_scores: bool = False,
+ # Whether to use an MLP to predict object scores
+ pred_obj_scores_mlp: bool = False,
+ # Only relevant if pred_obj_scores=True and use_obj_ptrs_in_encoder=True;
+ # Whether to have a fixed no obj pointer when there is no object present
+ # or to use it as an additive embedding with obj_ptr produced by decoder
+ fixed_no_obj_ptr: bool = False,
+ # Soft no object, i.e. mix in no_obj_ptr softly,
+ # hope to make recovery easier if there is a mistake and mitigate accumulation of errors
+ soft_no_obj_ptr: bool = False,
+ use_mlp_for_obj_ptr_proj: bool = False,
+ # add no obj embedding to spatial frames
+ no_obj_embed_spatial: bool = False,
+ # extra arguments used to construct the SAM mask decoder; if not None, it should be a dict of kwargs to be passed into `MaskDecoder` class.
+ sam_mask_decoder_extra_args=None,
+ compile_image_encoder: bool = False,
+ ):
+ super().__init__()
+
+ # Part 1: the image backbone
+ self.image_encoder = image_encoder
+ # Use level 0, 1, 2 for high-res setting, or just level 2 for the default setting
+ self.use_high_res_features_in_sam = use_high_res_features_in_sam
+ self.num_feature_levels = 3 if use_high_res_features_in_sam else 1
+ self.use_obj_ptrs_in_encoder = use_obj_ptrs_in_encoder
+ self.max_obj_ptrs_in_encoder = max_obj_ptrs_in_encoder
+ if use_obj_ptrs_in_encoder:
+ # A conv layer to downsample the mask prompt to stride 4 (the same stride as
+ # low-res SAM mask logits) and to change its scales from 0~1 to SAM logit scale,
+ # so that it can be fed into the SAM mask decoder to generate a pointer.
+ self.mask_downsample = torch.nn.Conv2d(1, 1, kernel_size=4, stride=4)
+ self.add_tpos_enc_to_obj_ptrs = add_tpos_enc_to_obj_ptrs
+ if proj_tpos_enc_in_obj_ptrs:
+ assert add_tpos_enc_to_obj_ptrs # these options need to be used together
+ self.proj_tpos_enc_in_obj_ptrs = proj_tpos_enc_in_obj_ptrs
+ self.use_signed_tpos_enc_to_obj_ptrs = use_signed_tpos_enc_to_obj_ptrs
+ self.only_obj_ptrs_in_the_past_for_eval = only_obj_ptrs_in_the_past_for_eval
+
+ # Part 2: memory attention to condition current frame's visual features
+ # with memories (and obj ptrs) from past frames
+ self.memory_attention = memory_attention
+ self.hidden_dim = image_encoder.neck.d_model
+
+ # Part 3: memory encoder for the previous frame's outputs
+ self.memory_encoder = memory_encoder
+ self.mem_dim = self.hidden_dim
+ if hasattr(self.memory_encoder, "out_proj") and hasattr(
+ self.memory_encoder.out_proj, "weight"
+ ):
+ # if there is compression of memories along channel dim
+ self.mem_dim = self.memory_encoder.out_proj.weight.shape[0]
+ self.num_maskmem = num_maskmem # Number of memories accessible
+ # Temporal encoding of the memories
+ self.maskmem_tpos_enc = torch.nn.Parameter(
+ torch.zeros(num_maskmem, 1, 1, self.mem_dim)
+ )
+ trunc_normal_(self.maskmem_tpos_enc, std=0.02)
+ # a single token to indicate no memory embedding from previous frames
+ self.no_mem_embed = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
+ self.no_mem_pos_enc = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
+ trunc_normal_(self.no_mem_embed, std=0.02)
+ trunc_normal_(self.no_mem_pos_enc, std=0.02)
+ self.directly_add_no_mem_embed = directly_add_no_mem_embed
+ # Apply sigmoid to the output raw mask logits (to turn them from
+ # range (-inf, +inf) to range (0, 1)) before feeding them into the memory encoder
+ self.sigmoid_scale_for_mem_enc = sigmoid_scale_for_mem_enc
+ self.sigmoid_bias_for_mem_enc = sigmoid_bias_for_mem_enc
+ self.binarize_mask_from_pts_for_mem_enc = binarize_mask_from_pts_for_mem_enc
+ self.non_overlap_masks_for_mem_enc = non_overlap_masks_for_mem_enc
+ self.memory_temporal_stride_for_eval = memory_temporal_stride_for_eval
+ # On frames with mask input, whether to directly output the input mask without
+ # using a SAM prompt encoder + mask decoder
+ self.use_mask_input_as_output_without_sam = use_mask_input_as_output_without_sam
+ self.multimask_output_in_sam = multimask_output_in_sam
+ self.multimask_min_pt_num = multimask_min_pt_num
+ self.multimask_max_pt_num = multimask_max_pt_num
+ self.multimask_output_for_tracking = multimask_output_for_tracking
+ self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr
+ self.iou_prediction_use_sigmoid = iou_prediction_use_sigmoid
+
+ # Part 4: SAM-style prompt encoder (for both mask and point inputs)
+ # and SAM-style mask decoder for the final mask output
+ self.image_size = image_size
+ self.backbone_stride = backbone_stride
+ self.sam_mask_decoder_extra_args = sam_mask_decoder_extra_args
+ self.pred_obj_scores = pred_obj_scores
+ self.pred_obj_scores_mlp = pred_obj_scores_mlp
+ self.fixed_no_obj_ptr = fixed_no_obj_ptr
+ self.soft_no_obj_ptr = soft_no_obj_ptr
+ if self.fixed_no_obj_ptr:
+ assert self.pred_obj_scores
+ assert self.use_obj_ptrs_in_encoder
+ if self.pred_obj_scores and self.use_obj_ptrs_in_encoder:
+ self.no_obj_ptr = torch.nn.Parameter(torch.zeros(1, self.hidden_dim))
+ trunc_normal_(self.no_obj_ptr, std=0.02)
+ self.use_mlp_for_obj_ptr_proj = use_mlp_for_obj_ptr_proj
+ self.no_obj_embed_spatial = None
+ if no_obj_embed_spatial:
+ self.no_obj_embed_spatial = torch.nn.Parameter(torch.zeros(1, self.mem_dim))
+ trunc_normal_(self.no_obj_embed_spatial, std=0.02)
+
+ self._build_sam_heads()
+ self.max_cond_frames_in_attn = max_cond_frames_in_attn
+
+ # Model compilation
+ if compile_image_encoder:
+ # Compile the forward function (not the full module) to allow loading checkpoints.
+ print(
+ "Image encoder compilation is enabled. First forward pass will be slow."
+ )
+ self.image_encoder.forward = torch.compile(
+ self.image_encoder.forward,
+ mode="max-autotune",
+ fullgraph=True,
+ dynamic=False,
+ )
+
+ @property
+ def device(self):
+ return next(self.parameters()).device
+
+ def forward(self, *args, **kwargs):
+ raise NotImplementedError(
+ "Please use the corresponding methods in SAM2VideoPredictor for inference or SAM2Train for training/fine-tuning"
+ "See notebooks/video_predictor_example.ipynb for an inference example."
+ )
+
+ def _build_sam_heads(self):
+ """Build SAM-style prompt encoder and mask decoder."""
+ self.sam_prompt_embed_dim = self.hidden_dim
+ self.sam_image_embedding_size = self.image_size // self.backbone_stride
+
+ # build PromptEncoder and MaskDecoder from SAM
+ # (their hyperparameters like `mask_in_chans=16` are from SAM code)
+ self.sam_prompt_encoder = PromptEncoder(
+ embed_dim=self.sam_prompt_embed_dim,
+ image_embedding_size=(
+ self.sam_image_embedding_size,
+ self.sam_image_embedding_size,
+ ),
+ input_image_size=(self.image_size, self.image_size),
+ mask_in_chans=16,
+ )
+ self.sam_mask_decoder = MaskDecoder(
+ num_multimask_outputs=3,
+ transformer=TwoWayTransformer(
+ depth=2,
+ embedding_dim=self.sam_prompt_embed_dim,
+ mlp_dim=2048,
+ num_heads=8,
+ ),
+ transformer_dim=self.sam_prompt_embed_dim,
+ iou_head_depth=3,
+ iou_head_hidden_dim=256,
+ use_high_res_features=self.use_high_res_features_in_sam,
+ iou_prediction_use_sigmoid=self.iou_prediction_use_sigmoid,
+ pred_obj_scores=self.pred_obj_scores,
+ pred_obj_scores_mlp=self.pred_obj_scores_mlp,
+ use_multimask_token_for_obj_ptr=self.use_multimask_token_for_obj_ptr,
+ **(self.sam_mask_decoder_extra_args or {}),
+ )
+ if self.use_obj_ptrs_in_encoder:
+ # a linear projection on SAM output tokens to turn them into object pointers
+ self.obj_ptr_proj = torch.nn.Linear(self.hidden_dim, self.hidden_dim)
+ if self.use_mlp_for_obj_ptr_proj:
+ self.obj_ptr_proj = MLP(
+ self.hidden_dim, self.hidden_dim, self.hidden_dim, 3
+ )
+ else:
+ self.obj_ptr_proj = torch.nn.Identity()
+ if self.proj_tpos_enc_in_obj_ptrs:
+ # a linear projection on temporal positional encoding in object pointers to
+ # avoid potential interference with spatial positional encoding
+ self.obj_ptr_tpos_proj = torch.nn.Linear(self.hidden_dim, self.mem_dim)
+ else:
+ self.obj_ptr_tpos_proj = torch.nn.Identity()
+
+ def _forward_sam_heads(
+ self,
+ backbone_features,
+ point_inputs=None,
+ mask_inputs=None,
+ high_res_features=None,
+ multimask_output=False,
+ ):
+ """
+ Forward SAM prompt encoders and mask heads.
+
+ Inputs:
+ - backbone_features: image features of [B, C, H, W] shape
+ - point_inputs: a dictionary with "point_coords" and "point_labels", where
+ 1) "point_coords" has [B, P, 2] shape and float32 dtype and contains the
+ absolute pixel-unit coordinate in (x, y) format of the P input points
+ 2) "point_labels" has shape [B, P] and int32 dtype, where 1 means
+ positive clicks, 0 means negative clicks, and -1 means padding
+ - mask_inputs: a mask of [B, 1, H*16, W*16] shape, float or bool, with the
+ same spatial size as the image.
+ - high_res_features: either 1) None or 2) or a list of length 2 containing
+ two feature maps of [B, C, 4*H, 4*W] and [B, C, 2*H, 2*W] shapes respectively,
+ which will be used as high-resolution feature maps for SAM decoder.
+ - multimask_output: if it's True, we output 3 candidate masks and their 3
+ corresponding IoU estimates, and if it's False, we output only 1 mask and
+ its corresponding IoU estimate.
+
+ Outputs:
+ - low_res_multimasks: [B, M, H*4, W*4] shape (where M = 3 if
+ `multimask_output=True` and M = 1 if `multimask_output=False`), the SAM
+ output mask logits (before sigmoid) for the low-resolution masks, with 4x
+ the resolution (1/4 stride) of the input backbone_features.
+ - high_res_multimasks: [B, M, H*16, W*16] shape (where M = 3
+ if `multimask_output=True` and M = 1 if `multimask_output=False`),
+ upsampled from the low-resolution masks, with shape size as the image
+ (stride is 1 pixel).
+ - ious, [B, M] shape, where (where M = 3 if `multimask_output=True` and M = 1
+ if `multimask_output=False`), the estimated IoU of each output mask.
+ - low_res_masks: [B, 1, H*4, W*4] shape, the best mask in `low_res_multimasks`.
+ If `multimask_output=True`, it's the mask with the highest IoU estimate.
+ If `multimask_output=False`, it's the same as `low_res_multimasks`.
+ - high_res_masks: [B, 1, H*16, W*16] shape, the best mask in `high_res_multimasks`.
+ If `multimask_output=True`, it's the mask with the highest IoU estimate.
+ If `multimask_output=False`, it's the same as `high_res_multimasks`.
+ - obj_ptr: [B, C] shape, the object pointer vector for the output mask, extracted
+ based on the output token from the SAM mask decoder.
+ """
+ B = backbone_features.size(0)
+ device = backbone_features.device
+ assert backbone_features.size(1) == self.sam_prompt_embed_dim
+ assert backbone_features.size(2) == self.sam_image_embedding_size
+ assert backbone_features.size(3) == self.sam_image_embedding_size
+
+ # a) Handle point prompts
+ if point_inputs is not None:
+ sam_point_coords = point_inputs["point_coords"]
+ sam_point_labels = point_inputs["point_labels"]
+ assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B
+ else:
+ # If no points are provide, pad with an empty point (with label -1)
+ sam_point_coords = torch.zeros(B, 1, 2, device=device)
+ sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device)
+
+ # b) Handle mask prompts
+ if mask_inputs is not None:
+ # If mask_inputs is provided, downsize it into low-res mask input if needed
+ # and feed it as a dense mask prompt into the SAM mask encoder
+ assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1)
+ if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size:
+ sam_mask_prompt = F.interpolate(
+ mask_inputs.float(),
+ size=self.sam_prompt_encoder.mask_input_size,
+ align_corners=False,
+ mode="bilinear",
+ antialias=True, # use antialias for downsampling
+ )
+ else:
+ sam_mask_prompt = mask_inputs
+ else:
+ # Otherwise, simply feed None (and SAM's prompt encoder will add
+ # a learned `no_mask_embed` to indicate no mask input in this case).
+ sam_mask_prompt = None
+
+ sparse_embeddings, dense_embeddings = self.sam_prompt_encoder(
+ points=(sam_point_coords, sam_point_labels),
+ boxes=None,
+ masks=sam_mask_prompt,
+ )
+ (
+ low_res_multimasks,
+ ious,
+ sam_output_tokens,
+ object_score_logits,
+ ) = self.sam_mask_decoder(
+ image_embeddings=backbone_features,
+ image_pe=self.sam_prompt_encoder.get_dense_pe(),
+ sparse_prompt_embeddings=sparse_embeddings,
+ dense_prompt_embeddings=dense_embeddings,
+ multimask_output=multimask_output,
+ repeat_image=False, # the image is already batched
+ high_res_features=high_res_features,
+ )
+ if self.pred_obj_scores:
+ is_obj_appearing = object_score_logits > 0
+
+ # Mask used for spatial memories is always a *hard* choice between obj and no obj,
+ # consistent with the actual mask prediction
+ low_res_multimasks = torch.where(
+ is_obj_appearing[:, None, None],
+ low_res_multimasks,
+ NO_OBJ_SCORE,
+ )
+
+ # convert masks from possibly bfloat16 (or float16) to float32
+ # (older PyTorch versions before 2.1 don't support `interpolate` on bf16)
+ low_res_multimasks = low_res_multimasks.float()
+ high_res_multimasks = F.interpolate(
+ low_res_multimasks,
+ size=(self.image_size, self.image_size),
+ mode="bilinear",
+ align_corners=False,
+ )
+
+ sam_output_token = sam_output_tokens[:, 0]
+ if multimask_output:
+ # take the best mask prediction (with the highest IoU estimation)
+ best_iou_inds = torch.argmax(ious, dim=-1)
+ batch_inds = torch.arange(B, device=device)
+ low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
+ high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
+ if sam_output_tokens.size(1) > 1:
+ sam_output_token = sam_output_tokens[batch_inds, best_iou_inds]
+ else:
+ low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks
+
+ # Extract object pointer from the SAM output token (with occlusion handling)
+ obj_ptr = self.obj_ptr_proj(sam_output_token)
+ if self.pred_obj_scores:
+ # Allow *soft* no obj ptr, unlike for masks
+ if self.soft_no_obj_ptr:
+ lambda_is_obj_appearing = object_score_logits.sigmoid()
+ else:
+ lambda_is_obj_appearing = is_obj_appearing.float()
+
+ if self.fixed_no_obj_ptr:
+ obj_ptr = lambda_is_obj_appearing * obj_ptr
+ obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
+
+ return (
+ low_res_multimasks,
+ high_res_multimasks,
+ ious,
+ low_res_masks,
+ high_res_masks,
+ obj_ptr,
+ object_score_logits,
+ )
+
+ def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs):
+ """
+ Directly turn binary `mask_inputs` into a output mask logits without using SAM.
+ (same input and output shapes as in _forward_sam_heads above).
+ """
+ # Use -10/+10 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid).
+ out_scale, out_bias = 20.0, -10.0 # sigmoid(-10.0)=4.5398e-05
+ mask_inputs_float = mask_inputs.float()
+ high_res_masks = mask_inputs_float * out_scale + out_bias
+ low_res_masks = F.interpolate(
+ high_res_masks,
+ size=(high_res_masks.size(-2) // 4, high_res_masks.size(-1) // 4),
+ align_corners=False,
+ mode="bilinear",
+ antialias=True, # use antialias for downsampling
+ )
+ # a dummy IoU prediction of all 1's under mask input
+ ious = mask_inputs.new_ones(mask_inputs.size(0), 1).float()
+ if not self.use_obj_ptrs_in_encoder:
+ # all zeros as a dummy object pointer (of shape [B, C])
+ obj_ptr = torch.zeros(
+ mask_inputs.size(0), self.hidden_dim, device=mask_inputs.device
+ )
+ else:
+ # produce an object pointer using the SAM decoder from the mask input
+ _, _, _, _, _, obj_ptr, _ = self._forward_sam_heads(
+ backbone_features=backbone_features,
+ mask_inputs=self.mask_downsample(mask_inputs_float),
+ high_res_features=high_res_features,
+ )
+ # In this method, we are treating mask_input as output, e.g. using it directly to create spatial mem;
+ # Below, we follow the same design axiom to use mask_input to decide if obj appears or not instead of relying
+ # on the object_scores from the SAM decoder.
+ is_obj_appearing = torch.any(mask_inputs.flatten(1).float() > 0.0, dim=1)
+ is_obj_appearing = is_obj_appearing[..., None]
+ lambda_is_obj_appearing = is_obj_appearing.float()
+ object_score_logits = out_scale * lambda_is_obj_appearing + out_bias
+ if self.pred_obj_scores:
+ if self.fixed_no_obj_ptr:
+ obj_ptr = lambda_is_obj_appearing * obj_ptr
+ obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
+
+ return (
+ low_res_masks,
+ high_res_masks,
+ ious,
+ low_res_masks,
+ high_res_masks,
+ obj_ptr,
+ object_score_logits,
+ )
+
+ def forward_image(self, img_batch: torch.Tensor):
+ """Get the image feature on the input batch."""
+ backbone_out = self.image_encoder(img_batch)
+ if self.use_high_res_features_in_sam:
+ # precompute projected level 0 and level 1 features in SAM decoder
+ # to avoid running it again on every SAM click
+ backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0(
+ backbone_out["backbone_fpn"][0]
+ )
+ backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(
+ backbone_out["backbone_fpn"][1]
+ )
+ return backbone_out
+
+ def _prepare_backbone_features(self, backbone_out):
+ """Prepare and flatten visual features."""
+ backbone_out = backbone_out.copy()
+ assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"])
+ assert len(backbone_out["backbone_fpn"]) >= self.num_feature_levels
+
+ feature_maps = backbone_out["backbone_fpn"][-self.num_feature_levels :]
+ vision_pos_embeds = backbone_out["vision_pos_enc"][-self.num_feature_levels :]
+
+ feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds]
+ # flatten NxCxHxW to HWxNxC
+ vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps]
+ vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds]
+
+ return backbone_out, vision_feats, vision_pos_embeds, feat_sizes
+
+ def _prepare_memory_conditioned_features(
+ self,
+ frame_idx,
+ is_init_cond_frame,
+ current_vision_feats,
+ current_vision_pos_embeds,
+ feat_sizes,
+ output_dict,
+ num_frames,
+ track_in_reverse=False, # tracking in reverse time order (for demo usage)
+ ):
+ """Fuse the current frame's visual feature map with previous memory."""
+ B = current_vision_feats[-1].size(1) # batch size on this frame
+ C = self.hidden_dim
+ H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size
+ device = current_vision_feats[-1].device
+ # The case of `self.num_maskmem == 0` below is primarily used for reproducing SAM on images.
+ # In this case, we skip the fusion with any memory.
+ if self.num_maskmem == 0: # Disable memory and skip fusion
+ pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)
+ return pix_feat
+
+ num_obj_ptr_tokens = 0
+ tpos_sign_mul = -1 if track_in_reverse else 1
+ # Step 1: condition the visual features of the current frame on previous memories
+ if not is_init_cond_frame:
+ # Retrieve the memories encoded with the maskmem backbone
+ to_cat_memory, to_cat_memory_pos_embed = [], []
+ # Add conditioning frames's output first (all cond frames have t_pos=0 for
+ # when getting temporal positional embedding below)
+ assert len(output_dict["cond_frame_outputs"]) > 0
+ # Select a maximum number of temporally closest cond frames for cross attention
+ cond_outputs = output_dict["cond_frame_outputs"]
+ selected_cond_outputs, unselected_cond_outputs = select_closest_cond_frames(
+ frame_idx, cond_outputs, self.max_cond_frames_in_attn
+ )
+ t_pos_and_prevs = [(0, out) for out in selected_cond_outputs.values()]
+ # Add last (self.num_maskmem - 1) frames before current frame for non-conditioning memory
+ # the earliest one has t_pos=1 and the latest one has t_pos=self.num_maskmem-1
+ # We also allow taking the memory frame non-consecutively (with stride>1), in which case
+ # we take (self.num_maskmem - 2) frames among every stride-th frames plus the last frame.
+ stride = 1 if self.training else self.memory_temporal_stride_for_eval
+ for t_pos in range(1, self.num_maskmem):
+ t_rel = self.num_maskmem - t_pos # how many frames before current frame
+ if t_rel == 1:
+ # for t_rel == 1, we take the last frame (regardless of r)
+ if not track_in_reverse:
+ # the frame immediately before this frame (i.e. frame_idx - 1)
+ prev_frame_idx = frame_idx - t_rel
+ else:
+ # the frame immediately after this frame (i.e. frame_idx + 1)
+ prev_frame_idx = frame_idx + t_rel
+ else:
+ # for t_rel >= 2, we take the memory frame from every r-th frames
+ if not track_in_reverse:
+ # first find the nearest frame among every r-th frames before this frame
+ # for r=1, this would be (frame_idx - 2)
+ prev_frame_idx = ((frame_idx - 2) // stride) * stride
+ # then seek further among every r-th frames
+ prev_frame_idx = prev_frame_idx - (t_rel - 2) * stride
+ else:
+ # first find the nearest frame among every r-th frames after this frame
+ # for r=1, this would be (frame_idx + 2)
+ prev_frame_idx = -(-(frame_idx + 2) // stride) * stride
+ # then seek further among every r-th frames
+ prev_frame_idx = prev_frame_idx + (t_rel - 2) * stride
+ out = output_dict["non_cond_frame_outputs"].get(prev_frame_idx, None)
+ if out is None:
+ # If an unselected conditioning frame is among the last (self.num_maskmem - 1)
+ # frames, we still attend to it as if it's a non-conditioning frame.
+ out = unselected_cond_outputs.get(prev_frame_idx, None)
+ t_pos_and_prevs.append((t_pos, out))
+
+ for t_pos, prev in t_pos_and_prevs:
+ if prev is None:
+ continue # skip padding frames
+ # "maskmem_features" might have been offloaded to CPU in demo use cases,
+ # so we load it back to GPU (it's a no-op if it's already on GPU).
+ feats = prev["maskmem_features"].to(device, non_blocking=True)
+ to_cat_memory.append(feats.flatten(2).permute(2, 0, 1))
+ # Spatial positional encoding (it might have been offloaded to CPU in eval)
+ maskmem_enc = prev["maskmem_pos_enc"][-1].to(device)
+ maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1)
+ # Temporal positional encoding
+ maskmem_enc = (
+ maskmem_enc + self.maskmem_tpos_enc[self.num_maskmem - t_pos - 1]
+ )
+ to_cat_memory_pos_embed.append(maskmem_enc)
+
+ # Construct the list of past object pointers
+ if self.use_obj_ptrs_in_encoder:
+ max_obj_ptrs_in_encoder = min(num_frames, self.max_obj_ptrs_in_encoder)
+ # First add those object pointers from selected conditioning frames
+ # (optionally, only include object pointers in the past during evaluation)
+ if not self.training and self.only_obj_ptrs_in_the_past_for_eval:
+ ptr_cond_outputs = {
+ t: out
+ for t, out in selected_cond_outputs.items()
+ if (t >= frame_idx if track_in_reverse else t <= frame_idx)
+ }
+ else:
+ ptr_cond_outputs = selected_cond_outputs
+ pos_and_ptrs = [
+ # Temporal pos encoding contains how far away each pointer is from current frame
+ (
+ (
+ (frame_idx - t) * tpos_sign_mul
+ if self.use_signed_tpos_enc_to_obj_ptrs
+ else abs(frame_idx - t)
+ ),
+ out["obj_ptr"],
+ )
+ for t, out in ptr_cond_outputs.items()
+ ]
+ # Add up to (max_obj_ptrs_in_encoder - 1) non-conditioning frames before current frame
+ for t_diff in range(1, max_obj_ptrs_in_encoder):
+ t = frame_idx + t_diff if track_in_reverse else frame_idx - t_diff
+ if t < 0 or (num_frames is not None and t >= num_frames):
+ break
+ out = output_dict["non_cond_frame_outputs"].get(
+ t, unselected_cond_outputs.get(t, None)
+ )
+ if out is not None:
+ pos_and_ptrs.append((t_diff, out["obj_ptr"]))
+ # If we have at least one object pointer, add them to the across attention
+ if len(pos_and_ptrs) > 0:
+ pos_list, ptrs_list = zip(*pos_and_ptrs)
+ # stack object pointers along dim=0 into [ptr_seq_len, B, C] shape
+ obj_ptrs = torch.stack(ptrs_list, dim=0)
+ # a temporal positional embedding based on how far each object pointer is from
+ # the current frame (sine embedding normalized by the max pointer num).
+ if self.add_tpos_enc_to_obj_ptrs:
+ t_diff_max = max_obj_ptrs_in_encoder - 1
+ tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim
+ obj_pos = torch.tensor(pos_list, device=device)
+ obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim)
+ obj_pos = self.obj_ptr_tpos_proj(obj_pos)
+ obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim)
+ else:
+ obj_pos = obj_ptrs.new_zeros(len(pos_list), B, self.mem_dim)
+ if self.mem_dim < C:
+ # split a pointer into (C // self.mem_dim) tokens for self.mem_dim < C
+ obj_ptrs = obj_ptrs.reshape(
+ -1, B, C // self.mem_dim, self.mem_dim
+ )
+ obj_ptrs = obj_ptrs.permute(0, 2, 1, 3).flatten(0, 1)
+ obj_pos = obj_pos.repeat_interleave(C // self.mem_dim, dim=0)
+ to_cat_memory.append(obj_ptrs)
+ to_cat_memory_pos_embed.append(obj_pos)
+ num_obj_ptr_tokens = obj_ptrs.shape[0]
+ else:
+ num_obj_ptr_tokens = 0
+ else:
+ # for initial conditioning frames, encode them without using any previous memory
+ if self.directly_add_no_mem_embed:
+ # directly add no-mem embedding (instead of using the transformer encoder)
+ pix_feat_with_mem = current_vision_feats[-1] + self.no_mem_embed
+ pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
+ return pix_feat_with_mem
+
+ # Use a dummy token on the first frame (to avoid empty memory input to tranformer encoder)
+ to_cat_memory = [self.no_mem_embed.expand(1, B, self.mem_dim)]
+ to_cat_memory_pos_embed = [self.no_mem_pos_enc.expand(1, B, self.mem_dim)]
+
+ # Step 2: Concatenate the memories and forward through the transformer encoder
+ memory = torch.cat(to_cat_memory, dim=0)
+ memory_pos_embed = torch.cat(to_cat_memory_pos_embed, dim=0)
+
+ pix_feat_with_mem = self.memory_attention(
+ curr=current_vision_feats,
+ curr_pos=current_vision_pos_embeds,
+ memory=memory,
+ memory_pos=memory_pos_embed,
+ num_obj_ptr_tokens=num_obj_ptr_tokens,
+ )
+ # reshape the output (HW)BC => BCHW
+ pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
+ return pix_feat_with_mem
+
+ def _encode_new_memory(
+ self,
+ current_vision_feats,
+ feat_sizes,
+ pred_masks_high_res,
+ object_score_logits,
+ is_mask_from_pts,
+ ):
+ """Encode the current image and its prediction into a memory feature."""
+ B = current_vision_feats[-1].size(1) # batch size on this frame
+ C = self.hidden_dim
+ H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size
+ # top-level feature, (HW)BC => BCHW
+ pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)
+ if self.non_overlap_masks_for_mem_enc and not self.training:
+ # optionally, apply non-overlapping constraints to the masks (it's applied
+ # in the batch dimension and should only be used during eval, where all
+ # the objects come from the same video under batch size 1).
+ pred_masks_high_res = self._apply_non_overlapping_constraints(
+ pred_masks_high_res
+ )
+ # scale the raw mask logits with a temperature before applying sigmoid
+ binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts
+ if binarize and not self.training:
+ mask_for_mem = (pred_masks_high_res > 0).float()
+ else:
+ # apply sigmoid on the raw mask logits to turn them into range (0, 1)
+ mask_for_mem = torch.sigmoid(pred_masks_high_res)
+ # apply scale and bias terms to the sigmoid probabilities
+ if self.sigmoid_scale_for_mem_enc != 1.0:
+ mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc
+ if self.sigmoid_bias_for_mem_enc != 0.0:
+ mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc
+ maskmem_out = self.memory_encoder(
+ pix_feat, mask_for_mem, skip_mask_sigmoid=True # sigmoid already applied
+ )
+ maskmem_features = maskmem_out["vision_features"]
+ maskmem_pos_enc = maskmem_out["vision_pos_enc"]
+ # add a no-object embedding to the spatial memory to indicate that the frame
+ # is predicted to be occluded (i.e. no object is appearing in the frame)
+ if self.no_obj_embed_spatial is not None:
+ is_obj_appearing = (object_score_logits > 0).float()
+ maskmem_features += (
+ 1 - is_obj_appearing[..., None, None]
+ ) * self.no_obj_embed_spatial[..., None, None].expand(
+ *maskmem_features.shape
+ )
+
+ return maskmem_features, maskmem_pos_enc
+
+ def _track_step(
+ self,
+ frame_idx,
+ is_init_cond_frame,
+ current_vision_feats,
+ current_vision_pos_embeds,
+ feat_sizes,
+ point_inputs,
+ mask_inputs,
+ output_dict,
+ num_frames,
+ track_in_reverse,
+ prev_sam_mask_logits,
+ ):
+ current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs}
+ # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW
+ if len(current_vision_feats) > 1:
+ high_res_features = [
+ x.permute(1, 2, 0).view(x.size(1), x.size(2), *s)
+ for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1])
+ ]
+ else:
+ high_res_features = None
+ if mask_inputs is not None and self.use_mask_input_as_output_without_sam:
+ # When use_mask_input_as_output_without_sam=True, we directly output the mask input
+ # (see it as a GT mask) without using a SAM prompt encoder + mask decoder.
+ pix_feat = current_vision_feats[-1].permute(1, 2, 0)
+ pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1])
+ sam_outputs = self._use_mask_as_output(
+ pix_feat, high_res_features, mask_inputs
+ )
+ else:
+ # fused the visual feature with previous memory features in the memory bank
+ pix_feat = self._prepare_memory_conditioned_features(
+ frame_idx=frame_idx,
+ is_init_cond_frame=is_init_cond_frame,
+ current_vision_feats=current_vision_feats[-1:],
+ current_vision_pos_embeds=current_vision_pos_embeds[-1:],
+ feat_sizes=feat_sizes[-1:],
+ output_dict=output_dict,
+ num_frames=num_frames,
+ track_in_reverse=track_in_reverse,
+ )
+ # apply SAM-style segmentation head
+ # here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder,
+ # e.g. in demo where such logits come from earlier interaction instead of correction sampling
+ # (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead)
+ if prev_sam_mask_logits is not None:
+ assert point_inputs is not None and mask_inputs is None
+ mask_inputs = prev_sam_mask_logits
+ multimask_output = self._use_multimask(is_init_cond_frame, point_inputs)
+ sam_outputs = self._forward_sam_heads(
+ backbone_features=pix_feat,
+ point_inputs=point_inputs,
+ mask_inputs=mask_inputs,
+ high_res_features=high_res_features,
+ multimask_output=multimask_output,
+ )
+
+ return current_out, sam_outputs, high_res_features, pix_feat
+
+ def _encode_memory_in_output(
+ self,
+ current_vision_feats,
+ feat_sizes,
+ point_inputs,
+ run_mem_encoder,
+ high_res_masks,
+ object_score_logits,
+ current_out,
+ ):
+ if run_mem_encoder and self.num_maskmem > 0:
+ high_res_masks_for_mem_enc = high_res_masks
+ maskmem_features, maskmem_pos_enc = self._encode_new_memory(
+ current_vision_feats=current_vision_feats,
+ feat_sizes=feat_sizes,
+ pred_masks_high_res=high_res_masks_for_mem_enc,
+ object_score_logits=object_score_logits,
+ is_mask_from_pts=(point_inputs is not None),
+ )
+ current_out["maskmem_features"] = maskmem_features
+ current_out["maskmem_pos_enc"] = maskmem_pos_enc
+ else:
+ current_out["maskmem_features"] = None
+ current_out["maskmem_pos_enc"] = None
+
+ def track_step(
+ self,
+ frame_idx,
+ is_init_cond_frame,
+ current_vision_feats,
+ current_vision_pos_embeds,
+ feat_sizes,
+ point_inputs,
+ mask_inputs,
+ output_dict,
+ num_frames,
+ track_in_reverse=False, # tracking in reverse time order (for demo usage)
+ # Whether to run the memory encoder on the predicted masks. Sometimes we might want
+ # to skip the memory encoder with `run_mem_encoder=False`. For example,
+ # in demo we might call `track_step` multiple times for each user click,
+ # and only encode the memory when the user finalizes their clicks. And in ablation
+ # settings like SAM training on static images, we don't need the memory encoder.
+ run_mem_encoder=True,
+ # The previously predicted SAM mask logits (which can be fed together with new clicks in demo).
+ prev_sam_mask_logits=None,
+ ):
+ current_out, sam_outputs, _, _ = self._track_step(
+ frame_idx,
+ is_init_cond_frame,
+ current_vision_feats,
+ current_vision_pos_embeds,
+ feat_sizes,
+ point_inputs,
+ mask_inputs,
+ output_dict,
+ num_frames,
+ track_in_reverse,
+ prev_sam_mask_logits,
+ )
+
+ (
+ _,
+ _,
+ _,
+ low_res_masks,
+ high_res_masks,
+ obj_ptr,
+ object_score_logits,
+ ) = sam_outputs
+
+ current_out["pred_masks"] = low_res_masks
+ current_out["pred_masks_high_res"] = high_res_masks
+ current_out["obj_ptr"] = obj_ptr
+ if not self.training:
+ # Only add this in inference (to avoid unused param in activation checkpointing;
+ # it's mainly used in the demo to encode spatial memories w/ consolidated masks)
+ current_out["object_score_logits"] = object_score_logits
+
+ # Finally run the memory encoder on the predicted mask to encode
+ # it into a new memory feature (that can be used in future frames)
+ self._encode_memory_in_output(
+ current_vision_feats,
+ feat_sizes,
+ point_inputs,
+ run_mem_encoder,
+ high_res_masks,
+ object_score_logits,
+ current_out,
+ )
+
+ return current_out
+
+ def _use_multimask(self, is_init_cond_frame, point_inputs):
+ """Whether to use multimask output in the SAM head."""
+ num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1)
+ multimask_output = (
+ self.multimask_output_in_sam
+ and (is_init_cond_frame or self.multimask_output_for_tracking)
+ and (self.multimask_min_pt_num <= num_pts <= self.multimask_max_pt_num)
+ )
+ return multimask_output
+
+ def _apply_non_overlapping_constraints(self, pred_masks):
+ """
+ Apply non-overlapping constraints to the object scores in pred_masks. Here we
+ keep only the highest scoring object at each spatial location in pred_masks.
+ """
+ batch_size = pred_masks.size(0)
+ if batch_size == 1:
+ return pred_masks
+
+ device = pred_masks.device
+ # "max_obj_inds": object index of the object with the highest score at each location
+ max_obj_inds = torch.argmax(pred_masks, dim=0, keepdim=True)
+ # "batch_obj_inds": object index of each object slice (along dim 0) in `pred_masks`
+ batch_obj_inds = torch.arange(batch_size, device=device)[:, None, None, None]
+ keep = max_obj_inds == batch_obj_inds
+ # suppress overlapping regions' scores below -10.0 so that the foreground regions
+ # don't overlap (here sigmoid(-10.0)=4.5398e-05)
+ pred_masks = torch.where(keep, pred_masks, torch.clamp(pred_masks, max=-10.0))
+ return pred_masks
diff --git a/libs/sam2/modeling/sam2_utils.py b/libs/sam2/modeling/sam2_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e16caae3a9a49e451b2d03d1ee60c47f8e9ed23c
--- /dev/null
+++ b/libs/sam2/modeling/sam2_utils.py
@@ -0,0 +1,323 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+import copy
+from typing import Tuple
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from sam2.utils.misc import mask_to_box
+
+
+def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num):
+ """
+ Select up to `max_cond_frame_num` conditioning frames from `cond_frame_outputs`
+ that are temporally closest to the current frame at `frame_idx`. Here, we take
+ - a) the closest conditioning frame before `frame_idx` (if any);
+ - b) the closest conditioning frame after `frame_idx` (if any);
+ - c) any other temporally closest conditioning frames until reaching a total
+ of `max_cond_frame_num` conditioning frames.
+
+ Outputs:
+ - selected_outputs: selected items (keys & values) from `cond_frame_outputs`.
+ - unselected_outputs: items (keys & values) not selected in `cond_frame_outputs`.
+ """
+ if max_cond_frame_num == -1 or len(cond_frame_outputs) <= max_cond_frame_num:
+ selected_outputs = cond_frame_outputs
+ unselected_outputs = {}
+ else:
+ assert max_cond_frame_num >= 2, "we should allow using 2+ conditioning frames"
+ selected_outputs = {}
+
+ # the closest conditioning frame before `frame_idx` (if any)
+ idx_before = max((t for t in cond_frame_outputs if t < frame_idx), default=None)
+ if idx_before is not None:
+ selected_outputs[idx_before] = cond_frame_outputs[idx_before]
+
+ # the closest conditioning frame after `frame_idx` (if any)
+ idx_after = min((t for t in cond_frame_outputs if t >= frame_idx), default=None)
+ if idx_after is not None:
+ selected_outputs[idx_after] = cond_frame_outputs[idx_after]
+
+ # add other temporally closest conditioning frames until reaching a total
+ # of `max_cond_frame_num` conditioning frames.
+ num_remain = max_cond_frame_num - len(selected_outputs)
+ inds_remain = sorted(
+ (t for t in cond_frame_outputs if t not in selected_outputs),
+ key=lambda x: abs(x - frame_idx),
+ )[:num_remain]
+ selected_outputs.update((t, cond_frame_outputs[t]) for t in inds_remain)
+ unselected_outputs = {
+ t: v for t, v in cond_frame_outputs.items() if t not in selected_outputs
+ }
+
+ return selected_outputs, unselected_outputs
+
+
+def get_1d_sine_pe(pos_inds, dim, temperature=10000):
+ """
+ Get 1D sine positional embedding as in the original Transformer paper.
+ """
+ pe_dim = dim // 2
+ dim_t = torch.arange(pe_dim, dtype=torch.float32, device=pos_inds.device)
+ dim_t = temperature ** (2 * (dim_t // 2) / pe_dim)
+
+ pos_embed = pos_inds.unsqueeze(-1) / dim_t
+ pos_embed = torch.cat([pos_embed.sin(), pos_embed.cos()], dim=-1)
+ return pos_embed
+
+
+def get_activation_fn(activation):
+ """Return an activation function given a string"""
+ if activation == "relu":
+ return F.relu
+ if activation == "gelu":
+ return F.gelu
+ if activation == "glu":
+ return F.glu
+ raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
+
+
+def get_clones(module, N):
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
+
+
+class DropPath(nn.Module):
+ # adapted from https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py
+ def __init__(self, drop_prob=0.0, scale_by_keep=True):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+ self.scale_by_keep = scale_by_keep
+
+ def forward(self, x):
+ if self.drop_prob == 0.0 or not self.training:
+ return x
+ keep_prob = 1 - self.drop_prob
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1)
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
+ if keep_prob > 0.0 and self.scale_by_keep:
+ random_tensor.div_(keep_prob)
+ return x * random_tensor
+
+
+# Lightly adapted from
+# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
+class MLP(nn.Module):
+ def __init__(
+ self,
+ input_dim: int,
+ hidden_dim: int,
+ output_dim: int,
+ num_layers: int,
+ activation: nn.Module = nn.ReLU,
+ sigmoid_output: bool = False,
+ ) -> None:
+ super().__init__()
+ self.num_layers = num_layers
+ h = [hidden_dim] * (num_layers - 1)
+ self.layers = nn.ModuleList(
+ nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
+ )
+ self.sigmoid_output = sigmoid_output
+ self.act = activation()
+
+ def forward(self, x):
+ for i, layer in enumerate(self.layers):
+ x = self.act(layer(x)) if i < self.num_layers - 1 else layer(x)
+ if self.sigmoid_output:
+ x = F.sigmoid(x)
+ return x
+
+
+# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
+# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
+class LayerNorm2d(nn.Module):
+ def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(num_channels))
+ self.bias = nn.Parameter(torch.zeros(num_channels))
+ self.eps = eps
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ u = x.mean(1, keepdim=True)
+ s = (x - u).pow(2).mean(1, keepdim=True)
+ x = (x - u) / torch.sqrt(s + self.eps)
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
+ return x
+
+
+def sample_box_points(
+ masks: torch.Tensor,
+ noise: float = 0.1, # SAM default
+ noise_bound: int = 20, # SAM default
+ top_left_label: int = 2,
+ bottom_right_label: int = 3,
+) -> Tuple[np.array, np.array]:
+ """
+ Sample a noised version of the top left and bottom right corners of a given `bbox`
+
+ Inputs:
+ - masks: [B, 1, H,W] boxes, dtype=torch.Tensor
+ - noise: noise as a fraction of box width and height, dtype=float
+ - noise_bound: maximum amount of noise (in pure pixesl), dtype=int
+
+ Returns:
+ - box_coords: [B, num_pt, 2], contains (x, y) coordinates of top left and bottom right box corners, dtype=torch.float
+ - box_labels: [B, num_pt], label 2 is reserverd for top left and 3 for bottom right corners, dtype=torch.int32
+ """
+ device = masks.device
+ box_coords = mask_to_box(masks)
+ B, _, H, W = masks.shape
+ box_labels = torch.tensor(
+ [top_left_label, bottom_right_label], dtype=torch.int, device=device
+ ).repeat(B)
+ if noise > 0.0:
+ if not isinstance(noise_bound, torch.Tensor):
+ noise_bound = torch.tensor(noise_bound, device=device)
+ bbox_w = box_coords[..., 2] - box_coords[..., 0]
+ bbox_h = box_coords[..., 3] - box_coords[..., 1]
+ max_dx = torch.min(bbox_w * noise, noise_bound)
+ max_dy = torch.min(bbox_h * noise, noise_bound)
+ box_noise = 2 * torch.rand(B, 1, 4, device=device) - 1
+ box_noise = box_noise * torch.stack((max_dx, max_dy, max_dx, max_dy), dim=-1)
+
+ box_coords = box_coords + box_noise
+ img_bounds = (
+ torch.tensor([W, H, W, H], device=device) - 1
+ ) # uncentered pixel coords
+ box_coords.clamp_(torch.zeros_like(img_bounds), img_bounds) # In place clamping
+
+ box_coords = box_coords.reshape(-1, 2, 2) # always 2 points
+ box_labels = box_labels.reshape(-1, 2)
+ return box_coords, box_labels
+
+
+def sample_random_points_from_errors(gt_masks, pred_masks, num_pt=1):
+ """
+ Sample `num_pt` random points (along with their labels) independently from the error regions.
+
+ Inputs:
+ - gt_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool
+ - pred_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool or None
+ - num_pt: int, number of points to sample independently for each of the B error maps
+
+ Outputs:
+ - points: [B, num_pt, 2], dtype=torch.float, contains (x, y) coordinates of each sampled point
+ - labels: [B, num_pt], dtype=torch.int32, where 1 means positive clicks and 0 means
+ negative clicks
+ """
+ if pred_masks is None: # if pred_masks is not provided, treat it as empty
+ pred_masks = torch.zeros_like(gt_masks)
+ assert gt_masks.dtype == torch.bool and gt_masks.size(1) == 1
+ assert pred_masks.dtype == torch.bool and pred_masks.shape == gt_masks.shape
+ assert num_pt >= 0
+
+ B, _, H_im, W_im = gt_masks.shape
+ device = gt_masks.device
+
+ # false positive region, a new point sampled in this region should have
+ # negative label to correct the FP error
+ fp_masks = ~gt_masks & pred_masks
+ # false negative region, a new point sampled in this region should have
+ # positive label to correct the FN error
+ fn_masks = gt_masks & ~pred_masks
+ # whether the prediction completely match the ground-truth on each mask
+ all_correct = torch.all((gt_masks == pred_masks).flatten(2), dim=2)
+ all_correct = all_correct[..., None, None]
+
+ # channel 0 is FP map, while channel 1 is FN map
+ pts_noise = torch.rand(B, num_pt, H_im, W_im, 2, device=device)
+ # sample a negative new click from FP region or a positive new click
+ # from FN region, depend on where the maximum falls,
+ # and in case the predictions are all correct (no FP or FN), we just
+ # sample a negative click from the background region
+ pts_noise[..., 0] *= fp_masks | (all_correct & ~gt_masks)
+ pts_noise[..., 1] *= fn_masks
+ pts_idx = pts_noise.flatten(2).argmax(dim=2)
+ labels = (pts_idx % 2).to(torch.int32)
+ pts_idx = pts_idx // 2
+ pts_x = pts_idx % W_im
+ pts_y = pts_idx // W_im
+ points = torch.stack([pts_x, pts_y], dim=2).to(torch.float)
+ return points, labels
+
+
+def sample_one_point_from_error_center(gt_masks, pred_masks, padding=True):
+ """
+ Sample 1 random point (along with its label) from the center of each error region,
+ that is, the point with the largest distance to the boundary of each error region.
+ This is the RITM sampling method from https://github.com/saic-vul/ritm_interactive_segmentation/blob/master/isegm/inference/clicker.py
+
+ Inputs:
+ - gt_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool
+ - pred_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool or None
+ - padding: if True, pad with boundary of 1 px for distance transform
+
+ Outputs:
+ - points: [B, 1, 2], dtype=torch.float, contains (x, y) coordinates of each sampled point
+ - labels: [B, 1], dtype=torch.int32, where 1 means positive clicks and 0 means negative clicks
+ """
+ import cv2
+
+ if pred_masks is None:
+ pred_masks = torch.zeros_like(gt_masks)
+ assert gt_masks.dtype == torch.bool and gt_masks.size(1) == 1
+ assert pred_masks.dtype == torch.bool and pred_masks.shape == gt_masks.shape
+
+ B, _, _, W_im = gt_masks.shape
+ device = gt_masks.device
+
+ # false positive region, a new point sampled in this region should have
+ # negative label to correct the FP error
+ fp_masks = ~gt_masks & pred_masks
+ # false negative region, a new point sampled in this region should have
+ # positive label to correct the FN error
+ fn_masks = gt_masks & ~pred_masks
+
+ fp_masks = fp_masks.cpu().numpy()
+ fn_masks = fn_masks.cpu().numpy()
+ points = torch.zeros(B, 1, 2, dtype=torch.float)
+ labels = torch.ones(B, 1, dtype=torch.int32)
+ for b in range(B):
+ fn_mask = fn_masks[b, 0]
+ fp_mask = fp_masks[b, 0]
+ if padding:
+ fn_mask = np.pad(fn_mask, ((1, 1), (1, 1)), "constant")
+ fp_mask = np.pad(fp_mask, ((1, 1), (1, 1)), "constant")
+ # compute the distance of each point in FN/FP region to its boundary
+ fn_mask_dt = cv2.distanceTransform(fn_mask.astype(np.uint8), cv2.DIST_L2, 0)
+ fp_mask_dt = cv2.distanceTransform(fp_mask.astype(np.uint8), cv2.DIST_L2, 0)
+ if padding:
+ fn_mask_dt = fn_mask_dt[1:-1, 1:-1]
+ fp_mask_dt = fp_mask_dt[1:-1, 1:-1]
+
+ # take the point in FN/FP region with the largest distance to its boundary
+ fn_mask_dt_flat = fn_mask_dt.reshape(-1)
+ fp_mask_dt_flat = fp_mask_dt.reshape(-1)
+ fn_argmax = np.argmax(fn_mask_dt_flat)
+ fp_argmax = np.argmax(fp_mask_dt_flat)
+ is_positive = fn_mask_dt_flat[fn_argmax] > fp_mask_dt_flat[fp_argmax]
+ pt_idx = fn_argmax if is_positive else fp_argmax
+ points[b, 0, 0] = pt_idx % W_im # x
+ points[b, 0, 1] = pt_idx // W_im # y
+ labels[b, 0] = int(is_positive)
+
+ points = points.to(device)
+ labels = labels.to(device)
+ return points, labels
+
+
+def get_next_point(gt_masks, pred_masks, method):
+ if method == "uniform":
+ return sample_random_points_from_errors(gt_masks, pred_masks)
+ elif method == "center":
+ return sample_one_point_from_error_center(gt_masks, pred_masks)
+ else:
+ raise ValueError(f"unknown sampling method {method}")
diff --git a/libs/sam2/sam2_hiera_b+.yaml b/libs/sam2/sam2_hiera_b+.yaml
new file mode 120000
index 0000000000000000000000000000000000000000..998d9c98c9ff4e8ddd55deff72aa0d9067977418
--- /dev/null
+++ b/libs/sam2/sam2_hiera_b+.yaml
@@ -0,0 +1 @@
+configs/sam2/sam2_hiera_b+.yaml
\ No newline at end of file
diff --git a/libs/sam2/sam2_hiera_l.yaml b/libs/sam2/sam2_hiera_l.yaml
new file mode 120000
index 0000000000000000000000000000000000000000..c0e7e58e1951d5c55a3a3ebe6b803dd814cf9d86
--- /dev/null
+++ b/libs/sam2/sam2_hiera_l.yaml
@@ -0,0 +1 @@
+configs/sam2/sam2_hiera_l.yaml
\ No newline at end of file
diff --git a/libs/sam2/sam2_hiera_s.yaml b/libs/sam2/sam2_hiera_s.yaml
new file mode 120000
index 0000000000000000000000000000000000000000..41896a26beb2aa831d18b0bf3c349ed43deeef68
--- /dev/null
+++ b/libs/sam2/sam2_hiera_s.yaml
@@ -0,0 +1 @@
+configs/sam2/sam2_hiera_s.yaml
\ No newline at end of file
diff --git a/libs/sam2/sam2_hiera_t.yaml b/libs/sam2/sam2_hiera_t.yaml
new file mode 120000
index 0000000000000000000000000000000000000000..71ff3abbb1e11f8b82100a0a1d63cb267eefe52a
--- /dev/null
+++ b/libs/sam2/sam2_hiera_t.yaml
@@ -0,0 +1 @@
+configs/sam2/sam2_hiera_t.yaml
\ No newline at end of file
diff --git a/libs/sam2/sam2_image_predictor.py b/libs/sam2/sam2_image_predictor.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b43ef536e5489e6099e90a470bfcbb5bc0aadcc
--- /dev/null
+++ b/libs/sam2/sam2_image_predictor.py
@@ -0,0 +1,466 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import logging
+
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+from PIL.Image import Image
+
+from sam2.modeling.sam2_base import SAM2Base
+
+from sam2.utils.transforms import SAM2Transforms
+
+
+class SAM2ImagePredictor:
+ def __init__(
+ self,
+ sam_model: SAM2Base,
+ mask_threshold=0.0,
+ max_hole_area=0.0,
+ max_sprinkle_area=0.0,
+ **kwargs,
+ ) -> None:
+ """
+ Uses SAM-2 to calculate the image embedding for an image, and then
+ allow repeated, efficient mask prediction given prompts.
+
+ Arguments:
+ sam_model (Sam-2): The model to use for mask prediction.
+ mask_threshold (float): The threshold to use when converting mask logits
+ to binary masks. Masks are thresholded at 0 by default.
+ max_hole_area (int): If max_hole_area > 0, we fill small holes in up to
+ the maximum area of max_hole_area in low_res_masks.
+ max_sprinkle_area (int): If max_sprinkle_area > 0, we remove small sprinkles up to
+ the maximum area of max_sprinkle_area in low_res_masks.
+ """
+ super().__init__()
+ self.model = sam_model
+ self._transforms = SAM2Transforms(
+ resolution=self.model.image_size,
+ mask_threshold=mask_threshold,
+ max_hole_area=max_hole_area,
+ max_sprinkle_area=max_sprinkle_area,
+ )
+
+ # Predictor state
+ self._is_image_set = False
+ self._features = None
+ self._orig_hw = None
+ # Whether the predictor is set for single image or a batch of images
+ self._is_batch = False
+
+ # Predictor config
+ self.mask_threshold = mask_threshold
+
+ # Spatial dim for backbone feature maps
+ self._bb_feat_sizes = [
+ (256, 256),
+ (128, 128),
+ (64, 64),
+ ]
+
+ @classmethod
+ def from_pretrained(cls, model_id: str, cache_dir, device) -> "SAM2ImagePredictor":
+ """
+ Load a pretrained model from the Hugging Face hub.
+
+ Arguments:
+ model_id (str): The Hugging Face repository ID.
+ **kwargs: Additional arguments to pass to the model constructor.
+
+ Returns:
+ (SAM2ImagePredictor): The loaded model.
+ """
+ from sam2.build_sam import build_sam2_hf
+
+ sam_model = build_sam2_hf(model_id, cache_dir, device)
+ return cls(sam_model)
+
+ @torch.no_grad()
+ def set_image(
+ self,
+ image: Union[np.ndarray, Image],
+ ) -> None:
+ """
+ Calculates the image embeddings for the provided image, allowing
+ masks to be predicted with the 'predict' method.
+
+ Arguments:
+ image (np.ndarray or PIL Image): The input image to embed in RGB format. The image should be in HWC format if np.ndarray, or WHC format if PIL Image
+ with pixel values in [0, 255].
+ image_format (str): The color format of the image, in ['RGB', 'BGR'].
+ """
+ self.reset_predictor()
+ # Transform the image to the form expected by the model
+ if isinstance(image, np.ndarray):
+ logging.info("For numpy array image, we assume (HxWxC) format")
+ self._orig_hw = [image.shape[:2]]
+ elif isinstance(image, Image):
+ w, h = image.size
+ self._orig_hw = [(h, w)]
+ else:
+ raise NotImplementedError("Image format not supported")
+
+ input_image = self._transforms(image)
+ input_image = input_image[None, ...].to(self.device)
+
+ assert (
+ len(input_image.shape) == 4 and input_image.shape[1] == 3
+ ), f"input_image must be of size 1x3xHxW, got {input_image.shape}"
+ logging.info("Computing image embeddings for the provided image...")
+ backbone_out = self.model.forward_image(input_image)
+ _, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out)
+ # Add no_mem_embed, which is added to the lowest rest feat. map during training on videos
+ if self.model.directly_add_no_mem_embed:
+ vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed
+
+ feats = [
+ feat.permute(1, 2, 0).view(1, -1, *feat_size)
+ for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1])
+ ][::-1]
+ self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
+ self._is_image_set = True
+ logging.info("Image embeddings computed.")
+
+ @torch.no_grad()
+ def set_image_batch(
+ self,
+ image_list: List[Union[np.ndarray]],
+ ) -> None:
+ """
+ Calculates the image embeddings for the provided image batch, allowing
+ masks to be predicted with the 'predict_batch' method.
+
+ Arguments:
+ image_list (List[np.ndarray]): The input images to embed in RGB format. The image should be in HWC format if np.ndarray
+ with pixel values in [0, 255].
+ """
+ self.reset_predictor()
+ assert isinstance(image_list, list)
+ self._orig_hw = []
+ for image in image_list:
+ assert isinstance(
+ image, np.ndarray
+ ), "Images are expected to be an np.ndarray in RGB format, and of shape HWC"
+ self._orig_hw.append(image.shape[:2])
+ # Transform the image to the form expected by the model
+ img_batch = self._transforms.forward_batch(image_list)
+ img_batch = img_batch.to(self.device)
+ batch_size = img_batch.shape[0]
+ assert (
+ len(img_batch.shape) == 4 and img_batch.shape[1] == 3
+ ), f"img_batch must be of size Bx3xHxW, got {img_batch.shape}"
+ logging.info("Computing image embeddings for the provided images...")
+ backbone_out = self.model.forward_image(img_batch)
+ _, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out)
+ # Add no_mem_embed, which is added to the lowest rest feat. map during training on videos
+ if self.model.directly_add_no_mem_embed:
+ vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed
+
+ feats = [
+ feat.permute(1, 2, 0).view(batch_size, -1, *feat_size)
+ for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1])
+ ][::-1]
+ self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
+ self._is_image_set = True
+ self._is_batch = True
+ logging.info("Image embeddings computed.")
+
+ def predict_batch(
+ self,
+ point_coords_batch: List[np.ndarray] = None,
+ point_labels_batch: List[np.ndarray] = None,
+ box_batch: List[np.ndarray] = None,
+ mask_input_batch: List[np.ndarray] = None,
+ multimask_output: bool = True,
+ return_logits: bool = False,
+ normalize_coords=True,
+ ) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]:
+ """This function is very similar to predict(...), however it is used for batched mode, when the model is expected to generate predictions on multiple images.
+ It returns a tuple of lists of masks, ious, and low_res_masks_logits.
+ """
+ assert self._is_batch, "This function should only be used when in batched mode"
+ if not self._is_image_set:
+ raise RuntimeError(
+ "An image must be set with .set_image_batch(...) before mask prediction."
+ )
+ num_images = len(self._features["image_embed"])
+ all_masks = []
+ all_ious = []
+ all_low_res_masks = []
+ for img_idx in range(num_images):
+ # Transform input prompts
+ point_coords = (
+ point_coords_batch[img_idx] if point_coords_batch is not None else None
+ )
+ point_labels = (
+ point_labels_batch[img_idx] if point_labels_batch is not None else None
+ )
+ box = box_batch[img_idx] if box_batch is not None else None
+ mask_input = (
+ mask_input_batch[img_idx] if mask_input_batch is not None else None
+ )
+ mask_input, unnorm_coords, labels, unnorm_box = self._prep_prompts(
+ point_coords,
+ point_labels,
+ box,
+ mask_input,
+ normalize_coords,
+ img_idx=img_idx,
+ )
+ masks, iou_predictions, low_res_masks = self._predict(
+ unnorm_coords,
+ labels,
+ unnorm_box,
+ mask_input,
+ multimask_output,
+ return_logits=return_logits,
+ img_idx=img_idx,
+ )
+ masks_np = masks.squeeze(0).float().detach().cpu().numpy()
+ iou_predictions_np = (
+ iou_predictions.squeeze(0).float().detach().cpu().numpy()
+ )
+ low_res_masks_np = low_res_masks.squeeze(0).float().detach().cpu().numpy()
+ all_masks.append(masks_np)
+ all_ious.append(iou_predictions_np)
+ all_low_res_masks.append(low_res_masks_np)
+
+ return all_masks, all_ious, all_low_res_masks
+
+ def predict(
+ self,
+ point_coords: Optional[np.ndarray] = None,
+ point_labels: Optional[np.ndarray] = None,
+ box: Optional[np.ndarray] = None,
+ mask_input: Optional[np.ndarray] = None,
+ multimask_output: bool = True,
+ return_logits: bool = False,
+ normalize_coords=True,
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
+ """
+ Predict masks for the given input prompts, using the currently set image.
+
+ Arguments:
+ point_coords (np.ndarray or None): A Nx2 array of point prompts to the
+ model. Each point is in (X,Y) in pixels.
+ point_labels (np.ndarray or None): A length N array of labels for the
+ point prompts. 1 indicates a foreground point and 0 indicates a
+ background point.
+ box (np.ndarray or None): A length 4 array given a box prompt to the
+ model, in XYXY format.
+ mask_input (np.ndarray): A low resolution mask input to the model, typically
+ coming from a previous prediction iteration. Has form 1xHxW, where
+ for SAM, H=W=256.
+ multimask_output (bool): If true, the model will return three masks.
+ For ambiguous input prompts (such as a single click), this will often
+ produce better masks than a single prediction. If only a single
+ mask is needed, the model's predicted quality score can be used
+ to select the best mask. For non-ambiguous prompts, such as multiple
+ input prompts, multimask_output=False can give better results.
+ return_logits (bool): If true, returns un-thresholded masks logits
+ instead of a binary mask.
+ normalize_coords (bool): If true, the point coordinates will be normalized to the range [0,1] and point_coords is expected to be wrt. image dimensions.
+
+ Returns:
+ (np.ndarray): The output masks in CxHxW format, where C is the
+ number of masks, and (H, W) is the original image size.
+ (np.ndarray): An array of length C containing the model's
+ predictions for the quality of each mask.
+ (np.ndarray): An array of shape CxHxW, where C is the number
+ of masks and H=W=256. These low resolution logits can be passed to
+ a subsequent iteration as mask input.
+ """
+ if not self._is_image_set:
+ raise RuntimeError(
+ "An image must be set with .set_image(...) before mask prediction."
+ )
+
+ # Transform input prompts
+
+ mask_input, unnorm_coords, labels, unnorm_box = self._prep_prompts(
+ point_coords, point_labels, box, mask_input, normalize_coords
+ )
+
+ masks, iou_predictions, low_res_masks = self._predict(
+ unnorm_coords,
+ labels,
+ unnorm_box,
+ mask_input,
+ multimask_output,
+ return_logits=return_logits,
+ )
+
+ masks_np = masks.squeeze(0).float().detach().cpu().numpy()
+ iou_predictions_np = iou_predictions.squeeze(0).float().detach().cpu().numpy()
+ low_res_masks_np = low_res_masks.squeeze(0).float().detach().cpu().numpy()
+ return masks_np, iou_predictions_np, low_res_masks_np
+
+ def _prep_prompts(
+ self, point_coords, point_labels, box, mask_logits, normalize_coords, img_idx=-1
+ ):
+
+ unnorm_coords, labels, unnorm_box, mask_input = None, None, None, None
+ if point_coords is not None:
+ assert (
+ point_labels is not None
+ ), "point_labels must be supplied if point_coords is supplied."
+ point_coords = torch.as_tensor(
+ point_coords, dtype=torch.float, device=self.device
+ )
+ unnorm_coords = self._transforms.transform_coords(
+ point_coords, normalize=normalize_coords, orig_hw=self._orig_hw[img_idx]
+ )
+ labels = torch.as_tensor(point_labels, dtype=torch.int, device=self.device)
+ if len(unnorm_coords.shape) == 2:
+ unnorm_coords, labels = unnorm_coords[None, ...], labels[None, ...]
+ if box is not None:
+ box = torch.as_tensor(box, dtype=torch.float, device=self.device)
+ unnorm_box = self._transforms.transform_boxes(
+ box, normalize=normalize_coords, orig_hw=self._orig_hw[img_idx]
+ ) # Bx2x2
+ if mask_logits is not None:
+ mask_input = torch.as_tensor(
+ mask_logits, dtype=torch.float, device=self.device
+ )
+ if len(mask_input.shape) == 3:
+ mask_input = mask_input[None, :, :, :]
+ return mask_input, unnorm_coords, labels, unnorm_box
+
+ @torch.no_grad()
+ def _predict(
+ self,
+ point_coords: Optional[torch.Tensor],
+ point_labels: Optional[torch.Tensor],
+ boxes: Optional[torch.Tensor] = None,
+ mask_input: Optional[torch.Tensor] = None,
+ multimask_output: bool = True,
+ return_logits: bool = False,
+ img_idx: int = -1,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ Predict masks for the given input prompts, using the currently set image.
+ Input prompts are batched torch tensors and are expected to already be
+ transformed to the input frame using SAM2Transforms.
+
+ Arguments:
+ point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
+ model. Each point is in (X,Y) in pixels.
+ point_labels (torch.Tensor or None): A BxN array of labels for the
+ point prompts. 1 indicates a foreground point and 0 indicates a
+ background point.
+ boxes (np.ndarray or None): A Bx4 array given a box prompt to the
+ model, in XYXY format.
+ mask_input (np.ndarray): A low resolution mask input to the model, typically
+ coming from a previous prediction iteration. Has form Bx1xHxW, where
+ for SAM, H=W=256. Masks returned by a previous iteration of the
+ predict method do not need further transformation.
+ multimask_output (bool): If true, the model will return three masks.
+ For ambiguous input prompts (such as a single click), this will often
+ produce better masks than a single prediction. If only a single
+ mask is needed, the model's predicted quality score can be used
+ to select the best mask. For non-ambiguous prompts, such as multiple
+ input prompts, multimask_output=False can give better results.
+ return_logits (bool): If true, returns un-thresholded masks logits
+ instead of a binary mask.
+
+ Returns:
+ (torch.Tensor): The output masks in BxCxHxW format, where C is the
+ number of masks, and (H, W) is the original image size.
+ (torch.Tensor): An array of shape BxC containing the model's
+ predictions for the quality of each mask.
+ (torch.Tensor): An array of shape BxCxHxW, where C is the number
+ of masks and H=W=256. These low res logits can be passed to
+ a subsequent iteration as mask input.
+ """
+ if not self._is_image_set:
+ raise RuntimeError(
+ "An image must be set with .set_image(...) before mask prediction."
+ )
+
+ if point_coords is not None:
+ concat_points = (point_coords, point_labels)
+ else:
+ concat_points = None
+
+ # Embed prompts
+ if boxes is not None:
+ box_coords = boxes.reshape(-1, 2, 2)
+ box_labels = torch.tensor([[2, 3]], dtype=torch.int, device=boxes.device)
+ box_labels = box_labels.repeat(boxes.size(0), 1)
+ # we merge "boxes" and "points" into a single "concat_points" input (where
+ # boxes are added at the beginning) to sam_prompt_encoder
+ if concat_points is not None:
+ concat_coords = torch.cat([box_coords, concat_points[0]], dim=1)
+ concat_labels = torch.cat([box_labels, concat_points[1]], dim=1)
+ concat_points = (concat_coords, concat_labels)
+ else:
+ concat_points = (box_coords, box_labels)
+
+ sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder(
+ points=concat_points,
+ boxes=None,
+ masks=mask_input,
+ )
+
+ # Predict masks
+ batched_mode = (
+ concat_points is not None and concat_points[0].shape[0] > 1
+ ) # multi object prediction
+ high_res_features = [
+ feat_level[img_idx].unsqueeze(0)
+ for feat_level in self._features["high_res_feats"]
+ ]
+ low_res_masks, iou_predictions, _, _ = self.model.sam_mask_decoder(
+ image_embeddings=self._features["image_embed"][img_idx].unsqueeze(0),
+ image_pe=self.model.sam_prompt_encoder.get_dense_pe(),
+ sparse_prompt_embeddings=sparse_embeddings,
+ dense_prompt_embeddings=dense_embeddings,
+ multimask_output=multimask_output,
+ repeat_image=batched_mode,
+ high_res_features=high_res_features,
+ )
+
+ # Upscale the masks to the original image resolution
+ masks = self._transforms.postprocess_masks(
+ low_res_masks, self._orig_hw[img_idx]
+ )
+ low_res_masks = torch.clamp(low_res_masks, -32.0, 32.0)
+ if not return_logits:
+ masks = masks > self.mask_threshold
+
+ return masks, iou_predictions, low_res_masks
+
+ def get_image_embedding(self) -> torch.Tensor:
+ """
+ Returns the image embeddings for the currently set image, with
+ shape 1xCxHxW, where C is the embedding dimension and (H,W) are
+ the embedding spatial dimension of SAM (typically C=256, H=W=64).
+ """
+ if not self._is_image_set:
+ raise RuntimeError(
+ "An image must be set with .set_image(...) to generate an embedding."
+ )
+ assert (
+ self._features is not None
+ ), "Features must exist if an image has been set."
+ return self._features["image_embed"]
+
+ @property
+ def device(self) -> torch.device:
+ return self.model.device
+
+ def reset_predictor(self) -> None:
+ """
+ Resets the image embeddings and other state variables.
+ """
+ self._is_image_set = False
+ self._features = None
+ self._orig_hw = None
+ self._is_batch = False
diff --git a/libs/sam2/sam2_video_predictor.py b/libs/sam2/sam2_video_predictor.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7e01ccf972491904b013526333826b337354db1
--- /dev/null
+++ b/libs/sam2/sam2_video_predictor.py
@@ -0,0 +1,1172 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import warnings
+from collections import OrderedDict
+
+import torch
+
+from tqdm import tqdm
+
+from sam2.modeling.sam2_base import NO_OBJ_SCORE, SAM2Base
+from sam2.utils.misc import concat_points, fill_holes_in_mask_scores, load_video_frames
+
+
+class SAM2VideoPredictor(SAM2Base):
+ """The predictor class to handle user interactions and manage inference states."""
+
+ def __init__(
+ self,
+ fill_hole_area=0,
+ # whether to apply non-overlapping constraints on the output object masks
+ non_overlap_masks=False,
+ # whether to clear non-conditioning memory of the surrounding frames (which may contain outdated information) after adding correction clicks;
+ # note that this would only apply to *single-object tracking* unless `clear_non_cond_mem_for_multi_obj` is also set to True)
+ clear_non_cond_mem_around_input=False,
+ # whether to also clear non-conditioning memory of the surrounding frames (only effective when `clear_non_cond_mem_around_input` is True).
+ clear_non_cond_mem_for_multi_obj=False,
+ # if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click
+ # if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames
+ add_all_frames_to_correct_as_cond=False,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.fill_hole_area = fill_hole_area
+ self.non_overlap_masks = non_overlap_masks
+ self.clear_non_cond_mem_around_input = clear_non_cond_mem_around_input
+ self.clear_non_cond_mem_for_multi_obj = clear_non_cond_mem_for_multi_obj
+ self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond
+
+ @torch.inference_mode()
+ def init_state(
+ self,
+ video_path,
+ offload_video_to_cpu=False,
+ offload_state_to_cpu=False,
+ async_loading_frames=False,
+ ):
+ """Initialize an inference state."""
+ compute_device = self.device # device of the model
+ images, video_height, video_width = load_video_frames(
+ video_path=video_path,
+ image_size=self.image_size,
+ offload_video_to_cpu=offload_video_to_cpu,
+ async_loading_frames=async_loading_frames,
+ compute_device=compute_device,
+ )
+ inference_state = {}
+ inference_state["images"] = images
+ inference_state["num_frames"] = len(images)
+ # whether to offload the video frames to CPU memory
+ # turning on this option saves the GPU memory with only a very small overhead
+ inference_state["offload_video_to_cpu"] = offload_video_to_cpu
+ # whether to offload the inference state to CPU memory
+ # turning on this option saves the GPU memory at the cost of a lower tracking fps
+ # (e.g. in a test case of 768x768 model, fps dropped from 27 to 24 when tracking one object
+ # and from 24 to 21 when tracking two objects)
+ inference_state["offload_state_to_cpu"] = offload_state_to_cpu
+ # the original video height and width, used for resizing final output scores
+ inference_state["video_height"] = video_height
+ inference_state["video_width"] = video_width
+ inference_state["device"] = compute_device
+ if offload_state_to_cpu:
+ inference_state["storage_device"] = torch.device("cpu")
+ else:
+ inference_state["storage_device"] = compute_device
+ # inputs on each frame
+ inference_state["point_inputs_per_obj"] = {}
+ inference_state["mask_inputs_per_obj"] = {}
+ # visual features on a small number of recently visited frames for quick interactions
+ inference_state["cached_features"] = {}
+ # values that don't change across frames (so we only need to hold one copy of them)
+ inference_state["constants"] = {}
+ # mapping between client-side object id and model-side object index
+ inference_state["obj_id_to_idx"] = OrderedDict()
+ inference_state["obj_idx_to_id"] = OrderedDict()
+ inference_state["obj_ids"] = []
+ # A storage to hold the model's tracking results and states on each frame
+ inference_state["output_dict"] = {
+ "cond_frame_outputs": {}, # dict containing {frame_idx: }
+ "non_cond_frame_outputs": {}, # dict containing {frame_idx: }
+ }
+ # Slice (view) of each object tracking results, sharing the same memory with "output_dict"
+ inference_state["output_dict_per_obj"] = {}
+ # A temporary storage to hold new outputs when user interact with a frame
+ # to add clicks or mask (it's merged into "output_dict" before propagation starts)
+ inference_state["temp_output_dict_per_obj"] = {}
+ # Frames that already holds consolidated outputs from click or mask inputs
+ # (we directly use their consolidated outputs during tracking)
+ inference_state["consolidated_frame_inds"] = {
+ "cond_frame_outputs": set(), # set containing frame indices
+ "non_cond_frame_outputs": set(), # set containing frame indices
+ }
+ # metadata for each tracking frame (e.g. which direction it's tracked)
+ inference_state["tracking_has_started"] = False
+ inference_state["frames_already_tracked"] = {}
+ # Warm up the visual backbone and cache the image feature on frame 0
+ self._get_image_feature(inference_state, frame_idx=0, batch_size=1)
+ return inference_state
+
+ @classmethod
+ def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2VideoPredictor":
+ """
+ Load a pretrained model from the Hugging Face hub.
+
+ Arguments:
+ model_id (str): The Hugging Face repository ID.
+ **kwargs: Additional arguments to pass to the model constructor.
+
+ Returns:
+ (SAM2VideoPredictor): The loaded model.
+ """
+ from sam2.build_sam import build_sam2_video_predictor_hf
+
+ sam_model = build_sam2_video_predictor_hf(model_id, **kwargs)
+ return sam_model
+
+ def _obj_id_to_idx(self, inference_state, obj_id):
+ """Map client-side object id to model-side object index."""
+ obj_idx = inference_state["obj_id_to_idx"].get(obj_id, None)
+ if obj_idx is not None:
+ return obj_idx
+
+ # This is a new object id not sent to the server before. We only allow adding
+ # new objects *before* the tracking starts.
+ allow_new_object = not inference_state["tracking_has_started"]
+ if allow_new_object:
+ # get the next object slot
+ obj_idx = len(inference_state["obj_id_to_idx"])
+ inference_state["obj_id_to_idx"][obj_id] = obj_idx
+ inference_state["obj_idx_to_id"][obj_idx] = obj_id
+ inference_state["obj_ids"] = list(inference_state["obj_id_to_idx"])
+ # set up input and output structures for this object
+ inference_state["point_inputs_per_obj"][obj_idx] = {}
+ inference_state["mask_inputs_per_obj"][obj_idx] = {}
+ inference_state["output_dict_per_obj"][obj_idx] = {
+ "cond_frame_outputs": {}, # dict containing {frame_idx: }
+ "non_cond_frame_outputs": {}, # dict containing {frame_idx: }
+ }
+ inference_state["temp_output_dict_per_obj"][obj_idx] = {
+ "cond_frame_outputs": {}, # dict containing {frame_idx: }
+ "non_cond_frame_outputs": {}, # dict containing {frame_idx: }
+ }
+ return obj_idx
+ else:
+ raise RuntimeError(
+ f"Cannot add new object id {obj_id} after tracking starts. "
+ f"All existing object ids: {inference_state['obj_ids']}. "
+ f"Please call 'reset_state' to restart from scratch."
+ )
+
+ def _obj_idx_to_id(self, inference_state, obj_idx):
+ """Map model-side object index to client-side object id."""
+ return inference_state["obj_idx_to_id"][obj_idx]
+
+ def _get_obj_num(self, inference_state):
+ """Get the total number of unique object ids received so far in this session."""
+ return len(inference_state["obj_idx_to_id"])
+
+ @torch.inference_mode()
+ def add_new_points_or_box(
+ self,
+ inference_state,
+ frame_idx,
+ obj_id,
+ points=None,
+ labels=None,
+ clear_old_points=True,
+ normalize_coords=True,
+ box=None,
+ ):
+ """Add new points to a frame."""
+ obj_idx = self._obj_id_to_idx(inference_state, obj_id)
+ point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx]
+ mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx]
+
+ if (points is not None) != (labels is not None):
+ raise ValueError("points and labels must be provided together")
+ if points is None and box is None:
+ raise ValueError("at least one of points or box must be provided as input")
+
+ if points is None:
+ points = torch.zeros(0, 2, dtype=torch.float32)
+ elif not isinstance(points, torch.Tensor):
+ points = torch.tensor(points, dtype=torch.float32)
+ if labels is None:
+ labels = torch.zeros(0, dtype=torch.int32)
+ elif not isinstance(labels, torch.Tensor):
+ labels = torch.tensor(labels, dtype=torch.int32)
+ if points.dim() == 2:
+ points = points.unsqueeze(0) # add batch dimension
+ if labels.dim() == 1:
+ labels = labels.unsqueeze(0) # add batch dimension
+
+ # If `box` is provided, we add it as the first two points with labels 2 and 3
+ # along with the user-provided points (consistent with how SAM 2 is trained).
+ if box is not None:
+ if not clear_old_points:
+ raise ValueError(
+ "cannot add box without clearing old points, since "
+ "box prompt must be provided before any point prompt "
+ "(please use clear_old_points=True instead)"
+ )
+ if inference_state["tracking_has_started"]:
+ warnings.warn(
+ "You are adding a box after tracking starts. SAM 2 may not always be "
+ "able to incorporate a box prompt for *refinement*. If you intend to "
+ "use box prompt as an *initial* input before tracking, please call "
+ "'reset_state' on the inference state to restart from scratch.",
+ category=UserWarning,
+ stacklevel=2,
+ )
+ if not isinstance(box, torch.Tensor):
+ box = torch.tensor(box, dtype=torch.float32, device=points.device)
+ box_coords = box.reshape(1, 2, 2)
+ box_labels = torch.tensor([2, 3], dtype=torch.int32, device=labels.device)
+ box_labels = box_labels.reshape(1, 2)
+ points = torch.cat([box_coords, points], dim=1)
+ labels = torch.cat([box_labels, labels], dim=1)
+
+ if normalize_coords:
+ video_H = inference_state["video_height"]
+ video_W = inference_state["video_width"]
+ points = points / torch.tensor([video_W, video_H]).to(points.device)
+ # scale the (normalized) coordinates by the model's internal image size
+ points = points * self.image_size
+ points = points.to(inference_state["device"])
+ labels = labels.to(inference_state["device"])
+
+ if not clear_old_points:
+ point_inputs = point_inputs_per_frame.get(frame_idx, None)
+ else:
+ point_inputs = None
+ point_inputs = concat_points(point_inputs, points, labels)
+
+ point_inputs_per_frame[frame_idx] = point_inputs
+ mask_inputs_per_frame.pop(frame_idx, None)
+ # If this frame hasn't been tracked before, we treat it as an initial conditioning
+ # frame, meaning that the inputs points are to generate segments on this frame without
+ # using any memory from other frames, like in SAM. Otherwise (if it has been tracked),
+ # the input points will be used to correct the already tracked masks.
+ is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"]
+ # whether to track in reverse time order
+ if is_init_cond_frame:
+ reverse = False
+ else:
+ reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"]
+ obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
+ obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
+ # Add a frame to conditioning output if it's an initial conditioning frame or
+ # if the model sees all frames receiving clicks/mask as conditioning frames.
+ is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond
+ storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
+
+ # Get any previously predicted mask logits on this object and feed it along with
+ # the new clicks into the SAM mask decoder.
+ prev_sam_mask_logits = None
+ # lookup temporary output dict first, which contains the most recent output
+ # (if not found, then lookup conditioning and non-conditioning frame output)
+ prev_out = obj_temp_output_dict[storage_key].get(frame_idx)
+ if prev_out is None:
+ prev_out = obj_output_dict["cond_frame_outputs"].get(frame_idx)
+ if prev_out is None:
+ prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx)
+
+ if prev_out is not None and prev_out["pred_masks"] is not None:
+ device = inference_state["device"]
+ prev_sam_mask_logits = prev_out["pred_masks"].to(device, non_blocking=True)
+ # Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues.
+ prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0)
+ current_out, _ = self._run_single_frame_inference(
+ inference_state=inference_state,
+ output_dict=obj_output_dict, # run on the slice of a single object
+ frame_idx=frame_idx,
+ batch_size=1, # run on the slice of a single object
+ is_init_cond_frame=is_init_cond_frame,
+ point_inputs=point_inputs,
+ mask_inputs=None,
+ reverse=reverse,
+ # Skip the memory encoder when adding clicks or mask. We execute the memory encoder
+ # at the beginning of `propagate_in_video` (after user finalize their clicks). This
+ # allows us to enforce non-overlapping constraints on all objects before encoding
+ # them into memory.
+ run_mem_encoder=False,
+ prev_sam_mask_logits=prev_sam_mask_logits,
+ )
+ # Add the output to the output dict (to be used as future memory)
+ obj_temp_output_dict[storage_key][frame_idx] = current_out
+
+ # Resize the output mask to the original video resolution
+ obj_ids = inference_state["obj_ids"]
+ consolidated_out = self._consolidate_temp_output_across_obj(
+ inference_state,
+ frame_idx,
+ is_cond=is_cond,
+ run_mem_encoder=False,
+ consolidate_at_video_res=True,
+ )
+ _, video_res_masks = self._get_orig_video_res_output(
+ inference_state, consolidated_out["pred_masks_video_res"]
+ )
+ return frame_idx, obj_ids, video_res_masks
+
+ def add_new_points(self, *args, **kwargs):
+ """Deprecated method. Please use `add_new_points_or_box` instead."""
+ return self.add_new_points_or_box(*args, **kwargs)
+
+ @torch.inference_mode()
+ def add_new_mask(
+ self,
+ inference_state,
+ frame_idx,
+ obj_id,
+ mask,
+ ):
+ """Add new mask to a frame."""
+ obj_idx = self._obj_id_to_idx(inference_state, obj_id)
+ point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx]
+ mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx]
+
+ if not isinstance(mask, torch.Tensor):
+ mask = torch.tensor(mask, dtype=torch.bool)
+ assert mask.dim() == 2
+ mask_H, mask_W = mask.shape
+ mask_inputs_orig = mask[None, None] # add batch and channel dimension
+ mask_inputs_orig = mask_inputs_orig.float().to(inference_state["device"])
+
+ # resize the mask if it doesn't match the model's image size
+ if mask_H != self.image_size or mask_W != self.image_size:
+ mask_inputs = torch.nn.functional.interpolate(
+ mask_inputs_orig,
+ size=(self.image_size, self.image_size),
+ align_corners=False,
+ mode="bilinear",
+ antialias=True, # use antialias for downsampling
+ )
+ mask_inputs = (mask_inputs >= 0.5).float()
+ else:
+ mask_inputs = mask_inputs_orig
+
+ mask_inputs_per_frame[frame_idx] = mask_inputs
+ point_inputs_per_frame.pop(frame_idx, None)
+ # If this frame hasn't been tracked before, we treat it as an initial conditioning
+ # frame, meaning that the inputs points are to generate segments on this frame without
+ # using any memory from other frames, like in SAM. Otherwise (if it has been tracked),
+ # the input points will be used to correct the already tracked masks.
+ is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"]
+ # whether to track in reverse time order
+ if is_init_cond_frame:
+ reverse = False
+ else:
+ reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"]
+ obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
+ obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
+ # Add a frame to conditioning output if it's an initial conditioning frame or
+ # if the model sees all frames receiving clicks/mask as conditioning frames.
+ is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond
+ storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
+
+ current_out, _ = self._run_single_frame_inference(
+ inference_state=inference_state,
+ output_dict=obj_output_dict, # run on the slice of a single object
+ frame_idx=frame_idx,
+ batch_size=1, # run on the slice of a single object
+ is_init_cond_frame=is_init_cond_frame,
+ point_inputs=None,
+ mask_inputs=mask_inputs,
+ reverse=reverse,
+ # Skip the memory encoder when adding clicks or mask. We execute the memory encoder
+ # at the beginning of `propagate_in_video` (after user finalize their clicks). This
+ # allows us to enforce non-overlapping constraints on all objects before encoding
+ # them into memory.
+ run_mem_encoder=False,
+ )
+ # Add the output to the output dict (to be used as future memory)
+ obj_temp_output_dict[storage_key][frame_idx] = current_out
+
+ # Resize the output mask to the original video resolution
+ obj_ids = inference_state["obj_ids"]
+ consolidated_out = self._consolidate_temp_output_across_obj(
+ inference_state,
+ frame_idx,
+ is_cond=is_cond,
+ run_mem_encoder=False,
+ consolidate_at_video_res=True,
+ )
+ _, video_res_masks = self._get_orig_video_res_output(
+ inference_state, consolidated_out["pred_masks_video_res"]
+ )
+ return frame_idx, obj_ids, video_res_masks
+
+ def _get_orig_video_res_output(self, inference_state, any_res_masks):
+ """
+ Resize the object scores to the original video resolution (video_res_masks)
+ and apply non-overlapping constraints for final output.
+ """
+ device = inference_state["device"]
+ video_H = inference_state["video_height"]
+ video_W = inference_state["video_width"]
+ any_res_masks = any_res_masks.to(device, non_blocking=True)
+ if any_res_masks.shape[-2:] == (video_H, video_W):
+ video_res_masks = any_res_masks
+ else:
+ video_res_masks = torch.nn.functional.interpolate(
+ any_res_masks,
+ size=(video_H, video_W),
+ mode="bilinear",
+ align_corners=False,
+ )
+ if self.non_overlap_masks:
+ video_res_masks = self._apply_non_overlapping_constraints(video_res_masks)
+ return any_res_masks, video_res_masks
+
+ def _consolidate_temp_output_across_obj(
+ self,
+ inference_state,
+ frame_idx,
+ is_cond,
+ run_mem_encoder,
+ consolidate_at_video_res=False,
+ ):
+ """
+ Consolidate the per-object temporary outputs in `temp_output_dict_per_obj` on
+ a frame into a single output for all objects, including
+ 1) fill any missing objects either from `output_dict_per_obj` (if they exist in
+ `output_dict_per_obj` for this frame) or leave them as placeholder values
+ (if they don't exist in `output_dict_per_obj` for this frame);
+ 2) if specified, rerun memory encoder after apply non-overlapping constraints
+ on the object scores.
+ """
+ batch_size = self._get_obj_num(inference_state)
+ storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
+ # Optionally, we allow consolidating the temporary outputs at the original
+ # video resolution (to provide a better editing experience for mask prompts).
+ if consolidate_at_video_res:
+ assert not run_mem_encoder, "memory encoder cannot run at video resolution"
+ consolidated_H = inference_state["video_height"]
+ consolidated_W = inference_state["video_width"]
+ consolidated_mask_key = "pred_masks_video_res"
+ else:
+ consolidated_H = consolidated_W = self.image_size // 4
+ consolidated_mask_key = "pred_masks"
+
+ # Initialize `consolidated_out`. Its "maskmem_features" and "maskmem_pos_enc"
+ # will be added when rerunning the memory encoder after applying non-overlapping
+ # constraints to object scores. Its "pred_masks" are prefilled with a large
+ # negative value (NO_OBJ_SCORE) to represent missing objects.
+ consolidated_out = {
+ "maskmem_features": None,
+ "maskmem_pos_enc": None,
+ consolidated_mask_key: torch.full(
+ size=(batch_size, 1, consolidated_H, consolidated_W),
+ fill_value=NO_OBJ_SCORE,
+ dtype=torch.float32,
+ device=inference_state["storage_device"],
+ ),
+ "obj_ptr": torch.full(
+ size=(batch_size, self.hidden_dim),
+ fill_value=NO_OBJ_SCORE,
+ dtype=torch.float32,
+ device=inference_state["device"],
+ ),
+ "object_score_logits": torch.full(
+ size=(batch_size, 1),
+ # default to 10.0 for object_score_logits, i.e. assuming the object is
+ # present as sigmoid(10)=1, same as in `predict_masks` of `MaskDecoder`
+ fill_value=10.0,
+ dtype=torch.float32,
+ device=inference_state["device"],
+ ),
+ }
+ empty_mask_ptr = None
+ for obj_idx in range(batch_size):
+ obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
+ obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
+ out = obj_temp_output_dict[storage_key].get(frame_idx, None)
+ # If the object doesn't appear in "temp_output_dict_per_obj" on this frame,
+ # we fall back and look up its previous output in "output_dict_per_obj".
+ # We look up both "cond_frame_outputs" and "non_cond_frame_outputs" in
+ # "output_dict_per_obj" to find a previous output for this object.
+ if out is None:
+ out = obj_output_dict["cond_frame_outputs"].get(frame_idx, None)
+ if out is None:
+ out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx, None)
+ # If the object doesn't appear in "output_dict_per_obj" either, we skip it
+ # and leave its mask scores to the default scores (i.e. the NO_OBJ_SCORE
+ # placeholder above) and set its object pointer to be a dummy pointer.
+ if out is None:
+ # Fill in dummy object pointers for those objects without any inputs or
+ # tracking outcomes on this frame (only do it under `run_mem_encoder=True`,
+ # i.e. when we need to build the memory for tracking).
+ if run_mem_encoder:
+ if empty_mask_ptr is None:
+ empty_mask_ptr = self._get_empty_mask_ptr(
+ inference_state, frame_idx
+ )
+ # fill object pointer with a dummy pointer (based on an empty mask)
+ consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = empty_mask_ptr
+ continue
+ # Add the temporary object output mask to consolidated output mask
+ obj_mask = out["pred_masks"]
+ consolidated_pred_masks = consolidated_out[consolidated_mask_key]
+ if obj_mask.shape[-2:] == consolidated_pred_masks.shape[-2:]:
+ consolidated_pred_masks[obj_idx : obj_idx + 1] = obj_mask
+ else:
+ # Resize first if temporary object mask has a different resolution
+ resized_obj_mask = torch.nn.functional.interpolate(
+ obj_mask,
+ size=consolidated_pred_masks.shape[-2:],
+ mode="bilinear",
+ align_corners=False,
+ )
+ consolidated_pred_masks[obj_idx : obj_idx + 1] = resized_obj_mask
+ consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = out["obj_ptr"]
+ consolidated_out["object_score_logits"][obj_idx : obj_idx + 1] = out[
+ "object_score_logits"
+ ]
+
+ # Optionally, apply non-overlapping constraints on the consolidated scores
+ # and rerun the memory encoder
+ if run_mem_encoder:
+ device = inference_state["device"]
+ high_res_masks = torch.nn.functional.interpolate(
+ consolidated_out["pred_masks"].to(device, non_blocking=True),
+ size=(self.image_size, self.image_size),
+ mode="bilinear",
+ align_corners=False,
+ )
+ if self.non_overlap_masks_for_mem_enc:
+ high_res_masks = self._apply_non_overlapping_constraints(high_res_masks)
+ maskmem_features, maskmem_pos_enc = self._run_memory_encoder(
+ inference_state=inference_state,
+ frame_idx=frame_idx,
+ batch_size=batch_size,
+ high_res_masks=high_res_masks,
+ object_score_logits=consolidated_out["object_score_logits"],
+ is_mask_from_pts=True, # these frames are what the user interacted with
+ )
+ consolidated_out["maskmem_features"] = maskmem_features
+ consolidated_out["maskmem_pos_enc"] = maskmem_pos_enc
+
+ return consolidated_out
+
+ def _get_empty_mask_ptr(self, inference_state, frame_idx):
+ """Get a dummy object pointer based on an empty mask on the current frame."""
+ # A dummy (empty) mask with a single object
+ batch_size = 1
+ mask_inputs = torch.zeros(
+ (batch_size, 1, self.image_size, self.image_size),
+ dtype=torch.float32,
+ device=inference_state["device"],
+ )
+
+ # Retrieve correct image features
+ (
+ _,
+ _,
+ current_vision_feats,
+ current_vision_pos_embeds,
+ feat_sizes,
+ ) = self._get_image_feature(inference_state, frame_idx, batch_size)
+
+ # Feed the empty mask and image feature above to get a dummy object pointer
+ current_out = self.track_step(
+ frame_idx=frame_idx,
+ is_init_cond_frame=True,
+ current_vision_feats=current_vision_feats,
+ current_vision_pos_embeds=current_vision_pos_embeds,
+ feat_sizes=feat_sizes,
+ point_inputs=None,
+ mask_inputs=mask_inputs,
+ output_dict={},
+ num_frames=inference_state["num_frames"],
+ track_in_reverse=False,
+ run_mem_encoder=False,
+ prev_sam_mask_logits=None,
+ )
+ return current_out["obj_ptr"]
+
+ @torch.inference_mode()
+ def propagate_in_video_preflight(self, inference_state):
+ """Prepare inference_state and consolidate temporary outputs before tracking."""
+ # Tracking has started and we don't allow adding new objects until session is reset.
+ inference_state["tracking_has_started"] = True
+ batch_size = self._get_obj_num(inference_state)
+
+ # Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and
+ # add them into "output_dict".
+ temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
+ output_dict = inference_state["output_dict"]
+ # "consolidated_frame_inds" contains indices of those frames where consolidated
+ # temporary outputs have been added (either in this call or any previous calls
+ # to `propagate_in_video_preflight`).
+ consolidated_frame_inds = inference_state["consolidated_frame_inds"]
+ for is_cond in [False, True]:
+ # Separately consolidate conditioning and non-conditioning temp outputs
+ storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
+ # Find all the frames that contain temporary outputs for any objects
+ # (these should be the frames that have just received clicks for mask inputs
+ # via `add_new_points_or_box` or `add_new_mask`)
+ temp_frame_inds = set()
+ for obj_temp_output_dict in temp_output_dict_per_obj.values():
+ temp_frame_inds.update(obj_temp_output_dict[storage_key].keys())
+ consolidated_frame_inds[storage_key].update(temp_frame_inds)
+ # consolidate the temporary output across all objects on this frame
+ for frame_idx in temp_frame_inds:
+ consolidated_out = self._consolidate_temp_output_across_obj(
+ inference_state, frame_idx, is_cond=is_cond, run_mem_encoder=True
+ )
+ # merge them into "output_dict" and also create per-object slices
+ output_dict[storage_key][frame_idx] = consolidated_out
+ self._add_output_per_object(
+ inference_state, frame_idx, consolidated_out, storage_key
+ )
+ clear_non_cond_mem = self.clear_non_cond_mem_around_input and (
+ self.clear_non_cond_mem_for_multi_obj or batch_size <= 1
+ )
+ if clear_non_cond_mem:
+ # clear non-conditioning memory of the surrounding frames
+ self._clear_non_cond_mem_around_input(inference_state, frame_idx)
+
+ # clear temporary outputs in `temp_output_dict_per_obj`
+ for obj_temp_output_dict in temp_output_dict_per_obj.values():
+ obj_temp_output_dict[storage_key].clear()
+
+ # edge case: if an output is added to "cond_frame_outputs", we remove any prior
+ # output on the same frame in "non_cond_frame_outputs"
+ for frame_idx in output_dict["cond_frame_outputs"]:
+ output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
+ for obj_output_dict in inference_state["output_dict_per_obj"].values():
+ for frame_idx in obj_output_dict["cond_frame_outputs"]:
+ obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
+ for frame_idx in consolidated_frame_inds["cond_frame_outputs"]:
+ assert frame_idx in output_dict["cond_frame_outputs"]
+ consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx)
+
+ # Make sure that the frame indices in "consolidated_frame_inds" are exactly those frames
+ # with either points or mask inputs (which should be true under a correct workflow).
+ all_consolidated_frame_inds = (
+ consolidated_frame_inds["cond_frame_outputs"]
+ | consolidated_frame_inds["non_cond_frame_outputs"]
+ )
+ input_frames_inds = set()
+ for point_inputs_per_frame in inference_state["point_inputs_per_obj"].values():
+ input_frames_inds.update(point_inputs_per_frame.keys())
+ for mask_inputs_per_frame in inference_state["mask_inputs_per_obj"].values():
+ input_frames_inds.update(mask_inputs_per_frame.keys())
+ assert all_consolidated_frame_inds == input_frames_inds
+
+ @torch.inference_mode()
+ def propagate_in_video(
+ self,
+ inference_state,
+ start_frame_idx=None,
+ max_frame_num_to_track=None,
+ reverse=False,
+ ):
+ """Propagate the input points across frames to track in the entire video."""
+ self.propagate_in_video_preflight(inference_state)
+
+ output_dict = inference_state["output_dict"]
+ consolidated_frame_inds = inference_state["consolidated_frame_inds"]
+ obj_ids = inference_state["obj_ids"]
+ num_frames = inference_state["num_frames"]
+ batch_size = self._get_obj_num(inference_state)
+ if len(output_dict["cond_frame_outputs"]) == 0:
+ raise RuntimeError("No points are provided; please add points first")
+ clear_non_cond_mem = self.clear_non_cond_mem_around_input and (
+ self.clear_non_cond_mem_for_multi_obj or batch_size <= 1
+ )
+
+ # set start index, end index, and processing order
+ if start_frame_idx is None:
+ # default: start from the earliest frame with input points
+ start_frame_idx = min(output_dict["cond_frame_outputs"])
+ if max_frame_num_to_track is None:
+ # default: track all the frames in the video
+ max_frame_num_to_track = num_frames
+ if reverse:
+ end_frame_idx = max(start_frame_idx - max_frame_num_to_track, 0)
+ if start_frame_idx > 0:
+ processing_order = range(start_frame_idx, end_frame_idx - 1, -1)
+ else:
+ processing_order = [] # skip reverse tracking if starting from frame 0
+ else:
+ end_frame_idx = min(
+ start_frame_idx + max_frame_num_to_track, num_frames - 1
+ )
+ processing_order = range(start_frame_idx, end_frame_idx + 1)
+
+ for frame_idx in tqdm(processing_order, desc="propagate in video"):
+ # We skip those frames already in consolidated outputs (these are frames
+ # that received input clicks or mask). Note that we cannot directly run
+ # batched forward on them via `_run_single_frame_inference` because the
+ # number of clicks on each object might be different.
+ if frame_idx in consolidated_frame_inds["cond_frame_outputs"]:
+ storage_key = "cond_frame_outputs"
+ current_out = output_dict[storage_key][frame_idx]
+ pred_masks = current_out["pred_masks"]
+ if clear_non_cond_mem:
+ # clear non-conditioning memory of the surrounding frames
+ self._clear_non_cond_mem_around_input(inference_state, frame_idx)
+ elif frame_idx in consolidated_frame_inds["non_cond_frame_outputs"]:
+ storage_key = "non_cond_frame_outputs"
+ current_out = output_dict[storage_key][frame_idx]
+ pred_masks = current_out["pred_masks"]
+ else:
+ storage_key = "non_cond_frame_outputs"
+ current_out, pred_masks = self._run_single_frame_inference(
+ inference_state=inference_state,
+ output_dict=output_dict,
+ frame_idx=frame_idx,
+ batch_size=batch_size,
+ is_init_cond_frame=False,
+ point_inputs=None,
+ mask_inputs=None,
+ reverse=reverse,
+ run_mem_encoder=True,
+ )
+ output_dict[storage_key][frame_idx] = current_out
+ # Create slices of per-object outputs for subsequent interaction with each
+ # individual object after tracking.
+ self._add_output_per_object(
+ inference_state, frame_idx, current_out, storage_key
+ )
+ inference_state["frames_already_tracked"][frame_idx] = {"reverse": reverse}
+
+ # Resize the output mask to the original video resolution (we directly use
+ # the mask scores on GPU for output to avoid any CPU conversion in between)
+ _, video_res_masks = self._get_orig_video_res_output(
+ inference_state, pred_masks
+ )
+ yield frame_idx, obj_ids, video_res_masks
+
+ def _add_output_per_object(
+ self, inference_state, frame_idx, current_out, storage_key
+ ):
+ """
+ Split a multi-object output into per-object output slices and add them into
+ `output_dict_per_obj`. The resulting slices share the same tensor storage.
+ """
+ maskmem_features = current_out["maskmem_features"]
+ assert maskmem_features is None or isinstance(maskmem_features, torch.Tensor)
+
+ maskmem_pos_enc = current_out["maskmem_pos_enc"]
+ assert maskmem_pos_enc is None or isinstance(maskmem_pos_enc, list)
+
+ output_dict_per_obj = inference_state["output_dict_per_obj"]
+ for obj_idx, obj_output_dict in output_dict_per_obj.items():
+ obj_slice = slice(obj_idx, obj_idx + 1)
+ obj_out = {
+ "maskmem_features": None,
+ "maskmem_pos_enc": None,
+ "pred_masks": current_out["pred_masks"][obj_slice],
+ "obj_ptr": current_out["obj_ptr"][obj_slice],
+ "object_score_logits": current_out["object_score_logits"][obj_slice],
+ }
+ if maskmem_features is not None:
+ obj_out["maskmem_features"] = maskmem_features[obj_slice]
+ if maskmem_pos_enc is not None:
+ obj_out["maskmem_pos_enc"] = [x[obj_slice] for x in maskmem_pos_enc]
+ obj_output_dict[storage_key][frame_idx] = obj_out
+
+ @torch.inference_mode()
+ def clear_all_prompts_in_frame(
+ self, inference_state, frame_idx, obj_id, need_output=True
+ ):
+ """Remove all input points or mask in a specific frame for a given object."""
+ obj_idx = self._obj_id_to_idx(inference_state, obj_id)
+
+ # Clear the conditioning information on the given frame
+ inference_state["point_inputs_per_obj"][obj_idx].pop(frame_idx, None)
+ inference_state["mask_inputs_per_obj"][obj_idx].pop(frame_idx, None)
+
+ temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
+ temp_output_dict_per_obj[obj_idx]["cond_frame_outputs"].pop(frame_idx, None)
+ temp_output_dict_per_obj[obj_idx]["non_cond_frame_outputs"].pop(frame_idx, None)
+
+ # Check and see if there are still any inputs left on this frame
+ batch_size = self._get_obj_num(inference_state)
+ frame_has_input = False
+ for obj_idx2 in range(batch_size):
+ if frame_idx in inference_state["point_inputs_per_obj"][obj_idx2]:
+ frame_has_input = True
+ break
+ if frame_idx in inference_state["mask_inputs_per_obj"][obj_idx2]:
+ frame_has_input = True
+ break
+
+ # If this frame has no remaining inputs for any objects, we further clear its
+ # conditioning frame status
+ if not frame_has_input:
+ output_dict = inference_state["output_dict"]
+ consolidated_frame_inds = inference_state["consolidated_frame_inds"]
+ consolidated_frame_inds["cond_frame_outputs"].discard(frame_idx)
+ consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx)
+ # Remove the frame's conditioning output (possibly downgrading it to non-conditioning)
+ out = output_dict["cond_frame_outputs"].pop(frame_idx, None)
+ if out is not None:
+ # The frame is not a conditioning frame anymore since it's not receiving inputs,
+ # so we "downgrade" its output (if exists) to a non-conditioning frame output.
+ output_dict["non_cond_frame_outputs"][frame_idx] = out
+ inference_state["frames_already_tracked"].pop(frame_idx, None)
+ # Similarly, do it for the sliced output on each object.
+ for obj_idx2 in range(batch_size):
+ obj_output_dict = inference_state["output_dict_per_obj"][obj_idx2]
+ obj_out = obj_output_dict["cond_frame_outputs"].pop(frame_idx, None)
+ if obj_out is not None:
+ obj_output_dict["non_cond_frame_outputs"][frame_idx] = obj_out
+
+ # If all the conditioning frames have been removed, we also clear the tracking outputs
+ if len(output_dict["cond_frame_outputs"]) == 0:
+ self._reset_tracking_results(inference_state)
+
+ if not need_output:
+ return
+ # Finally, output updated masks per object (after removing the inputs above)
+ obj_ids = inference_state["obj_ids"]
+ is_cond = any(
+ frame_idx in obj_temp_output_dict["cond_frame_outputs"]
+ for obj_temp_output_dict in temp_output_dict_per_obj.values()
+ )
+ consolidated_out = self._consolidate_temp_output_across_obj(
+ inference_state,
+ frame_idx,
+ is_cond=is_cond,
+ run_mem_encoder=False,
+ consolidate_at_video_res=True,
+ )
+ _, video_res_masks = self._get_orig_video_res_output(
+ inference_state, consolidated_out["pred_masks_video_res"]
+ )
+ return frame_idx, obj_ids, video_res_masks
+
+ @torch.inference_mode()
+ def reset_state(self, inference_state):
+ """Remove all input points or mask in all frames throughout the video."""
+ self._reset_tracking_results(inference_state)
+ # Remove all object ids
+ inference_state["obj_id_to_idx"].clear()
+ inference_state["obj_idx_to_id"].clear()
+ inference_state["obj_ids"].clear()
+ inference_state["point_inputs_per_obj"].clear()
+ inference_state["mask_inputs_per_obj"].clear()
+ inference_state["output_dict_per_obj"].clear()
+ inference_state["temp_output_dict_per_obj"].clear()
+
+ def _reset_tracking_results(self, inference_state):
+ """Reset all tracking inputs and results across the videos."""
+ for v in inference_state["point_inputs_per_obj"].values():
+ v.clear()
+ for v in inference_state["mask_inputs_per_obj"].values():
+ v.clear()
+ for v in inference_state["output_dict_per_obj"].values():
+ v["cond_frame_outputs"].clear()
+ v["non_cond_frame_outputs"].clear()
+ for v in inference_state["temp_output_dict_per_obj"].values():
+ v["cond_frame_outputs"].clear()
+ v["non_cond_frame_outputs"].clear()
+ inference_state["output_dict"]["cond_frame_outputs"].clear()
+ inference_state["output_dict"]["non_cond_frame_outputs"].clear()
+ inference_state["consolidated_frame_inds"]["cond_frame_outputs"].clear()
+ inference_state["consolidated_frame_inds"]["non_cond_frame_outputs"].clear()
+ inference_state["tracking_has_started"] = False
+ inference_state["frames_already_tracked"].clear()
+
+ def _get_image_feature(self, inference_state, frame_idx, batch_size):
+ """Compute the image features on a given frame."""
+ # Look up in the cache first
+ image, backbone_out = inference_state["cached_features"].get(
+ frame_idx, (None, None)
+ )
+ if backbone_out is None:
+ # Cache miss -- we will run inference on a single image
+ device = inference_state["device"]
+ image = inference_state["images"][frame_idx].to(device).float().unsqueeze(0)
+ backbone_out = self.forward_image(image)
+ # Cache the most recent frame's feature (for repeated interactions with
+ # a frame; we can use an LRU cache for more frames in the future).
+ inference_state["cached_features"] = {frame_idx: (image, backbone_out)}
+
+ # expand the features to have the same dimension as the number of objects
+ expanded_image = image.expand(batch_size, -1, -1, -1)
+ expanded_backbone_out = {
+ "backbone_fpn": backbone_out["backbone_fpn"].copy(),
+ "vision_pos_enc": backbone_out["vision_pos_enc"].copy(),
+ }
+ for i, feat in enumerate(expanded_backbone_out["backbone_fpn"]):
+ expanded_backbone_out["backbone_fpn"][i] = feat.expand(
+ batch_size, -1, -1, -1
+ )
+ for i, pos in enumerate(expanded_backbone_out["vision_pos_enc"]):
+ pos = pos.expand(batch_size, -1, -1, -1)
+ expanded_backbone_out["vision_pos_enc"][i] = pos
+
+ features = self._prepare_backbone_features(expanded_backbone_out)
+ features = (expanded_image,) + features
+ return features
+
+ def _run_single_frame_inference(
+ self,
+ inference_state,
+ output_dict,
+ frame_idx,
+ batch_size,
+ is_init_cond_frame,
+ point_inputs,
+ mask_inputs,
+ reverse,
+ run_mem_encoder,
+ prev_sam_mask_logits=None,
+ ):
+ """Run tracking on a single frame based on current inputs and previous memory."""
+ # Retrieve correct image features
+ (
+ _,
+ _,
+ current_vision_feats,
+ current_vision_pos_embeds,
+ feat_sizes,
+ ) = self._get_image_feature(inference_state, frame_idx, batch_size)
+
+ # point and mask should not appear as input simultaneously on the same frame
+ assert point_inputs is None or mask_inputs is None
+ current_out = self.track_step(
+ frame_idx=frame_idx,
+ is_init_cond_frame=is_init_cond_frame,
+ current_vision_feats=current_vision_feats,
+ current_vision_pos_embeds=current_vision_pos_embeds,
+ feat_sizes=feat_sizes,
+ point_inputs=point_inputs,
+ mask_inputs=mask_inputs,
+ output_dict=output_dict,
+ num_frames=inference_state["num_frames"],
+ track_in_reverse=reverse,
+ run_mem_encoder=run_mem_encoder,
+ prev_sam_mask_logits=prev_sam_mask_logits,
+ )
+
+ # optionally offload the output to CPU memory to save GPU space
+ storage_device = inference_state["storage_device"]
+ maskmem_features = current_out["maskmem_features"]
+ if maskmem_features is not None:
+ maskmem_features = maskmem_features.to(torch.bfloat16)
+ maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
+ pred_masks_gpu = current_out["pred_masks"]
+ # potentially fill holes in the predicted masks
+ if self.fill_hole_area > 0:
+ pred_masks_gpu = fill_holes_in_mask_scores(
+ pred_masks_gpu, self.fill_hole_area
+ )
+ pred_masks = pred_masks_gpu.to(storage_device, non_blocking=True)
+ # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
+ maskmem_pos_enc = self._get_maskmem_pos_enc(inference_state, current_out)
+ # object pointer is a small tensor, so we always keep it on GPU memory for fast access
+ obj_ptr = current_out["obj_ptr"]
+ object_score_logits = current_out["object_score_logits"]
+ # make a compact version of this frame's output to reduce the state size
+ compact_current_out = {
+ "maskmem_features": maskmem_features,
+ "maskmem_pos_enc": maskmem_pos_enc,
+ "pred_masks": pred_masks,
+ "obj_ptr": obj_ptr,
+ "object_score_logits": object_score_logits,
+ }
+ return compact_current_out, pred_masks_gpu
+
+ def _run_memory_encoder(
+ self,
+ inference_state,
+ frame_idx,
+ batch_size,
+ high_res_masks,
+ object_score_logits,
+ is_mask_from_pts,
+ ):
+ """
+ Run the memory encoder on `high_res_masks`. This is usually after applying
+ non-overlapping constraints to object scores. Since their scores changed, their
+ memory also need to be computed again with the memory encoder.
+ """
+ # Retrieve correct image features
+ _, _, current_vision_feats, _, feat_sizes = self._get_image_feature(
+ inference_state, frame_idx, batch_size
+ )
+ maskmem_features, maskmem_pos_enc = self._encode_new_memory(
+ current_vision_feats=current_vision_feats,
+ feat_sizes=feat_sizes,
+ pred_masks_high_res=high_res_masks,
+ object_score_logits=object_score_logits,
+ is_mask_from_pts=is_mask_from_pts,
+ )
+
+ # optionally offload the output to CPU memory to save GPU space
+ storage_device = inference_state["storage_device"]
+ maskmem_features = maskmem_features.to(torch.bfloat16)
+ maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
+ # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
+ maskmem_pos_enc = self._get_maskmem_pos_enc(
+ inference_state, {"maskmem_pos_enc": maskmem_pos_enc}
+ )
+ return maskmem_features, maskmem_pos_enc
+
+ def _get_maskmem_pos_enc(self, inference_state, current_out):
+ """
+ `maskmem_pos_enc` is the same across frames and objects, so we cache it as
+ a constant in the inference session to reduce session storage size.
+ """
+ model_constants = inference_state["constants"]
+ # "out_maskmem_pos_enc" should be either a list of tensors or None
+ out_maskmem_pos_enc = current_out["maskmem_pos_enc"]
+ if out_maskmem_pos_enc is not None:
+ if "maskmem_pos_enc" not in model_constants:
+ assert isinstance(out_maskmem_pos_enc, list)
+ # only take the slice for one object, since it's same across objects
+ maskmem_pos_enc = [x[0:1].clone() for x in out_maskmem_pos_enc]
+ model_constants["maskmem_pos_enc"] = maskmem_pos_enc
+ else:
+ maskmem_pos_enc = model_constants["maskmem_pos_enc"]
+ # expand the cached maskmem_pos_enc to the actual batch size
+ batch_size = out_maskmem_pos_enc[0].size(0)
+ expanded_maskmem_pos_enc = [
+ x.expand(batch_size, -1, -1, -1) for x in maskmem_pos_enc
+ ]
+ else:
+ expanded_maskmem_pos_enc = None
+ return expanded_maskmem_pos_enc
+
+ @torch.inference_mode()
+ def remove_object(self, inference_state, obj_id, strict=False, need_output=True):
+ """
+ Remove an object id from the tracking state. If strict is True, we check whether
+ the object id actually exists and raise an error if it doesn't exist.
+ """
+ old_obj_idx_to_rm = inference_state["obj_id_to_idx"].get(obj_id, None)
+ updated_frames = []
+ # Check whether this object_id to remove actually exists and possibly raise an error.
+ if old_obj_idx_to_rm is None:
+ if not strict:
+ return inference_state["obj_ids"], updated_frames
+ raise RuntimeError(
+ f"Cannot remove object id {obj_id} as it doesn't exist. "
+ f"All existing object ids: {inference_state['obj_ids']}."
+ )
+
+ # If this is the only remaining object id, we simply reset the state.
+ if len(inference_state["obj_id_to_idx"]) == 1:
+ self.reset_state(inference_state)
+ return inference_state["obj_ids"], updated_frames
+
+ # There are still remaining objects after removing this object id. In this case,
+ # we need to delete the object storage from inference state tensors.
+ # Step 0: clear the input on those frames where this object id has point or mask input
+ # (note that this step is required as it might downgrade conditioning frames to
+ # non-conditioning ones)
+ obj_input_frames_inds = set()
+ obj_input_frames_inds.update(
+ inference_state["point_inputs_per_obj"][old_obj_idx_to_rm]
+ )
+ obj_input_frames_inds.update(
+ inference_state["mask_inputs_per_obj"][old_obj_idx_to_rm]
+ )
+ for frame_idx in obj_input_frames_inds:
+ self.clear_all_prompts_in_frame(
+ inference_state, frame_idx, obj_id, need_output=False
+ )
+
+ # Step 1: Update the object id mapping (note that it must be done after Step 0,
+ # since Step 0 still requires the old object id mappings in inference_state)
+ old_obj_ids = inference_state["obj_ids"]
+ old_obj_inds = list(range(len(old_obj_ids)))
+ remain_old_obj_inds = old_obj_inds.copy()
+ remain_old_obj_inds.remove(old_obj_idx_to_rm)
+ new_obj_ids = [old_obj_ids[old_idx] for old_idx in remain_old_obj_inds]
+ new_obj_inds = list(range(len(new_obj_ids)))
+ # build new mappings
+ old_idx_to_new_idx = dict(zip(remain_old_obj_inds, new_obj_inds))
+ inference_state["obj_id_to_idx"] = dict(zip(new_obj_ids, new_obj_inds))
+ inference_state["obj_idx_to_id"] = dict(zip(new_obj_inds, new_obj_ids))
+ inference_state["obj_ids"] = new_obj_ids
+
+ # Step 2: For per-object tensor storage, we shift their obj_idx in the dict keys.
+ # (note that "consolidated_frame_inds" doesn't need to be updated in this step as
+ # it's already handled in Step 0)
+ def _map_keys(container):
+ new_kvs = []
+ for k in old_obj_inds:
+ v = container.pop(k)
+ if k in old_idx_to_new_idx:
+ new_kvs.append((old_idx_to_new_idx[k], v))
+ container.update(new_kvs)
+
+ _map_keys(inference_state["point_inputs_per_obj"])
+ _map_keys(inference_state["mask_inputs_per_obj"])
+ _map_keys(inference_state["output_dict_per_obj"])
+ _map_keys(inference_state["temp_output_dict_per_obj"])
+
+ # Step 3: For packed tensor storage, we index the remaining ids and rebuild the per-object slices.
+ def _slice_state(output_dict, storage_key):
+ for frame_idx, out in output_dict[storage_key].items():
+ out["maskmem_features"] = out["maskmem_features"][remain_old_obj_inds]
+ out["maskmem_pos_enc"] = [
+ x[remain_old_obj_inds] for x in out["maskmem_pos_enc"]
+ ]
+ # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
+ out["maskmem_pos_enc"] = self._get_maskmem_pos_enc(inference_state, out)
+ out["pred_masks"] = out["pred_masks"][remain_old_obj_inds]
+ out["obj_ptr"] = out["obj_ptr"][remain_old_obj_inds]
+ out["object_score_logits"] = out["object_score_logits"][
+ remain_old_obj_inds
+ ]
+ # also update the per-object slices
+ self._add_output_per_object(
+ inference_state, frame_idx, out, storage_key
+ )
+
+ _slice_state(inference_state["output_dict"], "cond_frame_outputs")
+ _slice_state(inference_state["output_dict"], "non_cond_frame_outputs")
+
+ # Step 4: Further collect the outputs on those frames in `obj_input_frames_inds`, which
+ # could show an updated mask for objects previously occluded by the object being removed
+ if need_output:
+ temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
+ for frame_idx in obj_input_frames_inds:
+ is_cond = any(
+ frame_idx in obj_temp_output_dict["cond_frame_outputs"]
+ for obj_temp_output_dict in temp_output_dict_per_obj.values()
+ )
+ consolidated_out = self._consolidate_temp_output_across_obj(
+ inference_state,
+ frame_idx,
+ is_cond=is_cond,
+ run_mem_encoder=False,
+ consolidate_at_video_res=True,
+ )
+ _, video_res_masks = self._get_orig_video_res_output(
+ inference_state, consolidated_out["pred_masks_video_res"]
+ )
+ updated_frames.append((frame_idx, video_res_masks))
+
+ return inference_state["obj_ids"], updated_frames
+
+ def _clear_non_cond_mem_around_input(self, inference_state, frame_idx):
+ """
+ Remove the non-conditioning memory around the input frame. When users provide
+ correction clicks, the surrounding frames' non-conditioning memories can still
+ contain outdated object appearance information and could confuse the model.
+
+ This method clears those non-conditioning memories surrounding the interacted
+ frame to avoid giving the model both old and new information about the object.
+ """
+ r = self.memory_temporal_stride_for_eval
+ frame_idx_begin = frame_idx - r * self.num_maskmem
+ frame_idx_end = frame_idx + r * self.num_maskmem
+ output_dict = inference_state["output_dict"]
+ non_cond_frame_outputs = output_dict["non_cond_frame_outputs"]
+ for t in range(frame_idx_begin, frame_idx_end + 1):
+ non_cond_frame_outputs.pop(t, None)
+ for obj_output_dict in inference_state["output_dict_per_obj"].values():
+ obj_output_dict["non_cond_frame_outputs"].pop(t, None)
diff --git a/libs/sam2/utils/__init__.py b/libs/sam2/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5277f46157403e47fd830fc519144b97ef69d4ae
--- /dev/null
+++ b/libs/sam2/utils/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
diff --git a/libs/sam2/utils/amg.py b/libs/sam2/utils/amg.py
new file mode 100644
index 0000000000000000000000000000000000000000..986842960cf5deca00614b7b1cde1ab77dad7e6e
--- /dev/null
+++ b/libs/sam2/utils/amg.py
@@ -0,0 +1,348 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+from copy import deepcopy
+from itertools import product
+from typing import Any, Dict, Generator, ItemsView, List, Tuple
+
+import numpy as np
+import torch
+
+# Very lightly adapted from https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/utils/amg.py
+
+
+class MaskData:
+ """
+ A structure for storing masks and their related data in batched format.
+ Implements basic filtering and concatenation.
+ """
+
+ def __init__(self, **kwargs) -> None:
+ for v in kwargs.values():
+ assert isinstance(
+ v, (list, np.ndarray, torch.Tensor)
+ ), "MaskData only supports list, numpy arrays, and torch tensors."
+ self._stats = dict(**kwargs)
+
+ def __setitem__(self, key: str, item: Any) -> None:
+ assert isinstance(
+ item, (list, np.ndarray, torch.Tensor)
+ ), "MaskData only supports list, numpy arrays, and torch tensors."
+ self._stats[key] = item
+
+ def __delitem__(self, key: str) -> None:
+ del self._stats[key]
+
+ def __getitem__(self, key: str) -> Any:
+ return self._stats[key]
+
+ def items(self) -> ItemsView[str, Any]:
+ return self._stats.items()
+
+ def filter(self, keep: torch.Tensor) -> None:
+ for k, v in self._stats.items():
+ if v is None:
+ self._stats[k] = None
+ elif isinstance(v, torch.Tensor):
+ self._stats[k] = v[torch.as_tensor(keep, device=v.device)]
+ elif isinstance(v, np.ndarray):
+ self._stats[k] = v[keep.detach().cpu().numpy()]
+ elif isinstance(v, list) and keep.dtype == torch.bool:
+ self._stats[k] = [a for i, a in enumerate(v) if keep[i]]
+ elif isinstance(v, list):
+ self._stats[k] = [v[i] for i in keep]
+ else:
+ raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
+
+ def cat(self, new_stats: "MaskData") -> None:
+ for k, v in new_stats.items():
+ if k not in self._stats or self._stats[k] is None:
+ self._stats[k] = deepcopy(v)
+ elif isinstance(v, torch.Tensor):
+ self._stats[k] = torch.cat([self._stats[k], v], dim=0)
+ elif isinstance(v, np.ndarray):
+ self._stats[k] = np.concatenate([self._stats[k], v], axis=0)
+ elif isinstance(v, list):
+ self._stats[k] = self._stats[k] + deepcopy(v)
+ else:
+ raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
+
+ def to_numpy(self) -> None:
+ for k, v in self._stats.items():
+ if isinstance(v, torch.Tensor):
+ self._stats[k] = v.float().detach().cpu().numpy()
+
+
+def is_box_near_crop_edge(
+ boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0
+) -> torch.Tensor:
+ """Filter masks at the edge of a crop, but not at the edge of the original image."""
+ crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
+ orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
+ boxes = uncrop_boxes_xyxy(boxes, crop_box).float()
+ near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0)
+ near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0)
+ near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge)
+ return torch.any(near_crop_edge, dim=1)
+
+
+def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor:
+ box_xywh = deepcopy(box_xyxy)
+ box_xywh[2] = box_xywh[2] - box_xywh[0]
+ box_xywh[3] = box_xywh[3] - box_xywh[1]
+ return box_xywh
+
+
+def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
+ assert len(args) > 0 and all(
+ len(a) == len(args[0]) for a in args
+ ), "Batched iteration must have inputs of all the same size."
+ n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
+ for b in range(n_batches):
+ yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args]
+
+
+def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]:
+ """
+ Encodes masks to an uncompressed RLE, in the format expected by
+ pycoco tools.
+ """
+ # Put in fortran order and flatten h,w
+ b, h, w = tensor.shape
+ tensor = tensor.permute(0, 2, 1).flatten(1)
+
+ # Compute change indices
+ diff = tensor[:, 1:] ^ tensor[:, :-1]
+ change_indices = diff.nonzero()
+
+ # Encode run length
+ out = []
+ for i in range(b):
+ cur_idxs = change_indices[change_indices[:, 0] == i, 1]
+ cur_idxs = torch.cat(
+ [
+ torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device),
+ cur_idxs + 1,
+ torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device),
+ ]
+ )
+ btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
+ counts = [] if tensor[i, 0] == 0 else [0]
+ counts.extend(btw_idxs.detach().cpu().tolist())
+ out.append({"size": [h, w], "counts": counts})
+ return out
+
+
+def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray:
+ """Compute a binary mask from an uncompressed RLE."""
+ h, w = rle["size"]
+ mask = np.empty(h * w, dtype=bool)
+ idx = 0
+ parity = False
+ for count in rle["counts"]:
+ mask[idx : idx + count] = parity
+ idx += count
+ parity ^= True
+ mask = mask.reshape(w, h)
+ return mask.transpose() # Put in C order
+
+
+def area_from_rle(rle: Dict[str, Any]) -> int:
+ return sum(rle["counts"][1::2])
+
+
+def calculate_stability_score(
+ masks: torch.Tensor, mask_threshold: float, threshold_offset: float
+) -> torch.Tensor:
+ """
+ Computes the stability score for a batch of masks. The stability
+ score is the IoU between the binary masks obtained by thresholding
+ the predicted mask logits at high and low values.
+ """
+ # One mask is always contained inside the other.
+ # Save memory by preventing unnecessary cast to torch.int64
+ intersections = (
+ (masks > (mask_threshold + threshold_offset))
+ .sum(-1, dtype=torch.int16)
+ .sum(-1, dtype=torch.int32)
+ )
+ unions = (
+ (masks > (mask_threshold - threshold_offset))
+ .sum(-1, dtype=torch.int16)
+ .sum(-1, dtype=torch.int32)
+ )
+ return intersections / unions
+
+
+def build_point_grid(n_per_side: int) -> np.ndarray:
+ """Generates a 2D grid of points evenly spaced in [0,1]x[0,1]."""
+ offset = 1 / (2 * n_per_side)
+ points_one_side = np.linspace(offset, 1 - offset, n_per_side)
+ points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
+ points_y = np.tile(points_one_side[:, None], (1, n_per_side))
+ points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2)
+ return points
+
+
+def build_all_layer_point_grids(
+ n_per_side: int, n_layers: int, scale_per_layer: int
+) -> List[np.ndarray]:
+ """Generates point grids for all crop layers."""
+ points_by_layer = []
+ for i in range(n_layers + 1):
+ n_points = int(n_per_side / (scale_per_layer**i))
+ points_by_layer.append(build_point_grid(n_points))
+ return points_by_layer
+
+
+def generate_crop_boxes(
+ im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float
+) -> Tuple[List[List[int]], List[int]]:
+ """
+ Generates a list of crop boxes of different sizes. Each layer
+ has (2**i)**2 boxes for the ith layer.
+ """
+ crop_boxes, layer_idxs = [], []
+ im_h, im_w = im_size
+ short_side = min(im_h, im_w)
+
+ # Original image
+ crop_boxes.append([0, 0, im_w, im_h])
+ layer_idxs.append(0)
+
+ def crop_len(orig_len, n_crops, overlap):
+ return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops))
+
+ for i_layer in range(n_layers):
+ n_crops_per_side = 2 ** (i_layer + 1)
+ overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side))
+
+ crop_w = crop_len(im_w, n_crops_per_side, overlap)
+ crop_h = crop_len(im_h, n_crops_per_side, overlap)
+
+ crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)]
+ crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)]
+
+ # Crops in XYWH format
+ for x0, y0 in product(crop_box_x0, crop_box_y0):
+ box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)]
+ crop_boxes.append(box)
+ layer_idxs.append(i_layer + 1)
+
+ return crop_boxes, layer_idxs
+
+
+def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
+ x0, y0, _, _ = crop_box
+ offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device)
+ # Check if boxes has a channel dimension
+ if len(boxes.shape) == 3:
+ offset = offset.unsqueeze(1)
+ return boxes + offset
+
+
+def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
+ x0, y0, _, _ = crop_box
+ offset = torch.tensor([[x0, y0]], device=points.device)
+ # Check if points has a channel dimension
+ if len(points.shape) == 3:
+ offset = offset.unsqueeze(1)
+ return points + offset
+
+
+def uncrop_masks(
+ masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int
+) -> torch.Tensor:
+ x0, y0, x1, y1 = crop_box
+ if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h:
+ return masks
+ # Coordinate transform masks
+ pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0)
+ pad = (x0, pad_x - x0, y0, pad_y - y0)
+ return torch.nn.functional.pad(masks, pad, value=0)
+
+
+def remove_small_regions(
+ mask: np.ndarray, area_thresh: float, mode: str
+) -> Tuple[np.ndarray, bool]:
+ """
+ Removes small disconnected regions and holes in a mask. Returns the
+ mask and an indicator of if the mask has been modified.
+ """
+ import cv2 # type: ignore
+
+ assert mode in ["holes", "islands"]
+ correct_holes = mode == "holes"
+ working_mask = (correct_holes ^ mask).astype(np.uint8)
+ n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
+ sizes = stats[:, -1][1:] # Row 0 is background label
+ small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
+ if len(small_regions) == 0:
+ return mask, False
+ fill_labels = [0] + small_regions
+ if not correct_holes:
+ fill_labels = [i for i in range(n_labels) if i not in fill_labels]
+ # If every region is below threshold, keep largest
+ if len(fill_labels) == 0:
+ fill_labels = [int(np.argmax(sizes)) + 1]
+ mask = np.isin(regions, fill_labels)
+ return mask, True
+
+
+def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]:
+ from pycocotools import mask as mask_utils # type: ignore
+
+ h, w = uncompressed_rle["size"]
+ rle = mask_utils.frPyObjects(uncompressed_rle, h, w)
+ rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json
+ return rle
+
+
+def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
+ """
+ Calculates boxes in XYXY format around masks. Return [0,0,0,0] for
+ an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4.
+ """
+ # torch.max below raises an error on empty inputs, just skip in this case
+ if torch.numel(masks) == 0:
+ return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
+
+ # Normalize shape to CxHxW
+ shape = masks.shape
+ h, w = shape[-2:]
+ if len(shape) > 2:
+ masks = masks.flatten(0, -3)
+ else:
+ masks = masks.unsqueeze(0)
+
+ # Get top and bottom edges
+ in_height, _ = torch.max(masks, dim=-1)
+ in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :]
+ bottom_edges, _ = torch.max(in_height_coords, dim=-1)
+ in_height_coords = in_height_coords + h * (~in_height)
+ top_edges, _ = torch.min(in_height_coords, dim=-1)
+
+ # Get left and right edges
+ in_width, _ = torch.max(masks, dim=-2)
+ in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :]
+ right_edges, _ = torch.max(in_width_coords, dim=-1)
+ in_width_coords = in_width_coords + w * (~in_width)
+ left_edges, _ = torch.min(in_width_coords, dim=-1)
+
+ # If the mask is empty the right edge will be to the left of the left edge.
+ # Replace these boxes with [0, 0, 0, 0]
+ empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
+ out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1)
+ out = out * (~empty_filter).unsqueeze(-1)
+
+ # Return to original shape
+ if len(shape) > 2:
+ out = out.reshape(*shape[:-2], 4)
+ else:
+ out = out[0]
+
+ return out
diff --git a/libs/sam2/utils/misc.py b/libs/sam2/utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..b65ee825732ff85137805be650edd4cbe8e6f6d4
--- /dev/null
+++ b/libs/sam2/utils/misc.py
@@ -0,0 +1,349 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import warnings
+from threading import Thread
+
+import numpy as np
+import torch
+from PIL import Image
+from tqdm import tqdm
+
+
+def get_sdpa_settings():
+ if torch.cuda.is_available():
+ old_gpu = torch.cuda.get_device_properties(0).major < 7
+ # only use Flash Attention on Ampere (8.0) or newer GPUs
+ use_flash_attn = torch.cuda.get_device_properties(0).major >= 8
+ if not use_flash_attn:
+ warnings.warn(
+ "Flash Attention is disabled as it requires a GPU with Ampere (8.0) CUDA capability.",
+ category=UserWarning,
+ stacklevel=2,
+ )
+ # keep math kernel for PyTorch versions before 2.2 (Flash Attention v2 is only
+ # available on PyTorch 2.2+, while Flash Attention v1 cannot handle all cases)
+ pytorch_version = tuple(int(v) for v in torch.__version__.split(".")[:2])
+ if pytorch_version < (2, 2):
+ warnings.warn(
+ f"You are using PyTorch {torch.__version__} without Flash Attention v2 support. "
+ "Consider upgrading to PyTorch 2.2+ for Flash Attention v2 (which could be faster).",
+ category=UserWarning,
+ stacklevel=2,
+ )
+ math_kernel_on = pytorch_version < (2, 2) or not use_flash_attn
+ else:
+ old_gpu = True
+ use_flash_attn = False
+ math_kernel_on = True
+
+ return old_gpu, use_flash_attn, math_kernel_on
+
+
+def get_connected_components(mask):
+ """
+ Get the connected components (8-connectivity) of binary masks of shape (N, 1, H, W).
+
+ Inputs:
+ - mask: A binary mask tensor of shape (N, 1, H, W), where 1 is foreground and 0 is
+ background.
+
+ Outputs:
+ - labels: A tensor of shape (N, 1, H, W) containing the connected component labels
+ for foreground pixels and 0 for background pixels.
+ - counts: A tensor of shape (N, 1, H, W) containing the area of the connected
+ components for foreground pixels and 0 for background pixels.
+ """
+ from sam2 import _C
+
+ return _C.get_connected_componnets(mask.to(torch.uint8).contiguous())
+
+
+def mask_to_box(masks: torch.Tensor):
+ """
+ compute bounding box given an input mask
+
+ Inputs:
+ - masks: [B, 1, H, W] masks, dtype=torch.Tensor
+
+ Returns:
+ - box_coords: [B, 1, 4], contains (x, y) coordinates of top left and bottom right box corners, dtype=torch.Tensor
+ """
+ B, _, h, w = masks.shape
+ device = masks.device
+ xs = torch.arange(w, device=device, dtype=torch.int32)
+ ys = torch.arange(h, device=device, dtype=torch.int32)
+ grid_xs, grid_ys = torch.meshgrid(xs, ys, indexing="xy")
+ grid_xs = grid_xs[None, None, ...].expand(B, 1, h, w)
+ grid_ys = grid_ys[None, None, ...].expand(B, 1, h, w)
+ min_xs, _ = torch.min(torch.where(masks, grid_xs, w).flatten(-2), dim=-1)
+ max_xs, _ = torch.max(torch.where(masks, grid_xs, -1).flatten(-2), dim=-1)
+ min_ys, _ = torch.min(torch.where(masks, grid_ys, h).flatten(-2), dim=-1)
+ max_ys, _ = torch.max(torch.where(masks, grid_ys, -1).flatten(-2), dim=-1)
+ bbox_coords = torch.stack((min_xs, min_ys, max_xs, max_ys), dim=-1)
+
+ return bbox_coords
+
+
+def _load_img_as_tensor(img_path, image_size):
+ img_pil = Image.open(img_path)
+ img_np = np.array(img_pil.convert("RGB").resize((image_size, image_size)))
+ if img_np.dtype == np.uint8: # np.uint8 is expected for JPEG images
+ img_np = img_np / 255.0
+ else:
+ raise RuntimeError(f"Unknown image dtype: {img_np.dtype} on {img_path}")
+ img = torch.from_numpy(img_np).permute(2, 0, 1)
+ video_width, video_height = img_pil.size # the original video size
+ return img, video_height, video_width
+
+
+class AsyncVideoFrameLoader:
+ """
+ A list of video frames to be load asynchronously without blocking session start.
+ """
+
+ def __init__(
+ self,
+ img_paths,
+ image_size,
+ offload_video_to_cpu,
+ img_mean,
+ img_std,
+ compute_device,
+ ):
+ self.img_paths = img_paths
+ self.image_size = image_size
+ self.offload_video_to_cpu = offload_video_to_cpu
+ self.img_mean = img_mean
+ self.img_std = img_std
+ # items in `self.images` will be loaded asynchronously
+ self.images = [None] * len(img_paths)
+ # catch and raise any exceptions in the async loading thread
+ self.exception = None
+ # video_height and video_width be filled when loading the first image
+ self.video_height = None
+ self.video_width = None
+ self.compute_device = compute_device
+
+ # load the first frame to fill video_height and video_width and also
+ # to cache it (since it's most likely where the user will click)
+ self.__getitem__(0)
+
+ # load the rest of frames asynchronously without blocking the session start
+ def _load_frames():
+ try:
+ for n in tqdm(range(len(self.images)), desc="frame loading (JPEG)"):
+ self.__getitem__(n)
+ except Exception as e:
+ self.exception = e
+
+ self.thread = Thread(target=_load_frames, daemon=True)
+ self.thread.start()
+
+ def __getitem__(self, index):
+ if self.exception is not None:
+ raise RuntimeError("Failure in frame loading thread") from self.exception
+
+ img = self.images[index]
+ if img is not None:
+ return img
+
+ img, video_height, video_width = _load_img_as_tensor(
+ self.img_paths[index], self.image_size
+ )
+ self.video_height = video_height
+ self.video_width = video_width
+ # normalize by mean and std
+ img -= self.img_mean
+ img /= self.img_std
+ if not self.offload_video_to_cpu:
+ img = img.to(self.compute_device, non_blocking=True)
+ self.images[index] = img
+ return img
+
+ def __len__(self):
+ return len(self.images)
+
+
+def load_video_frames(
+ video_path,
+ image_size,
+ offload_video_to_cpu,
+ img_mean=(0.485, 0.456, 0.406),
+ img_std=(0.229, 0.224, 0.225),
+ async_loading_frames=False,
+ compute_device=torch.device("cuda"),
+):
+ """
+ Load the video frames from video_path. The frames are resized to image_size as in
+ the model and are loaded to GPU if offload_video_to_cpu=False. This is used by the demo.
+ """
+ is_bytes = isinstance(video_path, bytes)
+ is_str = isinstance(video_path, str)
+ is_mp4_path = is_str and os.path.splitext(video_path)[-1] in [".mp4", ".MP4"]
+ if is_bytes or is_mp4_path:
+ return load_video_frames_from_video_file(
+ video_path=video_path,
+ image_size=image_size,
+ offload_video_to_cpu=offload_video_to_cpu,
+ img_mean=img_mean,
+ img_std=img_std,
+ compute_device=compute_device,
+ )
+ elif is_str and os.path.isdir(video_path):
+ return load_video_frames_from_jpg_images(
+ video_path=video_path,
+ image_size=image_size,
+ offload_video_to_cpu=offload_video_to_cpu,
+ img_mean=img_mean,
+ img_std=img_std,
+ async_loading_frames=async_loading_frames,
+ compute_device=compute_device,
+ )
+ else:
+ raise NotImplementedError(
+ "Only MP4 video and JPEG folder are supported at this moment"
+ )
+
+
+def load_video_frames_from_jpg_images(
+ video_path,
+ image_size,
+ offload_video_to_cpu,
+ img_mean=(0.485, 0.456, 0.406),
+ img_std=(0.229, 0.224, 0.225),
+ async_loading_frames=False,
+ compute_device=torch.device("cuda"),
+):
+ """
+ Load the video frames from a directory of JPEG files (".jpg" format).
+
+ The frames are resized to image_size x image_size and are loaded to GPU if
+ `offload_video_to_cpu` is `False` and to CPU if `offload_video_to_cpu` is `True`.
+
+ You can load a frame asynchronously by setting `async_loading_frames` to `True`.
+ """
+ if isinstance(video_path, str) and os.path.isdir(video_path):
+ jpg_folder = video_path
+ else:
+ raise NotImplementedError(
+ "Only JPEG frames are supported at this moment. For video files, you may use "
+ "ffmpeg (https://ffmpeg.org/) to extract frames into a folder of JPEG files, such as \n"
+ "```\n"
+ "ffmpeg -i .mp4 -q:v 2 -start_number 0 /'%05d.jpg'\n"
+ "```\n"
+ "where `-q:v` generates high-quality JPEG frames and `-start_number 0` asks "
+ "ffmpeg to start the JPEG file from 00000.jpg."
+ )
+
+ frame_names = [
+ p
+ for p in os.listdir(jpg_folder)
+ if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
+ ]
+ frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
+ num_frames = len(frame_names)
+ if num_frames == 0:
+ raise RuntimeError(f"no images found in {jpg_folder}")
+ img_paths = [os.path.join(jpg_folder, frame_name) for frame_name in frame_names]
+ img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None]
+ img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None]
+
+ if async_loading_frames:
+ lazy_images = AsyncVideoFrameLoader(
+ img_paths,
+ image_size,
+ offload_video_to_cpu,
+ img_mean,
+ img_std,
+ compute_device,
+ )
+ return lazy_images, lazy_images.video_height, lazy_images.video_width
+
+ images = torch.zeros(num_frames, 3, image_size, image_size, dtype=torch.float32)
+ for n, img_path in enumerate(tqdm(img_paths, desc="frame loading (JPEG)")):
+ images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size)
+ if not offload_video_to_cpu:
+ images = images.to(compute_device)
+ img_mean = img_mean.to(compute_device)
+ img_std = img_std.to(compute_device)
+ # normalize by mean and std
+ images -= img_mean
+ images /= img_std
+ return images, video_height, video_width
+
+
+def load_video_frames_from_video_file(
+ video_path,
+ image_size,
+ offload_video_to_cpu,
+ img_mean=(0.485, 0.456, 0.406),
+ img_std=(0.229, 0.224, 0.225),
+ compute_device=torch.device("cuda"),
+):
+ """Load the video frames from a video file."""
+ import decord
+
+ img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None]
+ img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None]
+ # Get the original video height and width
+ decord.bridge.set_bridge("torch")
+ video_height, video_width, _ = decord.VideoReader(video_path).next().shape
+ # Iterate over all frames in the video
+ images = []
+ for frame in decord.VideoReader(video_path, width=image_size, height=image_size):
+ images.append(frame.permute(2, 0, 1))
+
+ images = torch.stack(images, dim=0).float() / 255.0
+ if not offload_video_to_cpu:
+ images = images.to(compute_device)
+ img_mean = img_mean.to(compute_device)
+ img_std = img_std.to(compute_device)
+ # normalize by mean and std
+ images -= img_mean
+ images /= img_std
+ return images, video_height, video_width
+
+
+def fill_holes_in_mask_scores(mask, max_area):
+ """
+ A post processor to fill small holes in mask scores with area under `max_area`.
+ """
+ # Holes are those connected components in background with area <= self.max_area
+ # (background regions are those with mask scores <= 0)
+ assert max_area > 0, "max_area must be positive"
+
+ input_mask = mask
+ try:
+ labels, areas = get_connected_components(mask <= 0)
+ is_hole = (labels > 0) & (areas <= max_area)
+ # We fill holes with a small positive mask score (0.1) to change them to foreground.
+ mask = torch.where(is_hole, 0.1, mask)
+ except Exception as e:
+ # Skip the post-processing step on removing small holes if the CUDA kernel fails
+ warnings.warn(
+ f"{e}\n\nSkipping the post-processing step due to the error above. You can "
+ "still use SAM 2 and it's OK to ignore the error above, although some post-processing "
+ "functionality may be limited (which doesn't affect the results in most cases; see "
+ "https://github.com/facebookresearch/sam2/blob/main/INSTALL.md).",
+ category=UserWarning,
+ stacklevel=2,
+ )
+ mask = input_mask
+
+ return mask
+
+
+def concat_points(old_point_inputs, new_points, new_labels):
+ """Add new points and labels to previous point inputs (add at the end)."""
+ if old_point_inputs is None:
+ points, labels = new_points, new_labels
+ else:
+ points = torch.cat([old_point_inputs["point_coords"], new_points], dim=1)
+ labels = torch.cat([old_point_inputs["point_labels"], new_labels], dim=1)
+
+ return {"point_coords": points, "point_labels": labels}
diff --git a/libs/sam2/utils/transforms.py b/libs/sam2/utils/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc17bebfab104b659c5469e8434cf357ae7e24b6
--- /dev/null
+++ b/libs/sam2/utils/transforms.py
@@ -0,0 +1,118 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import warnings
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torchvision.transforms import Normalize, Resize, ToTensor
+
+
+class SAM2Transforms(nn.Module):
+ def __init__(
+ self, resolution, mask_threshold, max_hole_area=0.0, max_sprinkle_area=0.0
+ ):
+ """
+ Transforms for SAM2.
+ """
+ super().__init__()
+ self.resolution = resolution
+ self.mask_threshold = mask_threshold
+ self.max_hole_area = max_hole_area
+ self.max_sprinkle_area = max_sprinkle_area
+ self.mean = [0.485, 0.456, 0.406]
+ self.std = [0.229, 0.224, 0.225]
+ self.to_tensor = ToTensor()
+ self.transforms = torch.jit.script(
+ nn.Sequential(
+ Resize((self.resolution, self.resolution)),
+ Normalize(self.mean, self.std),
+ )
+ )
+
+ def __call__(self, x):
+ x = self.to_tensor(x)
+ return self.transforms(x)
+
+ def forward_batch(self, img_list):
+ img_batch = [self.transforms(self.to_tensor(img)) for img in img_list]
+ img_batch = torch.stack(img_batch, dim=0)
+ return img_batch
+
+ def transform_coords(
+ self, coords: torch.Tensor, normalize=False, orig_hw=None
+ ) -> torch.Tensor:
+ """
+ Expects a torch tensor with length 2 in the last dimension. The coordinates can be in absolute image or normalized coordinates,
+ If the coords are in absolute image coordinates, normalize should be set to True and original image size is required.
+
+ Returns
+ Un-normalized coordinates in the range of [0, 1] which is expected by the SAM2 model.
+ """
+ if normalize:
+ assert orig_hw is not None
+ h, w = orig_hw
+ coords = coords.clone()
+ coords[..., 0] = coords[..., 0] / w
+ coords[..., 1] = coords[..., 1] / h
+
+ coords = coords * self.resolution # unnormalize coords
+ return coords
+
+ def transform_boxes(
+ self, boxes: torch.Tensor, normalize=False, orig_hw=None
+ ) -> torch.Tensor:
+ """
+ Expects a tensor of shape Bx4. The coordinates can be in absolute image or normalized coordinates,
+ if the coords are in absolute image coordinates, normalize should be set to True and original image size is required.
+ """
+ boxes = self.transform_coords(boxes.reshape(-1, 2, 2), normalize, orig_hw)
+ return boxes
+
+ def postprocess_masks(self, masks: torch.Tensor, orig_hw) -> torch.Tensor:
+ """
+ Perform PostProcessing on output masks.
+ """
+ from sam2.utils.misc import get_connected_components
+
+ masks = masks.float()
+ input_masks = masks
+ mask_flat = masks.flatten(0, 1).unsqueeze(1) # flatten as 1-channel image
+ try:
+ if self.max_hole_area > 0:
+ # Holes are those connected components in background with area <= self.fill_hole_area
+ # (background regions are those with mask scores <= self.mask_threshold)
+ labels, areas = get_connected_components(
+ mask_flat <= self.mask_threshold
+ )
+ is_hole = (labels > 0) & (areas <= self.max_hole_area)
+ is_hole = is_hole.reshape_as(masks)
+ # We fill holes with a small positive mask score (10.0) to change them to foreground.
+ masks = torch.where(is_hole, self.mask_threshold + 10.0, masks)
+
+ if self.max_sprinkle_area > 0:
+ labels, areas = get_connected_components(
+ mask_flat > self.mask_threshold
+ )
+ is_hole = (labels > 0) & (areas <= self.max_sprinkle_area)
+ is_hole = is_hole.reshape_as(masks)
+ # We fill holes with negative mask score (-10.0) to change them to background.
+ masks = torch.where(is_hole, self.mask_threshold - 10.0, masks)
+ except Exception as e:
+ # Skip the post-processing step if the CUDA kernel fails
+ warnings.warn(
+ f"{e}\n\nSkipping the post-processing step due to the error above. You can "
+ "still use SAM 2 and it's OK to ignore the error above, although some post-processing "
+ "functionality may be limited (which doesn't affect the results in most cases; see "
+ "https://github.com/facebookresearch/sam2/blob/main/INSTALL.md).",
+ category=UserWarning,
+ stacklevel=2,
+ )
+ masks = input_masks
+
+ masks = F.interpolate(masks, orig_hw, mode="bilinear", align_corners=False)
+ return masks
diff --git a/libs/sv3d/.gitignore b/libs/sv3d/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..accdfedefb65597f79c22443892f6cc7a627917c
--- /dev/null
+++ b/libs/sv3d/.gitignore
@@ -0,0 +1,4 @@
+__pycache__/
+out/
+temp/
+temp*.*
diff --git a/libs/sv3d/README.md b/libs/sv3d/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..f6fdc6517315f41d8f8c324728f22e858363e86d
--- /dev/null
+++ b/libs/sv3d/README.md
@@ -0,0 +1,53 @@
+# SV3D-diffusers
+
+
+
+This repo provides scripts about:
+
+1. Spatio-temporal UNet (`SV3DUNetSpatioTemporalConditionModel`) and pipeline (`StableVideo3DDiffusionPipeline`) modified from [SVD](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py) for [SV3D](https://sv3d.github.io) in the [diffusers](https://github.com/huggingface/diffusers) convention.
+
+2. Converting the [Stability-AI](https://github.com/Stability-AI/generative-models)'s [SV3D-p UNet checkpoint](https://huggingface.co/stabilityai/sv3d) to the [diffusers](https://github.com/huggingface/diffusers) convention.
+
+3. Infering the `SV3D-p` model with the [diffusers](https://github.com/huggingface/diffusers) library to synthesize a 21-frame orbital video around a 3D object from a single-view image (preprocessed by removing background and centering first).
+
+Converted SV3D-p checkpoints have been uploaded to HuggingFace🤗 [chenguolin/sv3d-diffusers](https://huggingface.co/chenguolin/sv3d-diffusers).
+
+
+## 🔥 See Also
+
+You may also be interested in our works:
+- [**[ICLR 2025] DiffSplat**](https://github.com/chenguolin/DiffSplat): generate 3D objects in 3DGS directly by fine-tuning a text-to-image models.
+- [**[NeurIPS 2024] HumanSplat**](https://github.com/humansplat/humansplat): SV3D is fine-tuned on human datasets for single-view human reconstruction.
+
+
+## 🚀 Usage
+```bash
+git clone https://github.com/chenguolin/sv3d-diffusers.git
+# Please install PyTorch first according to your CUDA version
+pip3 install -r requirements.txt
+# If you can't access to HuggingFace🤗, try:
+# export HF_ENDPOINT=https://hf-mirror.com
+python3 infer.py --output_dir out/ --image_path assets/images/sculpture.png --elevation 10 --half_precision --seed -1
+```
+The synthesized video will save at `out/` as a `.gif` file.
+
+
+## 📸 Results
+> Image preprocessing and random seed for different implementations are different, so the results are presented only for reference.
+
+| Implementation | sculpture | bag | kunkun |
+| :------------- | :------: | :----: | :----: |
+| **SV3D-diffusers (Ours)** |  |  |  |
+| **Official SV3D** |  |  |  |
+
+
+## 📚 Citation
+If you find this repo helpful, please consider giving this repository a star 🌟 and citing the original SV3D paper.
+```
+@inproceedings{voleti2024sv3d,
+ author={Voleti, Vikram and Yao, Chun-Han and Boss, Mark and Letts, Adam and Pankratz, David and Tochilkin, Dmitrii and Laforte, Christian and Rombach, Robin and Jampani, Varun},
+ title={{SV3D}: Novel Multi-view Synthesis and {3D} Generation from a Single Image using Latent Video Diffusion},
+ booktitle={European Conference on Computer Vision (ECCV)},
+ year={2024},
+}
+```
diff --git a/libs/sv3d/convert/convert_sv3d.py b/libs/sv3d/convert/convert_sv3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d4a0756d4c6c4fc312cd2aad198709cfd47c5bc
--- /dev/null
+++ b/libs/sv3d/convert/convert_sv3d.py
@@ -0,0 +1,73 @@
+import os
+import argparse
+
+from huggingface_hub import hf_hub_download
+import safetensors.torch
+from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
+from diffusers import (
+ AutoencoderKL,
+ # AutoencoderKLTemporalDecoder,
+ EulerDiscreteScheduler,
+)
+
+from convert.convert_svd_to_diffusers import (
+ convert_ldm_unet_checkpoint,
+ # convert_ldm_vae_checkpoint,
+ create_unet_diffusers_config,
+)
+from diffusers_sv3d import SV3DUNetSpatioTemporalConditionModel, StableVideo3DDiffusionPipeline
+
+SVD_V1_CKPT = "stabilityai/stable-video-diffusion-img2vid-xt"
+SD_V15_CKPT = "chenguolin/stable-diffusion-v1-5"
+HF_HOME = "~/.cache/huggingface"
+HF_TOKEN = ""
+HF_USERNAME = ""
+
+# os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
+os.environ["HF_HOME"] = HF_HOME
+os.environ["HF_USERNAME"] = HF_USERNAME
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--original_ckpt_path", default=os.path.expanduser(f"{HF_HOME}/hub/models--stabilityai--sv3d/snapshots/31213729b4314a44b574ce7cc2d0c28356f097ed/sv3d_p.safetensors"), type=str, help="Path to the checkpoint to convert.")
+ parser.add_argument("--hf_token", default=HF_TOKEN, type=str, help="your HuggingFace token")
+ parser.add_argument("--config_path", default="convert/sv3d_p.yaml", type=str, help="Config filepath.")
+ parser.add_argument("--repo_name", default="sv3d-diffusers", type=str)
+ parser.add_argument("--push_to_hub", action="store_true")
+ args = parser.parse_args()
+
+ if not os.path.exists(args.original_ckpt_path):
+ token = HF_TOKEN # open(os.path.expanduser("~/.cache/huggingface/token"), "r").read()
+ hf_hub_download("stabilityai/sv3d", filename="sv3d_p.safetensors", token=token)
+ original_ckpt = safetensors.torch.load_file(args.original_ckpt_path, device="cpu")
+
+ from omegaconf import OmegaConf
+ config = OmegaConf.load(args.config_path)
+
+ unet_config = create_unet_diffusers_config(config, image_size=576)
+
+ ori_config = unet_config.copy()
+ unet_config.pop("attention_head_dim")
+ unet_config.pop("use_linear_projection")
+ unet_config.pop("class_embed_type")
+ unet_config.pop("addition_embed_type")
+ unet = SV3DUNetSpatioTemporalConditionModel(**unet_config)
+ unet_state_dict = convert_ldm_unet_checkpoint(original_ckpt, ori_config)
+ unet.load_state_dict(unet_state_dict, strict=True)
+
+ # unet.save_pretrained("out/sv3d-diffusers", push_to_hub=True)
+
+ vae = AutoencoderKL.from_pretrained(SD_V15_CKPT, subfolder="vae")
+ scheduler = EulerDiscreteScheduler.from_pretrained(SVD_V1_CKPT, subfolder="scheduler")
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(SVD_V1_CKPT, subfolder="image_encoder")
+ feature_extractor = CLIPImageProcessor.from_pretrained(SVD_V1_CKPT, subfolder="feature_extractor")
+
+ pipeline = StableVideo3DDiffusionPipeline(
+ image_encoder=image_encoder, feature_extractor=feature_extractor,
+ unet=unet, vae=vae,
+ scheduler=scheduler,
+ )
+
+ if args.push_to_hub:
+ pipeline.push_to_hub(args.repo_name)
diff --git a/libs/sv3d/convert/convert_svd_to_diffusers.py b/libs/sv3d/convert/convert_svd_to_diffusers.py
new file mode 100644
index 0000000000000000000000000000000000000000..02a932a7f4888799624a09aef63b86f11f1eaed7
--- /dev/null
+++ b/libs/sv3d/convert/convert_svd_to_diffusers.py
@@ -0,0 +1,728 @@
+# Copied from https://github.com/huggingface/diffusers/blob/main/scripts/convert_svd_to_diffusers.py
+from diffusers.utils import is_accelerate_available, logging
+
+
+if is_accelerate_available():
+ pass
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def create_unet_diffusers_config(original_config, image_size: int, controlnet=False):
+ """
+ Creates a config for the diffusers based on the config of the LDM model.
+ """
+ if controlnet:
+ unet_params = original_config.model.params.control_stage_config.params
+ else:
+ if "unet_config" in original_config.model.params and original_config.model.params.unet_config is not None:
+ unet_params = original_config.model.params.unet_config.params
+ else:
+ unet_params = original_config.model.params.network_config.params
+
+ vae_params = original_config.model.params.first_stage_config.params.decoder_config.params
+
+ block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
+
+ down_block_types = []
+ resolution = 1
+ for i in range(len(block_out_channels)):
+ block_type = (
+ "CrossAttnDownBlockSpatioTemporal"
+ if resolution in unet_params.attention_resolutions
+ else "DownBlockSpatioTemporal"
+ )
+ down_block_types.append(block_type)
+ if i != len(block_out_channels) - 1:
+ resolution *= 2
+
+ up_block_types = []
+ for i in range(len(block_out_channels)):
+ block_type = (
+ "CrossAttnUpBlockSpatioTemporal"
+ if resolution in unet_params.attention_resolutions
+ else "UpBlockSpatioTemporal"
+ )
+ up_block_types.append(block_type)
+ resolution //= 2
+
+ if unet_params.transformer_depth is not None:
+ transformer_layers_per_block = (
+ unet_params.transformer_depth
+ if isinstance(unet_params.transformer_depth, int)
+ else list(unet_params.transformer_depth)
+ )
+ else:
+ transformer_layers_per_block = 1
+
+ vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1)
+
+ head_dim = unet_params.num_heads if "num_heads" in unet_params else None
+ use_linear_projection = (
+ unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False
+ )
+ if use_linear_projection:
+ # stable diffusion 2-base-512 and 2-768
+ if head_dim is None:
+ head_dim_mult = unet_params.model_channels // unet_params.num_head_channels
+ head_dim = [head_dim_mult * c for c in list(unet_params.channel_mult)]
+
+ class_embed_type = None
+ addition_embed_type = None
+ addition_time_embed_dim = None
+ projection_class_embeddings_input_dim = None
+ context_dim = None
+
+ if unet_params.context_dim is not None:
+ context_dim = (
+ unet_params.context_dim if isinstance(unet_params.context_dim, int) else unet_params.context_dim[0]
+ )
+
+ if "num_classes" in unet_params:
+ if unet_params.num_classes == "sequential":
+ addition_time_embed_dim = 256
+ assert "adm_in_channels" in unet_params
+ projection_class_embeddings_input_dim = unet_params.adm_in_channels
+
+ config = {
+ "sample_size": image_size // vae_scale_factor,
+ "in_channels": unet_params.in_channels,
+ "down_block_types": tuple(down_block_types),
+ "block_out_channels": tuple(block_out_channels),
+ "layers_per_block": unet_params.num_res_blocks,
+ "cross_attention_dim": context_dim,
+ "attention_head_dim": tuple(head_dim),
+ "use_linear_projection": use_linear_projection,
+ "class_embed_type": class_embed_type,
+ "addition_embed_type": addition_embed_type,
+ "addition_time_embed_dim": addition_time_embed_dim,
+ "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
+ "transformer_layers_per_block": transformer_layers_per_block,
+ }
+
+ if "disable_self_attentions" in unet_params:
+ config["only_cross_attention"] = unet_params.disable_self_attentions
+
+ if "num_classes" in unet_params and isinstance(unet_params.num_classes, int):
+ config["num_class_embeds"] = unet_params.num_classes
+
+ if controlnet:
+ config["conditioning_channels"] = unet_params.hint_channels
+ else:
+ config["out_channels"] = unet_params.out_channels
+ config["up_block_types"] = tuple(up_block_types)
+
+ return config
+
+
+def assign_to_checkpoint(
+ paths,
+ checkpoint,
+ old_checkpoint,
+ attention_paths_to_split=None,
+ additional_replacements=None,
+ config=None,
+ mid_block_suffix="",
+):
+ """
+ This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
+ attention layers, and takes into account additional replacements that may arise.
+
+ Assigns the weights to the new checkpoint.
+ """
+ assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
+
+ # Splits the attention layers into three variables.
+ if attention_paths_to_split is not None:
+ for path, path_map in attention_paths_to_split.items():
+ old_tensor = old_checkpoint[path]
+ channels = old_tensor.shape[0] // 3
+
+ target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
+
+ num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
+
+ old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
+ query, key, value = old_tensor.split(channels // num_heads, dim=1)
+
+ checkpoint[path_map["query"]] = query.reshape(target_shape)
+ checkpoint[path_map["key"]] = key.reshape(target_shape)
+ checkpoint[path_map["value"]] = value.reshape(target_shape)
+
+ if mid_block_suffix is not None:
+ mid_block_suffix = f".{mid_block_suffix}"
+ else:
+ mid_block_suffix = ""
+
+ for path in paths:
+ new_path = path["new"]
+
+ # These have already been assigned
+ if attention_paths_to_split is not None and new_path in attention_paths_to_split:
+ continue
+
+ # Global renaming happens here
+ new_path = new_path.replace("middle_block.0", f"mid_block.resnets.0{mid_block_suffix}")
+ new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
+ new_path = new_path.replace("middle_block.2", f"mid_block.resnets.1{mid_block_suffix}")
+
+ if additional_replacements is not None:
+ for replacement in additional_replacements:
+ new_path = new_path.replace(replacement["old"], replacement["new"])
+
+ # proj_attn.weight has to be converted from conv 1D to linear
+ is_attn_weight = "proj_attn.weight" in new_path or ("attentions" in new_path and "to_" in new_path)
+ shape = old_checkpoint[path["old"]].shape
+ if is_attn_weight and len(shape) == 3:
+ checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
+ elif is_attn_weight and len(shape) == 4:
+ checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0]
+ else:
+ checkpoint[new_path] = old_checkpoint[path["old"]]
+
+
+def renew_attention_paths(old_list, n_shave_prefix_segments=0):
+ """
+ Updates paths inside attentions to the new naming scheme (local renaming)
+ """
+ mapping = []
+ for old_item in old_list:
+ new_item = old_item
+
+ # new_item = new_item.replace('norm.weight', 'group_norm.weight')
+ # new_item = new_item.replace('norm.bias', 'group_norm.bias')
+
+ # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
+ # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
+
+ # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
+ new_item = new_item.replace("time_stack", "temporal_transformer_blocks")
+
+ new_item = new_item.replace("time_pos_embed.0.bias", "time_pos_embed.linear_1.bias")
+ new_item = new_item.replace("time_pos_embed.0.weight", "time_pos_embed.linear_1.weight")
+ new_item = new_item.replace("time_pos_embed.2.bias", "time_pos_embed.linear_2.bias")
+ new_item = new_item.replace("time_pos_embed.2.weight", "time_pos_embed.linear_2.weight")
+
+ mapping.append({"old": old_item, "new": new_item})
+
+ return mapping
+
+
+def shave_segments(path, n_shave_prefix_segments=1):
+ """
+ Removes segments. Positive values shave the first segments, negative shave the last segments.
+ """
+ if n_shave_prefix_segments >= 0:
+ return ".".join(path.split(".")[n_shave_prefix_segments:])
+ else:
+ return ".".join(path.split(".")[:n_shave_prefix_segments])
+
+
+def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
+ """
+ Updates paths inside resnets to the new naming scheme (local renaming)
+ """
+ mapping = []
+ for old_item in old_list:
+ new_item = old_item.replace("in_layers.0", "norm1")
+ new_item = new_item.replace("in_layers.2", "conv1")
+
+ new_item = new_item.replace("out_layers.0", "norm2")
+ new_item = new_item.replace("out_layers.3", "conv2")
+
+ new_item = new_item.replace("emb_layers.1", "time_emb_proj")
+ new_item = new_item.replace("skip_connection", "conv_shortcut")
+
+ new_item = new_item.replace("time_stack.", "")
+
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
+
+ mapping.append({"old": old_item, "new": new_item})
+
+ return mapping
+
+
+def convert_ldm_unet_checkpoint(
+ checkpoint, config, path=None, extract_ema=False, controlnet=False, skip_extract_state_dict=False
+):
+ """
+ Takes a state dict and a config, and returns a converted checkpoint.
+ """
+
+ if skip_extract_state_dict:
+ unet_state_dict = checkpoint
+ else:
+ # extract state_dict for UNet
+ unet_state_dict = {}
+ keys = list(checkpoint.keys())
+
+ unet_key = "model.diffusion_model."
+
+ # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
+ if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
+ logger.warning(f"Checkpoint {path} has both EMA and non-EMA weights.")
+ logger.warning(
+ "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
+ " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
+ )
+ for key in keys:
+ if key.startswith("model.diffusion_model"):
+ flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
+ else:
+ if sum(k.startswith("model_ema") for k in keys) > 100:
+ logger.warning(
+ "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
+ " weights (usually better for inference), please make sure to add the `--extract_ema` flag."
+ )
+
+ for key in keys:
+ if key.startswith(unet_key):
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
+
+ new_checkpoint = {}
+
+ new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
+ new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
+ new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
+ new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
+
+ if config["class_embed_type"] is None:
+ # No parameters to port
+ ...
+ elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection":
+ new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"]
+ new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"]
+ new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
+ new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
+ else:
+ raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}")
+
+ # if config["addition_embed_type"] == "text_time":
+ new_checkpoint["add_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"]
+ new_checkpoint["add_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"]
+ new_checkpoint["add_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
+ new_checkpoint["add_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
+
+ new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
+ new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
+
+ new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
+ new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
+ new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
+ new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
+
+ # Retrieves the keys for the input blocks only
+ num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
+ input_blocks = {
+ layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
+ for layer_id in range(num_input_blocks)
+ }
+
+ # Retrieves the keys for the middle blocks only
+ num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
+ middle_blocks = {
+ layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
+ for layer_id in range(num_middle_blocks)
+ }
+
+ # Retrieves the keys for the output blocks only
+ num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
+ output_blocks = {
+ layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
+ for layer_id in range(num_output_blocks)
+ }
+
+ for i in range(1, num_input_blocks):
+ block_id = (i - 1) // (config["layers_per_block"] + 1)
+ layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
+
+ spatial_resnets = [
+ key
+ for key in input_blocks[i]
+ if f"input_blocks.{i}.0" in key
+ and (
+ f"input_blocks.{i}.0.op" not in key
+ and f"input_blocks.{i}.0.time_stack" not in key
+ and f"input_blocks.{i}.0.time_mixer" not in key
+ )
+ ]
+ temporal_resnets = [key for key in input_blocks[i] if f"input_blocks.{i}.0.time_stack" in key]
+ # import ipdb; ipdb.set_trace()
+ attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
+
+ if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
+ f"input_blocks.{i}.0.op.weight"
+ )
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
+ f"input_blocks.{i}.0.op.bias"
+ )
+
+ paths = renew_resnet_paths(spatial_resnets)
+ meta_path = {
+ "old": f"input_blocks.{i}.0",
+ "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}.spatial_res_block",
+ }
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ paths = renew_resnet_paths(temporal_resnets)
+ meta_path = {
+ "old": f"input_blocks.{i}.0",
+ "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}.temporal_res_block",
+ }
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ # TODO resnet time_mixer.mix_factor
+ if f"input_blocks.{i}.0.time_mixer.mix_factor" in unet_state_dict:
+ new_checkpoint[
+ f"down_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor"
+ ] = unet_state_dict[f"input_blocks.{i}.0.time_mixer.mix_factor"]
+
+ if len(attentions):
+ paths = renew_attention_paths(attentions)
+ meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
+ # import ipdb; ipdb.set_trace()
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ resnet_0 = middle_blocks[0]
+ attentions = middle_blocks[1]
+ resnet_1 = middle_blocks[2]
+
+ resnet_0_spatial = [key for key in resnet_0 if "time_stack" not in key and "time_mixer" not in key]
+ resnet_0_paths = renew_resnet_paths(resnet_0_spatial)
+ # import ipdb; ipdb.set_trace()
+ assign_to_checkpoint(
+ resnet_0_paths, new_checkpoint, unet_state_dict, config=config, mid_block_suffix="spatial_res_block"
+ )
+
+ resnet_0_temporal = [key for key in resnet_0 if "time_stack" in key and "time_mixer" not in key]
+ resnet_0_paths = renew_resnet_paths(resnet_0_temporal)
+ assign_to_checkpoint(
+ resnet_0_paths, new_checkpoint, unet_state_dict, config=config, mid_block_suffix="temporal_res_block"
+ )
+
+ resnet_1_spatial = [key for key in resnet_1 if "time_stack" not in key and "time_mixer" not in key]
+ resnet_1_paths = renew_resnet_paths(resnet_1_spatial)
+ assign_to_checkpoint(
+ resnet_1_paths, new_checkpoint, unet_state_dict, config=config, mid_block_suffix="spatial_res_block"
+ )
+
+ resnet_1_temporal = [key for key in resnet_1 if "time_stack" in key and "time_mixer" not in key]
+ resnet_1_paths = renew_resnet_paths(resnet_1_temporal)
+ assign_to_checkpoint(
+ resnet_1_paths, new_checkpoint, unet_state_dict, config=config, mid_block_suffix="temporal_res_block"
+ )
+
+ new_checkpoint["mid_block.resnets.0.time_mixer.mix_factor"] = unet_state_dict[
+ "middle_block.0.time_mixer.mix_factor"
+ ]
+ new_checkpoint["mid_block.resnets.1.time_mixer.mix_factor"] = unet_state_dict[
+ "middle_block.2.time_mixer.mix_factor"
+ ]
+
+ attentions_paths = renew_attention_paths(attentions)
+ meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
+ assign_to_checkpoint(
+ attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ for i in range(num_output_blocks):
+ block_id = i // (config["layers_per_block"] + 1)
+ layer_in_block_id = i % (config["layers_per_block"] + 1)
+ output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
+ output_block_list = {}
+
+ for layer in output_block_layers:
+ layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
+ if layer_id in output_block_list:
+ output_block_list[layer_id].append(layer_name)
+ else:
+ output_block_list[layer_id] = [layer_name]
+
+ if len(output_block_list) > 1:
+ spatial_resnets = [
+ key
+ for key in output_blocks[i]
+ if f"output_blocks.{i}.0" in key
+ and (f"output_blocks.{i}.0.time_stack" not in key and "time_mixer" not in key)
+ ]
+ # import ipdb; ipdb.set_trace()
+
+ temporal_resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0.time_stack" in key]
+
+ paths = renew_resnet_paths(spatial_resnets)
+ meta_path = {
+ "old": f"output_blocks.{i}.0",
+ "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}.spatial_res_block",
+ }
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ paths = renew_resnet_paths(temporal_resnets)
+ meta_path = {
+ "old": f"output_blocks.{i}.0",
+ "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}.temporal_res_block",
+ }
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ if f"output_blocks.{i}.0.time_mixer.mix_factor" in unet_state_dict:
+ new_checkpoint[
+ f"up_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor"
+ ] = unet_state_dict[f"output_blocks.{i}.0.time_mixer.mix_factor"]
+
+ output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
+ if ["conv.bias", "conv.weight"] in output_block_list.values():
+ index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
+ f"output_blocks.{i}.{index}.conv.weight"
+ ]
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
+ f"output_blocks.{i}.{index}.conv.bias"
+ ]
+
+ # Clear attentions as they have been attributed above.
+ if len(attentions) == 2:
+ attentions = []
+
+ attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key and "conv" not in key]
+ if len(attentions):
+ paths = renew_attention_paths(attentions)
+ # import ipdb; ipdb.set_trace()
+ meta_path = {
+ "old": f"output_blocks.{i}.1",
+ "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
+ }
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+ else:
+ spatial_layers = [
+ layer for layer in output_block_layers if "time_stack" not in layer and "time_mixer" not in layer
+ ]
+ resnet_0_paths = renew_resnet_paths(spatial_layers, n_shave_prefix_segments=1)
+ # import ipdb; ipdb.set_trace()
+ for path in resnet_0_paths:
+ old_path = ".".join(["output_blocks", str(i), path["old"]])
+ new_path = ".".join(
+ ["up_blocks", str(block_id), "resnets", str(layer_in_block_id), "spatial_res_block", path["new"]]
+ )
+
+ new_checkpoint[new_path] = unet_state_dict[old_path]
+
+ temporal_layers = [
+ layer for layer in output_block_layers if "time_stack" in layer and "time_mixer" not in key
+ ]
+ resnet_0_paths = renew_resnet_paths(temporal_layers, n_shave_prefix_segments=1)
+ # import ipdb; ipdb.set_trace()
+ for path in resnet_0_paths:
+ old_path = ".".join(["output_blocks", str(i), path["old"]])
+ new_path = ".".join(
+ ["up_blocks", str(block_id), "resnets", str(layer_in_block_id), "temporal_res_block", path["new"]]
+ )
+
+ new_checkpoint[new_path] = unet_state_dict[old_path]
+
+ new_checkpoint["up_blocks.0.resnets.0.time_mixer.mix_factor"] = unet_state_dict[
+ f"output_blocks.{str(i)}.0.time_mixer.mix_factor"
+ ]
+
+ return new_checkpoint
+
+
+def conv_attn_to_linear(checkpoint):
+ keys = list(checkpoint.keys())
+ attn_keys = ["to_q.weight", "to_k.weight", "to_v.weight"]
+ for key in keys:
+ if ".".join(key.split(".")[-2:]) in attn_keys:
+ if checkpoint[key].ndim > 2:
+ checkpoint[key] = checkpoint[key][:, :, 0, 0]
+ elif "proj_attn.weight" in key:
+ if checkpoint[key].ndim > 2:
+ checkpoint[key] = checkpoint[key][:, :, 0]
+
+
+def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0, is_temporal=False):
+ """
+ Updates paths inside resnets to the new naming scheme (local renaming)
+ """
+ mapping = []
+ for old_item in old_list:
+ new_item = old_item
+
+ # Temporal resnet
+ new_item = old_item.replace("in_layers.0", "norm1")
+ new_item = new_item.replace("in_layers.2", "conv1")
+
+ new_item = new_item.replace("out_layers.0", "norm2")
+ new_item = new_item.replace("out_layers.3", "conv2")
+
+ new_item = new_item.replace("skip_connection", "conv_shortcut")
+
+ new_item = new_item.replace("time_stack.", "temporal_res_block.")
+
+ # Spatial resnet
+ new_item = new_item.replace("conv1", "spatial_res_block.conv1")
+ new_item = new_item.replace("norm1", "spatial_res_block.norm1")
+
+ new_item = new_item.replace("conv2", "spatial_res_block.conv2")
+ new_item = new_item.replace("norm2", "spatial_res_block.norm2")
+
+ new_item = new_item.replace("nin_shortcut", "spatial_res_block.conv_shortcut")
+
+ new_item = new_item.replace("mix_factor", "spatial_res_block.time_mixer.mix_factor")
+
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
+
+ mapping.append({"old": old_item, "new": new_item})
+
+ return mapping
+
+
+def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
+ """
+ Updates paths inside attentions to the new naming scheme (local renaming)
+ """
+ mapping = []
+ for old_item in old_list:
+ new_item = old_item
+
+ new_item = new_item.replace("norm.weight", "group_norm.weight")
+ new_item = new_item.replace("norm.bias", "group_norm.bias")
+
+ new_item = new_item.replace("q.weight", "to_q.weight")
+ new_item = new_item.replace("q.bias", "to_q.bias")
+
+ new_item = new_item.replace("k.weight", "to_k.weight")
+ new_item = new_item.replace("k.bias", "to_k.bias")
+
+ new_item = new_item.replace("v.weight", "to_v.weight")
+ new_item = new_item.replace("v.bias", "to_v.bias")
+
+ new_item = new_item.replace("proj_out.weight", "to_out.0.weight")
+ new_item = new_item.replace("proj_out.bias", "to_out.0.bias")
+
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
+
+ mapping.append({"old": old_item, "new": new_item})
+
+ return mapping
+
+
+def convert_ldm_vae_checkpoint(checkpoint, config):
+ # extract state dict for VAE
+ vae_state_dict = {}
+ keys = list(checkpoint.keys())
+ vae_key = "first_stage_model." if any(k.startswith("first_stage_model.") for k in keys) else ""
+ for key in keys:
+ if key.startswith(vae_key):
+ vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
+
+ new_checkpoint = {}
+
+ new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
+ new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
+ new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
+ new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
+ new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
+ new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
+
+ new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
+ new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
+ new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
+ new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
+ new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
+ new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
+ new_checkpoint["decoder.time_conv_out.weight"] = vae_state_dict["decoder.time_mix_conv.weight"]
+ new_checkpoint["decoder.time_conv_out.bias"] = vae_state_dict["decoder.time_mix_conv.bias"]
+
+ # new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
+ # new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
+ # new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
+ # new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
+
+ # Retrieves the keys for the encoder down blocks only
+ num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
+ down_blocks = {
+ layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
+ }
+
+ # Retrieves the keys for the decoder up blocks only
+ num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
+ up_blocks = {
+ layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
+ }
+
+ for i in range(num_down_blocks):
+ resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
+
+ if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
+ f"encoder.down.{i}.downsample.conv.weight"
+ )
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
+ f"encoder.down.{i}.downsample.conv.bias"
+ )
+
+ paths = renew_vae_resnet_paths(resnets)
+ meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+
+ mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
+ num_mid_res_blocks = 2
+ for i in range(1, num_mid_res_blocks + 1):
+ resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
+
+ paths = renew_vae_resnet_paths(resnets)
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+
+ mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
+ paths = renew_vae_attention_paths(mid_attentions)
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+ conv_attn_to_linear(new_checkpoint)
+
+ for i in range(num_up_blocks):
+ block_id = num_up_blocks - 1 - i
+
+ resnets = [
+ key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
+ ]
+
+ if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
+ f"decoder.up.{block_id}.upsample.conv.weight"
+ ]
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
+ f"decoder.up.{block_id}.upsample.conv.bias"
+ ]
+
+ paths = renew_vae_resnet_paths(resnets)
+ meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+
+ mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
+ num_mid_res_blocks = 2
+ for i in range(1, num_mid_res_blocks + 1):
+ resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
+
+ paths = renew_vae_resnet_paths(resnets)
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+
+ mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
+ paths = renew_vae_attention_paths(mid_attentions)
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+ conv_attn_to_linear(new_checkpoint)
+ return new_checkpoint
diff --git a/libs/sv3d/convert/sv3d_p.yaml b/libs/sv3d/convert/sv3d_p.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..04878915e67348256eb0322bf7ad0fb3be2c9295
--- /dev/null
+++ b/libs/sv3d/convert/sv3d_p.yaml
@@ -0,0 +1,134 @@
+# Copied from https://github.com/Stability-AI/generative-models/blob/main/scripts/sampling/configs/sv3d_p.yaml
+
+model:
+ target: sgm.models.diffusion.DiffusionEngine
+ params:
+ scale_factor: 0.18215
+ disable_first_stage_autocast: True
+ ckpt_path: checkpoints/sv3d_p.safetensors
+
+ denoiser_config:
+ target: sgm.modules.diffusionmodules.denoiser.Denoiser
+ params:
+ scaling_config:
+ target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
+
+ network_config:
+ target: sgm.modules.diffusionmodules.video_model.VideoUNet
+ params:
+ adm_in_channels: 1280
+ num_classes: sequential
+ use_checkpoint: True
+ in_channels: 8
+ out_channels: 4
+ model_channels: 320
+ attention_resolutions: [4, 2, 1]
+ num_res_blocks: 2
+ channel_mult: [1, 2, 4, 4]
+ num_head_channels: 64
+ use_linear_in_transformer: True
+ transformer_depth: 1
+ context_dim: 1024
+ spatial_transformer_attn_type: softmax-xformers
+ extra_ff_mix_layer: True
+ use_spatial_context: True
+ merge_strategy: learned_with_images
+ video_kernel_size: [3, 1, 1]
+
+ conditioner_config:
+ target: sgm.modules.GeneralConditioner
+ params:
+ emb_models:
+ - input_key: cond_frames_without_noise
+ is_trainable: False
+ target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
+ params:
+ n_cond_frames: 1
+ n_copies: 1
+ open_clip_embedding_config:
+ target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
+ params:
+ freeze: True
+
+ - input_key: cond_frames
+ is_trainable: False
+ target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
+ params:
+ disable_encoder_autocast: True
+ n_cond_frames: 1
+ n_copies: 1
+ is_ae: True
+ encoder_config:
+ target: sgm.models.autoencoder.AutoencoderKLModeOnly
+ params:
+ embed_dim: 4
+ monitor: val/rec_loss
+ ddconfig:
+ attn_type: vanilla-xformers
+ double_z: True
+ z_channels: 4
+ resolution: 256
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult: [1, 2, 4, 4]
+ num_res_blocks: 2
+ attn_resolutions: []
+ dropout: 0.0
+ lossconfig:
+ target: torch.nn.Identity
+
+ - input_key: cond_aug
+ is_trainable: False
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
+ params:
+ outdim: 256
+
+ - input_key: polars_rad
+ is_trainable: False
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
+ params:
+ outdim: 512
+
+ - input_key: azimuths_rad
+ is_trainable: False
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
+ params:
+ outdim: 512
+
+ first_stage_config:
+ target: sgm.models.autoencoder.AutoencodingEngine
+ params:
+ loss_config:
+ target: torch.nn.Identity
+ regularizer_config:
+ target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
+ encoder_config:
+ target: torch.nn.Identity
+ decoder_config:
+ target: sgm.modules.diffusionmodules.model.Decoder
+ params:
+ attn_type: vanilla-xformers
+ double_z: True
+ z_channels: 4
+ resolution: 256
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult: [ 1, 2, 4, 4 ]
+ num_res_blocks: 2
+ attn_resolutions: [ ]
+ dropout: 0.0
+
+ sampler_config:
+ target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
+ params:
+ discretization_config:
+ target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
+ params:
+ sigma_max: 700.0
+
+ guider_config:
+ target: sgm.modules.diffusionmodules.guiders.TrianglePredictionGuider
+ params:
+ max_scale: 2.5
diff --git a/libs/sv3d/diffusers_sv3d/__init__.py b/libs/sv3d/diffusers_sv3d/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..33919ce3d5956fc45c88b95b5792c82194545fe1
--- /dev/null
+++ b/libs/sv3d/diffusers_sv3d/__init__.py
@@ -0,0 +1,2 @@
+from .models import SV3DUNetSpatioTemporalConditionModel
+from .pipelines import StableVideo3DDiffusionPipeline
diff --git a/libs/sv3d/diffusers_sv3d/models/__init__.py b/libs/sv3d/diffusers_sv3d/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..574adc5c0fdb11ce9520d67683763fa1ededaa56
--- /dev/null
+++ b/libs/sv3d/diffusers_sv3d/models/__init__.py
@@ -0,0 +1 @@
+from .unets import SV3DUNetSpatioTemporalConditionModel
diff --git a/libs/sv3d/diffusers_sv3d/models/unets/__init__.py b/libs/sv3d/diffusers_sv3d/models/unets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9fed2bf59f3e3ef141c5480bf4e5f68be4dfe0e6
--- /dev/null
+++ b/libs/sv3d/diffusers_sv3d/models/unets/__init__.py
@@ -0,0 +1 @@
+from .unet_spatio_temporal_condition import SV3DUNetSpatioTemporalConditionModel
diff --git a/libs/sv3d/diffusers_sv3d/models/unets/unet_spatio_temporal_condition.py b/libs/sv3d/diffusers_sv3d/models/unets/unet_spatio_temporal_condition.py
new file mode 100644
index 0000000000000000000000000000000000000000..b1a7a3c310e7b0289025e3cb26c876f9226ee526
--- /dev/null
+++ b/libs/sv3d/diffusers_sv3d/models/unets/unet_spatio_temporal_condition.py
@@ -0,0 +1,483 @@
+from typing import *
+
+from diffusers.models.unets.unet_spatio_temporal_condition import *
+
+
+# Copied from diffusers.models.unets.unet_spatio_temporal_condition UNetSpatioTemporalConditionModel
+class SV3DUNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
+ r"""
+ A conditional Spatio-Temporal UNet model that takes a noisy video frames, conditional state, and a timestep and
+ returns a sample shaped output.
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
+ for all models (such as downloading or saving).
+
+ Parameters:
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
+ Height and width of input/output sample.
+ in_channels (`int`, *optional*, defaults to 8): Number of channels in the input sample.
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "DownBlockSpatioTemporal")`):
+ The tuple of downsample blocks to use.
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal")`):
+ The tuple of upsample blocks to use.
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
+ The tuple of output channels for each block.
+ addition_time_embed_dim: (`int`, defaults to 256):
+ Dimension to to encode the additional time ids.
+ projection_class_embeddings_input_dim (`int`, defaults to 768):
+ The dimension of the projection of encoded `added_time_ids`.
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
+ The dimension of the cross attention features.
+ transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
+ [`~models.unets.unet_3d_blocks.CrossAttnDownBlockSpatioTemporal`],
+ [`~models.unets.unet_3d_blocks.CrossAttnUpBlockSpatioTemporal`],
+ [`~models.unets.unet_3d_blocks.UNetMidBlockSpatioTemporal`].
+ num_attention_heads (`int`, `Tuple[int]`, defaults to `(5, 10, 10, 20)`):
+ The number of attention heads.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ sample_size: Optional[int] = None,
+ in_channels: int = 8,
+ out_channels: int = 4,
+ down_block_types: Tuple[str] = (
+ "CrossAttnDownBlockSpatioTemporal",
+ "CrossAttnDownBlockSpatioTemporal",
+ "CrossAttnDownBlockSpatioTemporal",
+ "DownBlockSpatioTemporal",
+ ),
+ up_block_types: Tuple[str] = (
+ "UpBlockSpatioTemporal",
+ "CrossAttnUpBlockSpatioTemporal",
+ "CrossAttnUpBlockSpatioTemporal",
+ "CrossAttnUpBlockSpatioTemporal",
+ ),
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
+ addition_time_embed_dim: int = 256,
+ projection_class_embeddings_input_dim: int = 768,
+ layers_per_block: Union[int, Tuple[int]] = 2,
+ cross_attention_dim: Union[int, Tuple[int]] = 1024,
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
+ num_attention_heads: Union[int, Tuple[int]] = (5, 10, 20, 20),
+ num_frames: int = 25,
+ ):
+ super().__init__()
+
+ self.sample_size = sample_size
+
+ # Check inputs
+ if len(down_block_types) != len(up_block_types):
+ raise ValueError(
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
+ )
+
+ if len(block_out_channels) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
+ )
+
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
+ )
+
+ # input
+ self.conv_in = nn.Conv2d(
+ in_channels,
+ block_out_channels[0],
+ kernel_size=3,
+ padding=1,
+ )
+
+ # time
+ time_embed_dim = block_out_channels[0] * 4
+
+ self.time_proj = Timesteps(block_out_channels[0], True, downscale_freq_shift=0)
+ timestep_input_dim = block_out_channels[0]
+
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
+
+ self.add_time_proj = Timesteps(addition_time_embed_dim, True, downscale_freq_shift=0)
+ self.add_angle_proj = Timesteps(2*addition_time_embed_dim, True, downscale_freq_shift=0) # encode camera angles
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
+
+ self.down_blocks = nn.ModuleList([])
+ self.up_blocks = nn.ModuleList([])
+
+ if isinstance(num_attention_heads, int):
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
+
+ if isinstance(cross_attention_dim, int):
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
+
+ if isinstance(layers_per_block, int):
+ layers_per_block = [layers_per_block] * len(down_block_types)
+
+ if isinstance(transformer_layers_per_block, int):
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
+
+ blocks_time_embed_dim = time_embed_dim
+
+ # down
+ output_channel = block_out_channels[0]
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ down_block = get_down_block(
+ down_block_type,
+ num_layers=layers_per_block[i],
+ transformer_layers_per_block=transformer_layers_per_block[i],
+ in_channels=input_channel,
+ out_channels=output_channel,
+ temb_channels=blocks_time_embed_dim,
+ add_downsample=not is_final_block,
+ resnet_eps=1e-5,
+ cross_attention_dim=cross_attention_dim[i],
+ num_attention_heads=num_attention_heads[i],
+ resnet_act_fn="silu",
+ )
+ self.down_blocks.append(down_block)
+
+ # mid
+ self.mid_block = UNetMidBlockSpatioTemporal(
+ block_out_channels[-1],
+ temb_channels=blocks_time_embed_dim,
+ transformer_layers_per_block=transformer_layers_per_block[-1],
+ cross_attention_dim=cross_attention_dim[-1],
+ num_attention_heads=num_attention_heads[-1],
+ )
+
+ # count how many layers upsample the images
+ self.num_upsamplers = 0
+
+ # up
+ reversed_block_out_channels = list(reversed(block_out_channels))
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
+ reversed_layers_per_block = list(reversed(layers_per_block))
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
+ reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
+
+ output_channel = reversed_block_out_channels[0]
+ for i, up_block_type in enumerate(up_block_types):
+ is_final_block = i == len(block_out_channels) - 1
+
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
+
+ # add upsample block for all BUT final layer
+ if not is_final_block:
+ add_upsample = True
+ self.num_upsamplers += 1
+ else:
+ add_upsample = False
+
+ up_block = get_up_block(
+ up_block_type,
+ num_layers=reversed_layers_per_block[i] + 1,
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
+ in_channels=input_channel,
+ out_channels=output_channel,
+ prev_output_channel=prev_output_channel,
+ temb_channels=blocks_time_embed_dim,
+ add_upsample=add_upsample,
+ resnet_eps=1e-5,
+ resolution_idx=i,
+ cross_attention_dim=reversed_cross_attention_dim[i],
+ num_attention_heads=reversed_num_attention_heads[i],
+ resnet_act_fn="silu",
+ )
+ self.up_blocks.append(up_block)
+ prev_output_channel = output_channel
+
+ # out
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=32, eps=1e-5)
+ self.conv_act = nn.SiLU()
+
+ self.conv_out = nn.Conv2d(
+ block_out_channels[0],
+ out_channels,
+ kernel_size=3,
+ padding=1,
+ )
+
+ @property
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(
+ name: str,
+ module: torch.nn.Module,
+ processors: Dict[str, AttentionProcessor],
+ ):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor()
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ def set_default_attn_processor(self):
+ """
+ Disables custom attention processors and sets the default attention implementation.
+ """
+ if all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
+ processor = AttnProcessor()
+ else:
+ raise ValueError(
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
+ )
+
+ self.set_attn_processor(processor)
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if hasattr(module, "gradient_checkpointing"):
+ module.gradient_checkpointing = value
+
+ # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
+ def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
+ """
+ Sets the attention processor to use [feed forward
+ chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
+
+ Parameters:
+ chunk_size (`int`, *optional*):
+ The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
+ over each tensor of dim=`dim`.
+ dim (`int`, *optional*, defaults to `0`):
+ The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
+ or dim=1 (sequence length).
+ """
+ if dim not in [0, 1]:
+ raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
+
+ # By default chunk size is 1
+ chunk_size = chunk_size or 1
+
+ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
+ if hasattr(module, "set_chunk_feed_forward"):
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
+
+ for child in module.children():
+ fn_recursive_feed_forward(child, chunk_size, dim)
+
+ for module in self.children():
+ fn_recursive_feed_forward(module, chunk_size, dim)
+
+ def forward(
+ self,
+ sample: torch.Tensor,
+ timestep: Union[torch.Tensor, float, int],
+ encoder_hidden_states: torch.Tensor,
+ added_time_ids: Union[torch.Tensor, List[torch.Tensor]],
+ return_dict: bool = True,
+ ) -> Union[UNetSpatioTemporalConditionOutput, Tuple]:
+ r"""
+ The [`UNetSpatioTemporalConditionModel`] forward method.
+
+ Args:
+ sample (`torch.Tensor`):
+ The noisy input tensor with the following shape `(batch, num_frames, channel, height, width)`.
+ timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input.
+ encoder_hidden_states (`torch.Tensor`):
+ The encoder hidden states with shape `(batch, sequence_length, cross_attention_dim)`.
+ added_time_ids: (`torch.Tensor`):
+ The additional time ids with shape `(batch, num_additional_ids)`. These are encoded with sinusoidal
+ embeddings and added to the time embeddings.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] instead
+ of a plain tuple.
+ Returns:
+ [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] or `tuple`:
+ If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is
+ returned, otherwise a `tuple` is returned where the first element is the sample tensor.
+ """
+ # 1. time
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
+ # This would be a good case for the `match` statement (Python 3.10+)
+ is_mps = sample.device.type == "mps"
+ if isinstance(timestep, float):
+ dtype = torch.float32 if is_mps else torch.float64
+ else:
+ dtype = torch.int32 if is_mps else torch.int64
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
+ elif len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(sample.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ batch_size, num_frames = sample.shape[:2]
+ timesteps = timesteps.expand(batch_size)
+
+ t_emb = self.time_proj(timesteps)
+
+ # `Timesteps` does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=sample.dtype)
+
+ emb = self.time_embedding(t_emb)
+
+ if isinstance(added_time_ids, torch.Tensor):
+ time_embeds = self.add_time_proj(added_time_ids.flatten())
+ time_embeds = time_embeds.reshape((batch_size, -1))
+ time_embeds = time_embeds.to(emb.dtype)
+ aug_emb = self.add_embedding(time_embeds)
+ emb = emb + aug_emb
+
+ # Repeat the embeddings num_video_frames times
+ # emb: [batch, channels] -> [batch * frames, channels]
+ emb = emb.repeat_interleave(num_frames, dim=0)
+ elif isinstance(added_time_ids, list):
+ # Repeat the embeddings num_video_frames times
+ # emb: [batch, channels] -> [batch * frames, channels]
+ emb = emb.repeat_interleave(num_frames, dim=0)
+
+ cond_aug, polars, azimuths = added_time_ids
+ cond_aug_emb = self.add_time_proj(cond_aug.flatten())
+ polars_emb = self.add_angle_proj(polars.flatten())
+ azimuths_emb = self.add_angle_proj(azimuths.flatten())
+ time_embeds = torch.cat([cond_aug_emb, polars_emb, azimuths_emb],dim=1)
+ time_embeds = time_embeds.to(emb.dtype)
+ aug_emb = self.add_embedding(time_embeds)
+ emb = emb + aug_emb
+ else:
+ raise ValueError
+
+ # Flatten the batch and frames dimensions
+ # sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width]
+ sample = sample.flatten(0, 1)
+ # encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels]
+ encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0)
+
+ # 2. pre-process
+ sample = self.conv_in(sample)
+
+ image_only_indicator = torch.zeros(batch_size, num_frames, dtype=sample.dtype, device=sample.device)
+
+ down_block_res_samples = (sample,)
+ for downsample_block in self.down_blocks:
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
+ sample, res_samples = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ image_only_indicator=image_only_indicator,
+ )
+ else:
+ sample, res_samples = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ image_only_indicator=image_only_indicator,
+ )
+
+ down_block_res_samples += res_samples
+
+ # 4. mid
+ sample = self.mid_block(
+ hidden_states=sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ image_only_indicator=image_only_indicator,
+ )
+
+ # 5. up
+ for i, upsample_block in enumerate(self.up_blocks):
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
+
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
+ sample = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ encoder_hidden_states=encoder_hidden_states,
+ image_only_indicator=image_only_indicator,
+ )
+ else:
+ sample = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ image_only_indicator=image_only_indicator,
+ )
+
+ # 6. post-process
+ sample = self.conv_norm_out(sample)
+ sample = self.conv_act(sample)
+ sample = self.conv_out(sample)
+
+ # 7. Reshape back to original shape
+ sample = sample.reshape(batch_size, num_frames, *sample.shape[1:])
+
+ if not return_dict:
+ return (sample,)
+
+ return UNetSpatioTemporalConditionOutput(sample=sample)
diff --git a/libs/sv3d/diffusers_sv3d/pipelines/__init__.py b/libs/sv3d/diffusers_sv3d/pipelines/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea7e2eaba4677684886863cbfe5165f17efa911c
--- /dev/null
+++ b/libs/sv3d/diffusers_sv3d/pipelines/__init__.py
@@ -0,0 +1 @@
+from .stable_video_diffusion import StableVideo3DDiffusionPipeline
diff --git a/libs/sv3d/diffusers_sv3d/pipelines/stable_video_diffusion/__init__.py b/libs/sv3d/diffusers_sv3d/pipelines/stable_video_diffusion/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc498bc2a865a0a9425adafb4ceae2c0708239bd
--- /dev/null
+++ b/libs/sv3d/diffusers_sv3d/pipelines/stable_video_diffusion/__init__.py
@@ -0,0 +1 @@
+from .pipeline_stable_video_3d_diffusion import StableVideo3DDiffusionPipeline
diff --git a/libs/sv3d/diffusers_sv3d/pipelines/stable_video_diffusion/pipeline_stable_video_3d_diffusion.py b/libs/sv3d/diffusers_sv3d/pipelines/stable_video_diffusion/pipeline_stable_video_3d_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..d78fd3ffe960402d831d07881f01fdbf90fd49c0
--- /dev/null
+++ b/libs/sv3d/diffusers_sv3d/pipelines/stable_video_diffusion/pipeline_stable_video_3d_diffusion.py
@@ -0,0 +1,222 @@
+from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import _append_dims
+from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import *
+
+
+# Copied from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion.StableVideoDiffusionPipeline
+class StableVideo3DDiffusionPipeline(StableVideoDiffusionPipeline):
+ def __init__(
+ self,
+ vae: AutoencoderKLTemporalDecoder,
+ image_encoder: CLIPVisionModelWithProjection,
+ unet: UNetSpatioTemporalConditionModel,
+ scheduler: EulerDiscreteScheduler,
+ feature_extractor: CLIPImageProcessor,
+ ):
+ super().__init__(
+ vae,
+ image_encoder,
+ unet,
+ scheduler,
+ feature_extractor,
+ )
+
+ def _get_add_time_ids(
+ self,
+ noise_aug_strength: float,
+ polars_rad: List[float],
+ azimuths_rad: List[float],
+ dtype: torch.dtype,
+ batch_size: int,
+ num_videos_per_prompt: int,
+ do_classifier_free_guidance: bool,
+ ):
+ cond_aug = torch.tensor([noise_aug_strength]*len(polars_rad), dtype=dtype).repeat(batch_size * num_videos_per_prompt, 1)
+ polars_rad = torch.tensor(polars_rad, dtype=dtype).repeat(batch_size * num_videos_per_prompt, 1)
+ azimuths_rad = torch.tensor(azimuths_rad, dtype=dtype).repeat(batch_size * num_videos_per_prompt, 1)
+
+ if do_classifier_free_guidance:
+ cond_aug = torch.cat([cond_aug, cond_aug])
+ polars_rad = torch.cat([polars_rad, polars_rad])
+ azimuths_rad = torch.cat([azimuths_rad, azimuths_rad])
+
+ add_time_ids = [cond_aug, polars_rad, azimuths_rad]
+
+ return add_time_ids
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.Tensor],
+
+ polars_rad: List[float],
+ azimuths_rad: List[float],
+ triangle_cfg_scaling: bool = True,
+
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_frames: Optional[int] = None,
+ num_inference_steps: int = 25,
+ sigmas: Optional[List[float]] = None,
+ min_guidance_scale: float = 1.0,
+ max_guidance_scale: float = 2.5,
+ noise_aug_strength: float = 1e-5,
+ decode_chunk_size: Optional[int] = None,
+ num_videos_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ return_dict: bool = True,
+ ):
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ num_frames = num_frames if num_frames is not None else self.unet.config.num_frames
+ decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(image, height, width)
+
+ # 2. Define call parameters
+ if isinstance(image, PIL.Image.Image):
+ batch_size = 1
+ elif isinstance(image, list):
+ batch_size = len(image)
+ else:
+ batch_size = image.shape[0]
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ self._guidance_scale = max_guidance_scale
+
+ # 3. Encode input image
+ image_embeddings = self._encode_image(image, device, num_videos_per_prompt, self.do_classifier_free_guidance)
+
+ # 4. Encode input image using VAE
+ image = self.video_processor.preprocess(image, height=height, width=width).to(device)
+ noise = randn_tensor(image.shape, generator=generator, device=device, dtype=image.dtype)
+ image = image + noise_aug_strength * noise
+
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
+ if needs_upcasting:
+ self.vae.to(dtype=torch.float32)
+
+ image_latents = self._encode_vae_image(
+ image,
+ device=device,
+ num_videos_per_prompt=num_videos_per_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ )
+ image_latents = image_latents.to(image_embeddings.dtype)
+
+ # cast back to fp16 if needed
+ if needs_upcasting:
+ self.vae.to(dtype=torch.float16)
+
+ # Repeat the image latents for each frame so we can concatenate them with the noise
+ # image_latents [batch, channels, height, width] ->[batch, num_frames, channels, height, width]
+ image_latents = image_latents.unsqueeze(1).repeat(1, num_frames, 1, 1, 1)
+
+ # 5. Get Added Time IDs
+ added_time_ids = self._get_add_time_ids(
+ noise_aug_strength,
+ polars_rad,
+ azimuths_rad,
+ image_embeddings.dtype,
+ batch_size,
+ num_videos_per_prompt,
+ self.do_classifier_free_guidance,
+ )
+ added_time_ids = [a.to(device) for a in added_time_ids] # (cond_aug, polars_rad, azimuths_rad)
+
+ # 6. Prepare timesteps
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, None, sigmas)
+
+ # 7. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_videos_per_prompt,
+ num_frames,
+ num_channels_latents,
+ height,
+ width,
+ image_embeddings.dtype,
+ device,
+ generator,
+ latents,
+ )
+ print(latents.shape)
+
+ # 8. Prepare guidance scale
+ if triangle_cfg_scaling:
+ # Triangle CFG scaling; the last view is input condition
+ guidance_scale = torch.cat([
+ torch.linspace(min_guidance_scale, max_guidance_scale, num_frames//2 + 1)[1:].unsqueeze(0),
+ torch.linspace(max_guidance_scale, min_guidance_scale, num_frames - num_frames//2 + 1)[1:].unsqueeze(0)
+ ], dim=-1)
+ else:
+ guidance_scale = torch.linspace(min_guidance_scale, max_guidance_scale, num_frames).unsqueeze(0)
+ guidance_scale = guidance_scale.to(device, latents.dtype)
+ guidance_scale = guidance_scale.repeat(batch_size * num_videos_per_prompt, 1)
+ guidance_scale = _append_dims(guidance_scale, latents.ndim)
+
+ self._guidance_scale = guidance_scale
+
+ # 9. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ self._num_timesteps = len(timesteps)
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # Concatenate image_latents over channels dimension
+ latent_model_input = torch.cat([latent_model_input, image_latents], dim=2)
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=image_embeddings,
+ added_time_ids=added_time_ids,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents).prev_sample
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if not output_type == "latent":
+ # cast back to fp16 if needed
+ if needs_upcasting:
+ self.vae.to(dtype=torch.float16)
+ frames = self.decode_latents(latents, num_frames, decode_chunk_size)
+ frames = self.video_processor.postprocess_video(video=frames, output_type=output_type)
+ else:
+ frames = latents
+
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return frames
+
+ return StableVideoDiffusionPipelineOutput(frames=frames)
diff --git a/libs/sv3d/infer.py b/libs/sv3d/infer.py
new file mode 100644
index 0000000000000000000000000000000000000000..3f9f6284893a5279d80fe0374e80aecffcc27a3d
--- /dev/null
+++ b/libs/sv3d/infer.py
@@ -0,0 +1,130 @@
+import os
+import argparse
+import rembg
+import numpy as np
+import math
+import torch
+import json
+
+from PIL import Image
+from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
+from diffusers import AutoencoderKL, EulerDiscreteScheduler
+from diffusers.utils import export_to_gif
+from diffusers_sv3d import SV3DUNetSpatioTemporalConditionModel, StableVideo3DDiffusionPipeline
+from kiui.cam import orbit_camera
+
+SV3D_DIFFUSERS = "chenguolin/sv3d-diffusers"
+
+# os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
+os.environ["HF_HOME"] = "~/.cache/huggingface"
+
+def construct_camera(azimuths_rad, elevation_rad, output_dir, res=576, radius=2, fov=33.8):
+
+ transforms = {}
+ transforms["camera_angle_x"] = math.radians(fov)
+ transforms["frames"] = []
+
+ for i in range(21):
+ frame = {}
+ frame['file_path'] = f"data/{i:03d}"
+ frame['transform_matrix'] = orbit_camera(elevation_rad[i], azimuths_rad[i], radius, is_degree=False).tolist()
+ transforms['frames'].append(frame)
+
+ with open(f"{output_dir}/../transforms_train.json", "w") as f:
+ json.dump(transforms, f, indent=4)
+ with open(f"{output_dir}/../transforms_val.json", "w") as f:
+ json.dump(transforms, f, indent=4)
+ with open(f"{output_dir}/../transforms_test.json", "w") as f:
+ json.dump(transforms, f, indent=4)
+
+def recenter(image, h_begin=100, w_begin=220, res=256):
+ image = np.array(image)
+ h_image, w_image = image.shape[:2]
+ new_image = np.zeros((res, res, 4), dtype=np.uint8)
+ h_begin_new = -min(0, h_begin)
+ w_begin_new = -min(0, w_begin)
+ if h_begin > 0 and w_begin > 0:
+ new_image = image[h_begin:h_begin+res, w_begin:w_begin+res]
+ else:
+ new_image[h_begin_new:h_begin_new+h_image, w_begin_new:w_image] = image
+ return Image.fromarray(new_image)
+
+def main():
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--base-dir", default="../../data", type=str, help="Base dir")
+ parser.add_argument("--output-dir", default="../../data", type=str, help="Output filepath")
+ parser.add_argument("--data-name", default="chair", type=str, help="Data Name")
+ parser.add_argument("--elevation", default=0, type=float, help="Camera elevation of the input image")
+ parser.add_argument("--half-precision", action="store_true", help="Use fp16 half precision")
+ parser.add_argument("--seed", default=-1, type=int, help="Random seed")
+ args = parser.parse_args()
+
+ image_path = f'{args.base_dir}/{args.data_name}/{args.data_name}.png'
+ output_dir = f'{args.output_dir}/{args.data_name}/data'
+ os.makedirs(output_dir, exist_ok=True)
+
+ num_frames, sv3d_res = 20, 576
+ elevations_deg = [args.elevation] * num_frames
+ elevations_rad = [np.deg2rad(e) for e in elevations_deg]
+ polars_rad = [np.deg2rad(90 - e) for e in elevations_deg]
+ azimuths_deg = np.linspace(0, 360, num_frames + 1)[1:] % 360
+ azimuths_rad = [np.deg2rad((a - azimuths_deg[-1]) % 360) for a in azimuths_deg]
+ azimuths_rad[:-1].sort()
+
+ # print(f"Elevation: {elevations_rad}")
+ print(f"Azimuth: {np.rad2deg(azimuths_rad)}")
+ # construct_camera(azimuths_rad, elevations_rad, output_dir=output_dir)
+
+ bg_remover = rembg.new_session()
+ unet = SV3DUNetSpatioTemporalConditionModel.from_pretrained(SV3D_DIFFUSERS, subfolder="unet")
+ vae = AutoencoderKL.from_pretrained(SV3D_DIFFUSERS, subfolder="vae")
+ scheduler = EulerDiscreteScheduler.from_pretrained(SV3D_DIFFUSERS, subfolder="scheduler")
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(SV3D_DIFFUSERS, subfolder="image_encoder")
+ feature_extractor = CLIPImageProcessor.from_pretrained(SV3D_DIFFUSERS, subfolder="feature_extractor")
+
+ pipeline = StableVideo3DDiffusionPipeline(
+ image_encoder=image_encoder, feature_extractor=feature_extractor,
+ unet=unet, vae=vae,
+ scheduler=scheduler,
+ )
+ pipeline = pipeline.to("cuda")
+ with torch.no_grad():
+ with torch.autocast("cuda", dtype=torch.float16 if args.half_precision else torch.float32, enabled=True):
+
+ h_begin, w_begin, res = 180, 190, 280
+ image = Image.open(image_path)
+ image = recenter(image, h_begin, w_begin, res)
+ image = rembg.remove(image, session=bg_remover) # [H, W, 4]
+ image.save(f"{output_dir}/../{args.data_name}_alpha.png")
+ if len(image.split()) == 4: # RGBA
+ input_image = Image.new("RGB", image.size, (255, 255, 255)) # pure white bg
+ input_image.paste(image, mask=image.split()[3]) # 3rd is the alpha channel
+ else:
+ input_image = image
+
+ video_frames = pipeline(
+ input_image.resize((sv3d_res, sv3d_res)),
+ height=sv3d_res,
+ width=sv3d_res,
+ num_frames=num_frames,
+ decode_chunk_size=8, # smaller to save memory
+ polars_rad=polars_rad,
+ azimuths_rad=azimuths_rad,
+ generator=torch.manual_seed(args.seed) if args.seed >= 0 else None,
+ ).frames[0]
+
+ os.makedirs(output_dir, exist_ok=True)
+ export_to_gif(video_frames, f"{output_dir}/animation.gif", fps=7)
+ for i, frame in enumerate(video_frames):
+ # frame = frame.resize((res, res))
+ frame.save(f"{output_dir}/{i:03d}.png")
+ video_frames[19].save(f"../LGM/workspace_test/{args.data_name}_0.png")
+ video_frames[4].save(f"../LGM/workspace_test/{args.data_name}_1.png")
+ video_frames[9].save(f"../LGM/workspace_test/{args.data_name}_2.png")
+ video_frames[14].save(f"../LGM/workspace_test/{args.data_name}_3.png")
+
+
+if __name__ == "__main__":
+ main()
+
\ No newline at end of file
diff --git a/libs/sv3d/requirements.txt b/libs/sv3d/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..f37afc2e26c2ec57bdfb06d49654df85569b653f
--- /dev/null
+++ b/libs/sv3d/requirements.txt
@@ -0,0 +1,7 @@
+Pillow
+numpy
+transformers
+diffusers
+omegaconf
+rembg
+onnxruntime
\ No newline at end of file
diff --git a/libs/vggt/dependency/__init__.py b/libs/vggt/dependency/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4eab06de2c911398e0339782e327bffd4cc9f91c
--- /dev/null
+++ b/libs/vggt/dependency/__init__.py
@@ -0,0 +1,3 @@
+from .track_modules.track_refine import refine_track
+from .track_modules.blocks import BasicEncoder, ShallowEncoder
+from .track_modules.base_track_predictor import BaseTrackerPredictor
diff --git a/libs/vggt/dependency/distortion.py b/libs/vggt/dependency/distortion.py
new file mode 100644
index 0000000000000000000000000000000000000000..b3510230265dbd088844076e9d5763a35f7d712b
--- /dev/null
+++ b/libs/vggt/dependency/distortion.py
@@ -0,0 +1,182 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import numpy as np
+from typing import Union
+
+ArrayLike = Union[np.ndarray, torch.Tensor]
+
+
+def _is_numpy(x: ArrayLike) -> bool:
+ return isinstance(x, np.ndarray)
+
+
+def _is_torch(x: ArrayLike) -> bool:
+ return isinstance(x, torch.Tensor)
+
+
+def _ensure_torch(x: ArrayLike) -> torch.Tensor:
+ """Convert input to torch tensor if it's not already one."""
+ if _is_numpy(x):
+ return torch.from_numpy(x)
+ elif _is_torch(x):
+ return x
+ else:
+ return torch.tensor(x)
+
+
+def single_undistortion(params, tracks_normalized):
+ """
+ Apply undistortion to the normalized tracks using the given distortion parameters once.
+
+ Args:
+ params (torch.Tensor or numpy.ndarray): Distortion parameters of shape BxN.
+ tracks_normalized (torch.Tensor or numpy.ndarray): Normalized tracks tensor of shape [batch_size, num_tracks, 2].
+
+ Returns:
+ torch.Tensor: Undistorted normalized tracks tensor.
+ """
+ params = _ensure_torch(params)
+ tracks_normalized = _ensure_torch(tracks_normalized)
+
+ u, v = tracks_normalized[..., 0].clone(), tracks_normalized[..., 1].clone()
+ u_undist, v_undist = apply_distortion(params, u, v)
+ return torch.stack([u_undist, v_undist], dim=-1)
+
+
+def iterative_undistortion(params, tracks_normalized, max_iterations=100, max_step_norm=1e-10, rel_step_size=1e-6):
+ """
+ Iteratively undistort the normalized tracks using the given distortion parameters.
+
+ Args:
+ params (torch.Tensor or numpy.ndarray): Distortion parameters of shape BxN.
+ tracks_normalized (torch.Tensor or numpy.ndarray): Normalized tracks tensor of shape [batch_size, num_tracks, 2].
+ max_iterations (int): Maximum number of iterations for the undistortion process.
+ max_step_norm (float): Maximum step norm for convergence.
+ rel_step_size (float): Relative step size for numerical differentiation.
+
+ Returns:
+ torch.Tensor: Undistorted normalized tracks tensor.
+ """
+ params = _ensure_torch(params)
+ tracks_normalized = _ensure_torch(tracks_normalized)
+
+ B, N, _ = tracks_normalized.shape
+ u, v = tracks_normalized[..., 0].clone(), tracks_normalized[..., 1].clone()
+ original_u, original_v = u.clone(), v.clone()
+
+ eps = torch.finfo(u.dtype).eps
+ for idx in range(max_iterations):
+ u_undist, v_undist = apply_distortion(params, u, v)
+ dx = original_u - u_undist
+ dy = original_v - v_undist
+
+ step_u = torch.clamp(torch.abs(u) * rel_step_size, min=eps)
+ step_v = torch.clamp(torch.abs(v) * rel_step_size, min=eps)
+
+ J_00 = (apply_distortion(params, u + step_u, v)[0] - apply_distortion(params, u - step_u, v)[0]) / (2 * step_u)
+ J_01 = (apply_distortion(params, u, v + step_v)[0] - apply_distortion(params, u, v - step_v)[0]) / (2 * step_v)
+ J_10 = (apply_distortion(params, u + step_u, v)[1] - apply_distortion(params, u - step_u, v)[1]) / (2 * step_u)
+ J_11 = (apply_distortion(params, u, v + step_v)[1] - apply_distortion(params, u, v - step_v)[1]) / (2 * step_v)
+
+ J = torch.stack([torch.stack([J_00 + 1, J_01], dim=-1), torch.stack([J_10, J_11 + 1], dim=-1)], dim=-2)
+
+ delta = torch.linalg.solve(J, torch.stack([dx, dy], dim=-1))
+
+ u += delta[..., 0]
+ v += delta[..., 1]
+
+ if torch.max((delta**2).sum(dim=-1)) < max_step_norm:
+ break
+
+ return torch.stack([u, v], dim=-1)
+
+
+def apply_distortion(extra_params, u, v):
+ """
+ Applies radial or OpenCV distortion to the given 2D points.
+
+ Args:
+ extra_params (torch.Tensor or numpy.ndarray): Distortion parameters of shape BxN, where N can be 1, 2, or 4.
+ u (torch.Tensor or numpy.ndarray): Normalized x coordinates of shape Bxnum_tracks.
+ v (torch.Tensor or numpy.ndarray): Normalized y coordinates of shape Bxnum_tracks.
+
+ Returns:
+ points2D (torch.Tensor): Distorted 2D points of shape BxNx2.
+ """
+ extra_params = _ensure_torch(extra_params)
+ u = _ensure_torch(u)
+ v = _ensure_torch(v)
+
+ num_params = extra_params.shape[1]
+
+ if num_params == 1:
+ # Simple radial distortion
+ k = extra_params[:, 0]
+ u2 = u * u
+ v2 = v * v
+ r2 = u2 + v2
+ radial = k[:, None] * r2
+ du = u * radial
+ dv = v * radial
+
+ elif num_params == 2:
+ # RadialCameraModel distortion
+ k1, k2 = extra_params[:, 0], extra_params[:, 1]
+ u2 = u * u
+ v2 = v * v
+ r2 = u2 + v2
+ radial = k1[:, None] * r2 + k2[:, None] * r2 * r2
+ du = u * radial
+ dv = v * radial
+
+ elif num_params == 4:
+ # OpenCVCameraModel distortion
+ k1, k2, p1, p2 = (extra_params[:, 0], extra_params[:, 1], extra_params[:, 2], extra_params[:, 3])
+ u2 = u * u
+ v2 = v * v
+ uv = u * v
+ r2 = u2 + v2
+ radial = k1[:, None] * r2 + k2[:, None] * r2 * r2
+ du = u * radial + 2 * p1[:, None] * uv + p2[:, None] * (r2 + 2 * u2)
+ dv = v * radial + 2 * p2[:, None] * uv + p1[:, None] * (r2 + 2 * v2)
+ else:
+ raise ValueError("Unsupported number of distortion parameters")
+
+ u = u.clone() + du
+ v = v.clone() + dv
+
+ return u, v
+
+
+if __name__ == "__main__":
+ import random
+ import pycolmap
+
+ max_diff = 0
+ for i in range(1000):
+ # Define distortion parameters (assuming 1 parameter for simplicity)
+ B = random.randint(1, 500)
+ track_num = random.randint(100, 1000)
+ params = torch.rand((B, 1), dtype=torch.float32) # Batch size 1, 4 parameters
+ tracks_normalized = torch.rand((B, track_num, 2), dtype=torch.float32) # Batch size 1, 5 points
+
+ # Undistort the tracks
+ undistorted_tracks = iterative_undistortion(params, tracks_normalized)
+
+ for b in range(B):
+ pycolmap_intri = np.array([1, 0, 0, params[b].item()])
+ pycam = pycolmap.Camera(model="SIMPLE_RADIAL", width=1, height=1, params=pycolmap_intri, camera_id=0)
+
+ undistorted_tracks_pycolmap = pycam.cam_from_img(tracks_normalized[b].numpy())
+ diff = (undistorted_tracks[b] - undistorted_tracks_pycolmap).abs().median()
+ max_diff = max(max_diff, diff)
+ print(f"diff: {diff}, max_diff: {max_diff}")
+
+ import pdb
+
+ pdb.set_trace()
diff --git a/libs/vggt/dependency/np_to_pycolmap.py b/libs/vggt/dependency/np_to_pycolmap.py
new file mode 100644
index 0000000000000000000000000000000000000000..61ea578692d2b5a5cd5b6fd15836373a94351489
--- /dev/null
+++ b/libs/vggt/dependency/np_to_pycolmap.py
@@ -0,0 +1,320 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import numpy as np
+import pycolmap
+from .projection import project_3D_points_np
+
+
+def batch_np_matrix_to_pycolmap(
+ points3d,
+ extrinsics,
+ intrinsics,
+ tracks,
+ image_size,
+ masks=None,
+ max_reproj_error=None,
+ max_points3D_val=3000,
+ shared_camera=False,
+ camera_type="SIMPLE_PINHOLE",
+ extra_params=None,
+ min_inlier_per_frame=64,
+ points_rgb=None,
+):
+ """
+ Convert Batched NumPy Arrays to PyCOLMAP
+
+ Check https://github.com/colmap/pycolmap for more details about its format
+
+ NOTE that colmap expects images/cameras/points3D to be 1-indexed
+ so there is a +1 offset between colmap index and batch index
+
+
+ NOTE: different from VGGSfM, this function:
+ 1. Use np instead of torch
+ 2. Frame index and camera id starts from 1 rather than 0 (to fit the format of PyCOLMAP)
+ """
+ # points3d: Px3
+ # extrinsics: Nx3x4
+ # intrinsics: Nx3x3
+ # tracks: NxPx2
+ # masks: NxP
+ # image_size: 2, assume all the frames have been padded to the same size
+ # where N is the number of frames and P is the number of tracks
+
+ N, P, _ = tracks.shape
+ assert len(extrinsics) == N
+ assert len(intrinsics) == N
+ assert len(points3d) == P
+ assert image_size.shape[0] == 2
+
+ reproj_mask = None
+
+ if max_reproj_error is not None:
+ projected_points_2d, projected_points_cam = project_3D_points_np(points3d, extrinsics, intrinsics)
+ projected_diff = np.linalg.norm(projected_points_2d - tracks, axis=-1)
+ projected_points_2d[projected_points_cam[:, -1] <= 0] = 1e6
+ reproj_mask = projected_diff < max_reproj_error
+
+ if masks is not None and reproj_mask is not None:
+ masks = np.logical_and(masks, reproj_mask)
+ elif masks is not None:
+ masks = masks
+ else:
+ masks = reproj_mask
+
+ assert masks is not None
+
+ if masks.sum(1).min() < min_inlier_per_frame:
+ print(f"Not enough inliers per frame, skip BA.")
+ return None, None
+
+ # Reconstruction object, following the format of PyCOLMAP/COLMAP
+ reconstruction = pycolmap.Reconstruction()
+
+ inlier_num = masks.sum(0)
+ valid_mask = inlier_num >= 2 # a track is invalid if without two inliers
+ valid_idx = np.nonzero(valid_mask)[0]
+
+ # Only add 3D points that have sufficient 2D points
+ for vidx in valid_idx:
+ # Use RGB colors if provided, otherwise use zeros
+ rgb = points_rgb[vidx] if points_rgb is not None else np.zeros(3)
+ reconstruction.add_point3D(points3d[vidx], pycolmap.Track(), rgb)
+
+ num_points3D = len(valid_idx)
+ camera = None
+ # frame idx
+ for fidx in range(N):
+ # set camera
+ if camera is None or (not shared_camera):
+ pycolmap_intri = _build_pycolmap_intri(fidx, intrinsics, camera_type, extra_params)
+
+ camera = pycolmap.Camera(
+ model=camera_type, width=image_size[0], height=image_size[1], params=pycolmap_intri, camera_id=fidx + 1
+ )
+
+ # add camera
+ reconstruction.add_camera(camera)
+
+ # set image
+ cam_from_world = pycolmap.Rigid3d(
+ pycolmap.Rotation3d(extrinsics[fidx][:3, :3]), extrinsics[fidx][:3, 3]
+ ) # Rot and Trans
+
+ image = pycolmap.Image(
+ id=fidx + 1, name=f"image_{fidx + 1}", camera_id=camera.camera_id, cam_from_world=cam_from_world
+ )
+
+ points2D_list = []
+
+ point2D_idx = 0
+
+ # NOTE point3D_id start by 1
+ for point3D_id in range(1, num_points3D + 1):
+ original_track_idx = valid_idx[point3D_id - 1]
+
+ if (reconstruction.points3D[point3D_id].xyz < max_points3D_val).all():
+ if masks[fidx][original_track_idx]:
+ # It seems we don't need +0.5 for BA
+ point2D_xy = tracks[fidx][original_track_idx]
+ # Please note when adding the Point2D object
+ # It not only requires the 2D xy location, but also the id to 3D point
+ points2D_list.append(pycolmap.Point2D(point2D_xy, point3D_id))
+
+ # add element
+ track = reconstruction.points3D[point3D_id].track
+ track.add_element(fidx + 1, point2D_idx)
+ point2D_idx += 1
+
+ assert point2D_idx == len(points2D_list)
+
+ try:
+ image.points2D = pycolmap.ListPoint2D(points2D_list)
+ image.registered = True
+ except:
+ print(f"frame {fidx + 1} is out of BA")
+ image.registered = False
+
+ # add image
+ reconstruction.add_image(image)
+
+ return reconstruction, valid_mask
+
+
+def pycolmap_to_batch_np_matrix(reconstruction, device="cpu", camera_type="SIMPLE_PINHOLE"):
+ """
+ Convert a PyCOLMAP Reconstruction Object to batched NumPy arrays.
+
+ Args:
+ reconstruction (pycolmap.Reconstruction): The reconstruction object from PyCOLMAP.
+ device (str): Ignored in NumPy version (kept for API compatibility).
+ camera_type (str): The type of camera model used (default: "SIMPLE_PINHOLE").
+
+ Returns:
+ tuple: A tuple containing points3D, extrinsics, intrinsics, and optionally extra_params.
+ """
+
+ num_images = len(reconstruction.images)
+ max_points3D_id = max(reconstruction.point3D_ids())
+ points3D = np.zeros((max_points3D_id, 3))
+
+ for point3D_id in reconstruction.points3D:
+ points3D[point3D_id - 1] = reconstruction.points3D[point3D_id].xyz
+
+ extrinsics = []
+ intrinsics = []
+
+ extra_params = [] if camera_type == "SIMPLE_RADIAL" else None
+
+ for i in range(num_images):
+ # Extract and append extrinsics
+ pyimg = reconstruction.images[i + 1]
+ pycam = reconstruction.cameras[pyimg.camera_id]
+ matrix = pyimg.cam_from_world.matrix()
+ extrinsics.append(matrix)
+
+ # Extract and append intrinsics
+ calibration_matrix = pycam.calibration_matrix()
+ intrinsics.append(calibration_matrix)
+
+ if camera_type == "SIMPLE_RADIAL":
+ extra_params.append(pycam.params[-1])
+
+ # Convert lists to NumPy arrays instead of torch tensors
+ extrinsics = np.stack(extrinsics)
+ intrinsics = np.stack(intrinsics)
+
+ if camera_type == "SIMPLE_RADIAL":
+ extra_params = np.stack(extra_params)
+ extra_params = extra_params[:, None]
+
+ return points3D, extrinsics, intrinsics, extra_params
+
+
+########################################################
+
+
+def batch_np_matrix_to_pycolmap_wo_track(
+ points3d,
+ points_xyf,
+ points_rgb,
+ extrinsics,
+ intrinsics,
+ image_size,
+ shared_camera=False,
+ camera_type="SIMPLE_PINHOLE",
+):
+ """
+ Convert Batched NumPy Arrays to PyCOLMAP
+
+ Different from batch_np_matrix_to_pycolmap, this function does not use tracks.
+
+ It saves points3d to colmap reconstruction format only to serve as init for Gaussians or other nvs methods.
+
+ Do NOT use this for BA.
+ """
+ # points3d: Px3
+ # points_xyf: Px3, with x, y coordinates and frame indices
+ # points_rgb: Px3, rgb colors
+ # extrinsics: Nx3x4
+ # intrinsics: Nx3x3
+ # image_size: 2, assume all the frames have been padded to the same size
+ # where N is the number of frames and P is the number of tracks
+
+ N = len(extrinsics)
+ P = len(points3d)
+
+ # Reconstruction object, following the format of PyCOLMAP/COLMAP
+ reconstruction = pycolmap.Reconstruction()
+
+ for vidx in range(P):
+ reconstruction.add_point3D(points3d[vidx], pycolmap.Track(), points_rgb[vidx])
+
+ camera = None
+ # frame idx
+ for fidx in range(N):
+ # set camera
+ if camera is None or (not shared_camera):
+ pycolmap_intri = _build_pycolmap_intri(fidx, intrinsics, camera_type)
+
+ camera = pycolmap.Camera(
+ model=camera_type, width=image_size[0], height=image_size[1], params=pycolmap_intri, camera_id=fidx + 1
+ )
+
+ # add camera
+ reconstruction.add_camera(camera)
+
+ # set image
+ cam_from_world = pycolmap.Rigid3d(
+ pycolmap.Rotation3d(extrinsics[fidx][:3, :3]), extrinsics[fidx][:3, 3]
+ ) # Rot and Trans
+
+ image = pycolmap.Image(
+ id=fidx + 1, name=f"image_{fidx + 1}", camera_id=camera.camera_id, cam_from_world=cam_from_world
+ )
+
+ points2D_list = []
+
+ point2D_idx = 0
+
+ points_belong_to_fidx = points_xyf[:, 2].astype(np.int32) == fidx
+ points_belong_to_fidx = np.nonzero(points_belong_to_fidx)[0]
+
+ for point3D_batch_idx in points_belong_to_fidx:
+ point3D_id = point3D_batch_idx + 1
+ point2D_xyf = points_xyf[point3D_batch_idx]
+ point2D_xy = point2D_xyf[:2]
+ points2D_list.append(pycolmap.Point2D(point2D_xy, point3D_id))
+
+ # add element
+ track = reconstruction.points3D[point3D_id].track
+ track.add_element(fidx + 1, point2D_idx)
+ point2D_idx += 1
+
+ assert point2D_idx == len(points2D_list)
+
+ try:
+ image.points2D = pycolmap.ListPoint2D(points2D_list)
+ image.registered = True
+ except:
+ print(f"frame {fidx + 1} does not have any points")
+ image.registered = False
+
+ # add image
+ reconstruction.add_image(image)
+
+ return reconstruction
+
+
+def _build_pycolmap_intri(fidx, intrinsics, camera_type, extra_params=None):
+ """
+ Helper function to get camera parameters based on camera type.
+
+ Args:
+ fidx: Frame index
+ intrinsics: Camera intrinsic parameters
+ camera_type: Type of camera model
+ extra_params: Additional parameters for certain camera types
+
+ Returns:
+ pycolmap_intri: NumPy array of camera parameters
+ """
+ if camera_type == "PINHOLE":
+ pycolmap_intri = np.array(
+ [intrinsics[fidx][0, 0], intrinsics[fidx][1, 1], intrinsics[fidx][0, 2], intrinsics[fidx][1, 2]]
+ )
+ elif camera_type == "SIMPLE_PINHOLE":
+ focal = (intrinsics[fidx][0, 0] + intrinsics[fidx][1, 1]) / 2
+ pycolmap_intri = np.array([focal, intrinsics[fidx][0, 2], intrinsics[fidx][1, 2]])
+ elif camera_type == "SIMPLE_RADIAL":
+ raise NotImplementedError("SIMPLE_RADIAL is not supported yet")
+ focal = (intrinsics[fidx][0, 0] + intrinsics[fidx][1, 1]) / 2
+ pycolmap_intri = np.array([focal, intrinsics[fidx][0, 2], intrinsics[fidx][1, 2], extra_params[fidx][0]])
+ else:
+ raise ValueError(f"Camera type {camera_type} is not supported yet")
+
+ return pycolmap_intri
diff --git a/libs/vggt/dependency/projection.py b/libs/vggt/dependency/projection.py
new file mode 100644
index 0000000000000000000000000000000000000000..a98082dc2f5b3c057b398a03ab13dba470f4a111
--- /dev/null
+++ b/libs/vggt/dependency/projection.py
@@ -0,0 +1,228 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import numpy as np
+from .distortion import apply_distortion
+
+
+def img_from_cam_np(
+ intrinsics: np.ndarray, points_cam: np.ndarray, extra_params: np.ndarray | None = None, default: float = 0.0
+) -> np.ndarray:
+ """
+ Apply intrinsics (and optional radial distortion) to camera-space points.
+
+ Args
+ ----
+ intrinsics : (B,3,3) camera matrix K.
+ points_cam : (B,3,N) homogeneous camera coords (x, y, z)ᵀ.
+ extra_params: (B, N) or (B, k) distortion params (k = 1,2,4) or None.
+ default : value used for np.nan replacement.
+
+ Returns
+ -------
+ points2D : (B,N,2) pixel coordinates.
+ """
+ # 1. perspective divide ───────────────────────────────────────
+ z = points_cam[:, 2:3, :] # (B,1,N)
+ points_cam_norm = points_cam / z # (B,3,N)
+ uv = points_cam_norm[:, :2, :] # (B,2,N)
+
+ # 2. optional distortion ──────────────────────────────────────
+ if extra_params is not None:
+ uu, vv = apply_distortion(extra_params, uv[:, 0], uv[:, 1])
+ uv = np.stack([uu, vv], axis=1) # (B,2,N)
+
+ # 3. homogeneous coords then K multiplication ─────────────────
+ ones = np.ones_like(uv[:, :1, :]) # (B,1,N)
+ points_cam_h = np.concatenate([uv, ones], axis=1) # (B,3,N)
+
+ # batched mat-mul: K · [u v 1]ᵀ
+ points2D_h = np.einsum("bij,bjk->bik", intrinsics, points_cam_h) # (B,3,N)
+ points2D = np.nan_to_num(points2D_h[:, :2, :], nan=default) # (B,2,N)
+
+ return points2D.transpose(0, 2, 1) # (B,N,2)
+
+
+def project_3D_points_np(
+ points3D: np.ndarray,
+ extrinsics: np.ndarray,
+ intrinsics: np.ndarray | None = None,
+ extra_params: np.ndarray | None = None,
+ *,
+ default: float = 0.0,
+ only_points_cam: bool = False,
+):
+ """
+ NumPy clone of ``project_3D_points``.
+
+ Parameters
+ ----------
+ points3D : (N,3) world-space points.
+ extrinsics : (B,3,4) [R|t] matrix for each of B cameras.
+ intrinsics : (B,3,3) K matrix (optional if you only need cam-space).
+ extra_params : (B,k) or (B,N) distortion parameters (k ∈ {1,2,4}) or None.
+ default : value used to replace NaNs.
+ only_points_cam : if True, skip the projection and return points_cam with points2D as None.
+
+ Returns
+ -------
+ (points2D, points_cam) : A tuple where points2D is (B,N,2) pixel coords or None if only_points_cam=True,
+ and points_cam is (B,3,N) camera-space coordinates.
+ """
+ # ----- 0. prep sizes -----------------------------------------------------
+ N = points3D.shape[0] # #points
+ B = extrinsics.shape[0] # #cameras
+
+ # ----- 1. world → homogeneous -------------------------------------------
+ w_h = np.ones((N, 1), dtype=points3D.dtype)
+ points3D_h = np.concatenate([points3D, w_h], axis=1) # (N,4)
+
+ # broadcast to every camera (no actual copying with np.broadcast_to) ------
+ points3D_h_B = np.broadcast_to(points3D_h, (B, N, 4)) # (B,N,4)
+
+ # ----- 2. apply extrinsics (camera frame) ------------------------------
+ # X_cam = E · X_hom
+ # einsum: E_(b i j) · X_(b n j) → (b n i)
+ points_cam = np.einsum("bij,bnj->bni", extrinsics, points3D_h_B) # (B,N,3)
+ points_cam = points_cam.transpose(0, 2, 1) # (B,3,N)
+
+ if only_points_cam:
+ return None, points_cam
+
+ # ----- 3. intrinsics + distortion ---------------------------------------
+ if intrinsics is None:
+ raise ValueError("`intrinsics` must be provided unless only_points_cam=True")
+
+ points2D = img_from_cam_np(intrinsics, points_cam, extra_params=extra_params, default=default)
+
+ return points2D, points_cam
+
+
+def project_3D_points(points3D, extrinsics, intrinsics=None, extra_params=None, default=0, only_points_cam=False):
+ """
+ Transforms 3D points to 2D using extrinsic and intrinsic parameters.
+ Args:
+ points3D (torch.Tensor): 3D points of shape Px3.
+ extrinsics (torch.Tensor): Extrinsic parameters of shape Bx3x4.
+ intrinsics (torch.Tensor): Intrinsic parameters of shape Bx3x3.
+ extra_params (torch.Tensor): Extra parameters of shape BxN, used for radial distortion.
+ default (float): Default value to replace NaNs.
+ only_points_cam (bool): If True, skip the projection and return points2D as None.
+
+ Returns:
+ tuple: (points2D, points_cam) where points2D is of shape BxNx2 or None if only_points_cam=True,
+ and points_cam is of shape Bx3xN.
+ """
+ with torch.cuda.amp.autocast(dtype=torch.double):
+ N = points3D.shape[0] # Number of points
+ B = extrinsics.shape[0] # Batch size, i.e., number of cameras
+ points3D_homogeneous = torch.cat([points3D, torch.ones_like(points3D[..., 0:1])], dim=1) # Nx4
+ # Reshape for batch processing
+ points3D_homogeneous = points3D_homogeneous.unsqueeze(0).expand(B, -1, -1) # BxNx4
+
+ # Step 1: Apply extrinsic parameters
+ # Transform 3D points to camera coordinate system for all cameras
+ points_cam = torch.bmm(extrinsics, points3D_homogeneous.transpose(-1, -2))
+
+ if only_points_cam:
+ return None, points_cam
+
+ # Step 2: Apply intrinsic parameters and (optional) distortion
+ points2D = img_from_cam(intrinsics, points_cam, extra_params, default)
+
+ return points2D, points_cam
+
+
+def img_from_cam(intrinsics, points_cam, extra_params=None, default=0.0):
+ """
+ Applies intrinsic parameters and optional distortion to the given 3D points.
+
+ Args:
+ intrinsics (torch.Tensor): Intrinsic camera parameters of shape Bx3x3.
+ points_cam (torch.Tensor): 3D points in camera coordinates of shape Bx3xN.
+ extra_params (torch.Tensor, optional): Distortion parameters of shape BxN, where N can be 1, 2, or 4.
+ default (float, optional): Default value to replace NaNs in the output.
+
+ Returns:
+ points2D (torch.Tensor): 2D points in pixel coordinates of shape BxNx2.
+ """
+
+ # Normalize by the third coordinate (homogeneous division)
+ points_cam = points_cam / points_cam[:, 2:3, :]
+ # Extract uv
+ uv = points_cam[:, :2, :]
+
+ # Apply distortion if extra_params are provided
+ if extra_params is not None:
+ uu, vv = apply_distortion(extra_params, uv[:, 0], uv[:, 1])
+ uv = torch.stack([uu, vv], dim=1)
+
+ # Prepare points_cam for batch matrix multiplication
+ points_cam_homo = torch.cat((uv, torch.ones_like(uv[:, :1, :])), dim=1) # Bx3xN
+ # Apply intrinsic parameters using batch matrix multiplication
+ points2D_homo = torch.bmm(intrinsics, points_cam_homo) # Bx3xN
+
+ # Extract x and y coordinates
+ points2D = points2D_homo[:, :2, :] # Bx2xN
+
+ # Replace NaNs with default value
+ points2D = torch.nan_to_num(points2D, nan=default)
+
+ return points2D.transpose(1, 2) # BxNx2
+
+
+if __name__ == "__main__":
+ # Set up example input
+ B, N = 24, 10240
+
+ for _ in range(100):
+ points3D = np.random.rand(N, 3).astype(np.float64)
+ extrinsics = np.random.rand(B, 3, 4).astype(np.float64)
+ intrinsics = np.random.rand(B, 3, 3).astype(np.float64)
+
+ # Convert to torch tensors
+ points3D_torch = torch.tensor(points3D)
+ extrinsics_torch = torch.tensor(extrinsics)
+ intrinsics_torch = torch.tensor(intrinsics)
+
+ # Run NumPy implementation
+ points2D_np, points_cam_np = project_3D_points_np(points3D, extrinsics, intrinsics)
+
+ # Run torch implementation
+ points2D_torch, points_cam_torch = project_3D_points(points3D_torch, extrinsics_torch, intrinsics_torch)
+
+ # Convert torch output to numpy
+ points2D_torch_np = points2D_torch.detach().numpy()
+ points_cam_torch_np = points_cam_torch.detach().numpy()
+
+ # Compute difference
+ diff = np.abs(points2D_np - points2D_torch_np)
+ print("Difference between NumPy and PyTorch implementations:")
+ print(diff)
+
+ # Check max error
+ max_diff = np.max(diff)
+ print(f"Maximum difference: {max_diff}")
+
+ if np.allclose(points2D_np, points2D_torch_np, atol=1e-6):
+ print("Implementations match closely.")
+ else:
+ print("Significant differences detected.")
+
+ if points_cam_np is not None:
+ points_cam_diff = np.abs(points_cam_np - points_cam_torch_np)
+ print("Difference between NumPy and PyTorch camera-space coordinates:")
+ print(points_cam_diff)
+
+ # Check max error
+ max_cam_diff = np.max(points_cam_diff)
+ print(f"Maximum camera-space coordinate difference: {max_cam_diff}")
+
+ if np.allclose(points_cam_np, points_cam_torch_np, atol=1e-6):
+ print("Camera-space coordinates match closely.")
+ else:
+ print("Significant differences detected in camera-space coordinates.")
diff --git a/libs/vggt/dependency/track_modules/__init__.py b/libs/vggt/dependency/track_modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/libs/vggt/dependency/track_modules/base_track_predictor.py b/libs/vggt/dependency/track_modules/base_track_predictor.py
new file mode 100644
index 0000000000000000000000000000000000000000..8218c014e20baa646b612e368d8bdd1841658d65
--- /dev/null
+++ b/libs/vggt/dependency/track_modules/base_track_predictor.py
@@ -0,0 +1,190 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+from einops import rearrange, repeat
+
+from .blocks import EfficientUpdateFormer, CorrBlock
+from .utils import sample_features4d, get_2d_embedding, get_2d_sincos_pos_embed
+
+
+class BaseTrackerPredictor(nn.Module):
+ def __init__(
+ self,
+ stride=4,
+ corr_levels=5,
+ corr_radius=4,
+ latent_dim=128,
+ hidden_size=384,
+ use_spaceatt=True,
+ depth=6,
+ fine=False,
+ ):
+ super(BaseTrackerPredictor, self).__init__()
+ """
+ The base template to create a track predictor
+
+ Modified from https://github.com/facebookresearch/co-tracker/
+ """
+
+ self.stride = stride
+ self.latent_dim = latent_dim
+ self.corr_levels = corr_levels
+ self.corr_radius = corr_radius
+ self.hidden_size = hidden_size
+ self.fine = fine
+
+ self.flows_emb_dim = latent_dim // 2
+ self.transformer_dim = self.corr_levels * (self.corr_radius * 2 + 1) ** 2 + self.latent_dim * 2
+
+ if self.fine:
+ # TODO this is the old dummy code, will remove this when we train next model
+ self.transformer_dim += 4 if self.transformer_dim % 2 == 0 else 5
+ else:
+ self.transformer_dim += (4 - self.transformer_dim % 4) % 4
+
+ space_depth = depth if use_spaceatt else 0
+ time_depth = depth
+
+ self.updateformer = EfficientUpdateFormer(
+ space_depth=space_depth,
+ time_depth=time_depth,
+ input_dim=self.transformer_dim,
+ hidden_size=self.hidden_size,
+ output_dim=self.latent_dim + 2,
+ mlp_ratio=4.0,
+ add_space_attn=use_spaceatt,
+ )
+
+ self.norm = nn.GroupNorm(1, self.latent_dim)
+
+ # A linear layer to update track feats at each iteration
+ self.ffeat_updater = nn.Sequential(nn.Linear(self.latent_dim, self.latent_dim), nn.GELU())
+
+ if not self.fine:
+ self.vis_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1))
+
+ def forward(self, query_points, fmaps=None, iters=4, return_feat=False, down_ratio=1):
+ """
+ query_points: B x N x 2, the number of batches, tracks, and xy
+ fmaps: B x S x C x HH x WW, the number of batches, frames, and feature dimension.
+ note HH and WW is the size of feature maps instead of original images
+ """
+ B, N, D = query_points.shape
+ B, S, C, HH, WW = fmaps.shape
+
+ assert D == 2
+
+ # Scale the input query_points because we may downsample the images
+ # by down_ratio or self.stride
+ # e.g., if a 3x1024x1024 image is processed to a 128x256x256 feature map
+ # its query_points should be query_points/4
+ if down_ratio > 1:
+ query_points = query_points / float(down_ratio)
+ query_points = query_points / float(self.stride)
+
+ # Init with coords as the query points
+ # It means the search will start from the position of query points at the reference frames
+ coords = query_points.clone().reshape(B, 1, N, 2).repeat(1, S, 1, 1)
+
+ # Sample/extract the features of the query points in the query frame
+ query_track_feat = sample_features4d(fmaps[:, 0], coords[:, 0])
+
+ # init track feats by query feats
+ track_feats = query_track_feat.unsqueeze(1).repeat(1, S, 1, 1) # B, S, N, C
+ # back up the init coords
+ coords_backup = coords.clone()
+
+ # Construct the correlation block
+
+ fcorr_fn = CorrBlock(fmaps, num_levels=self.corr_levels, radius=self.corr_radius)
+
+ coord_preds = []
+
+ # Iterative Refinement
+ for itr in range(iters):
+ # Detach the gradients from the last iteration
+ # (in my experience, not very important for performance)
+ coords = coords.detach()
+
+ # Compute the correlation (check the implementation of CorrBlock)
+
+ fcorr_fn.corr(track_feats)
+ fcorrs = fcorr_fn.sample(coords) # B, S, N, corrdim
+
+ corrdim = fcorrs.shape[3]
+
+ fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, corrdim)
+
+ # Movement of current coords relative to query points
+ flows = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 2)
+
+ flows_emb = get_2d_embedding(flows, self.flows_emb_dim, cat_coords=False)
+
+ # (In my trials, it is also okay to just add the flows_emb instead of concat)
+ flows_emb = torch.cat([flows_emb, flows], dim=-1)
+
+ track_feats_ = track_feats.permute(0, 2, 1, 3).reshape(B * N, S, self.latent_dim)
+
+ # Concatenate them as the input for the transformers
+ transformer_input = torch.cat([flows_emb, fcorrs_, track_feats_], dim=2)
+
+ if transformer_input.shape[2] < self.transformer_dim:
+ # pad the features to match the dimension
+ pad_dim = self.transformer_dim - transformer_input.shape[2]
+ pad = torch.zeros_like(flows_emb[..., 0:pad_dim])
+ transformer_input = torch.cat([transformer_input, pad], dim=2)
+
+ # 2D positional embed
+ # TODO: this can be much simplified
+ pos_embed = get_2d_sincos_pos_embed(self.transformer_dim, grid_size=(HH, WW)).to(query_points.device)
+ sampled_pos_emb = sample_features4d(pos_embed.expand(B, -1, -1, -1), coords[:, 0])
+ sampled_pos_emb = rearrange(sampled_pos_emb, "b n c -> (b n) c").unsqueeze(1)
+
+ x = transformer_input + sampled_pos_emb
+
+ # B, N, S, C
+ x = rearrange(x, "(b n) s d -> b n s d", b=B)
+
+ # Compute the delta coordinates and delta track features
+ delta = self.updateformer(x)
+ # BN, S, C
+ delta = rearrange(delta, " b n s d -> (b n) s d", b=B)
+ delta_coords_ = delta[:, :, :2]
+ delta_feats_ = delta[:, :, 2:]
+
+ track_feats_ = track_feats_.reshape(B * N * S, self.latent_dim)
+ delta_feats_ = delta_feats_.reshape(B * N * S, self.latent_dim)
+
+ # Update the track features
+ track_feats_ = self.ffeat_updater(self.norm(delta_feats_)) + track_feats_
+ track_feats = track_feats_.reshape(B, N, S, self.latent_dim).permute(0, 2, 1, 3) # BxSxNxC
+
+ # B x S x N x 2
+ coords = coords + delta_coords_.reshape(B, N, S, 2).permute(0, 2, 1, 3)
+
+ # Force coord0 as query
+ # because we assume the query points should not be changed
+ coords[:, 0] = coords_backup[:, 0]
+
+ # The predicted tracks are in the original image scale
+ if down_ratio > 1:
+ coord_preds.append(coords * self.stride * down_ratio)
+ else:
+ coord_preds.append(coords * self.stride)
+
+ # B, S, N
+ if not self.fine:
+ vis_e = self.vis_predictor(track_feats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N)
+ vis_e = torch.sigmoid(vis_e)
+ else:
+ vis_e = None
+
+ if return_feat:
+ return coord_preds, vis_e, track_feats, query_track_feat
+ else:
+ return coord_preds, vis_e
diff --git a/libs/vggt/dependency/track_modules/blocks.py b/libs/vggt/dependency/track_modules/blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..e0017d2c25338d0ce5d3f31e3802282259c8fa36
--- /dev/null
+++ b/libs/vggt/dependency/track_modules/blocks.py
@@ -0,0 +1,329 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+# Modified from https://github.com/facebookresearch/co-tracker/
+
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from functools import partial
+from typing import Callable
+import collections
+from torch import Tensor
+from itertools import repeat
+
+from .utils import bilinear_sampler
+
+from .modules import Mlp, AttnBlock, CrossAttnBlock, ResidualBlock
+
+
+class BasicEncoder(nn.Module):
+ def __init__(self, input_dim=3, output_dim=128, stride=4):
+ super(BasicEncoder, self).__init__()
+
+ self.stride = stride
+ self.norm_fn = "instance"
+ self.in_planes = output_dim // 2
+
+ self.norm1 = nn.InstanceNorm2d(self.in_planes)
+ self.norm2 = nn.InstanceNorm2d(output_dim * 2)
+
+ self.conv1 = nn.Conv2d(input_dim, self.in_planes, kernel_size=7, stride=2, padding=3, padding_mode="zeros")
+ self.relu1 = nn.ReLU(inplace=True)
+ self.layer1 = self._make_layer(output_dim // 2, stride=1)
+ self.layer2 = self._make_layer(output_dim // 4 * 3, stride=2)
+ self.layer3 = self._make_layer(output_dim, stride=2)
+ self.layer4 = self._make_layer(output_dim, stride=2)
+
+ self.conv2 = nn.Conv2d(
+ output_dim * 3 + output_dim // 4, output_dim * 2, kernel_size=3, padding=1, padding_mode="zeros"
+ )
+ self.relu2 = nn.ReLU(inplace=True)
+ self.conv3 = nn.Conv2d(output_dim * 2, output_dim, kernel_size=1)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
+ elif isinstance(m, (nn.InstanceNorm2d)):
+ if m.weight is not None:
+ nn.init.constant_(m.weight, 1)
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+
+ def _make_layer(self, dim, stride=1):
+ layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
+ layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
+ layers = (layer1, layer2)
+
+ self.in_planes = dim
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ _, _, H, W = x.shape
+
+ x = self.conv1(x)
+ x = self.norm1(x)
+ x = self.relu1(x)
+
+ a = self.layer1(x)
+ b = self.layer2(a)
+ c = self.layer3(b)
+ d = self.layer4(c)
+
+ a = _bilinear_intepolate(a, self.stride, H, W)
+ b = _bilinear_intepolate(b, self.stride, H, W)
+ c = _bilinear_intepolate(c, self.stride, H, W)
+ d = _bilinear_intepolate(d, self.stride, H, W)
+
+ x = self.conv2(torch.cat([a, b, c, d], dim=1))
+ x = self.norm2(x)
+ x = self.relu2(x)
+ x = self.conv3(x)
+ return x
+
+
+class ShallowEncoder(nn.Module):
+ def __init__(self, input_dim=3, output_dim=32, stride=1, norm_fn="instance"):
+ super(ShallowEncoder, self).__init__()
+ self.stride = stride
+ self.norm_fn = norm_fn
+ self.in_planes = output_dim
+
+ if self.norm_fn == "group":
+ self.norm1 = nn.GroupNorm(num_groups=8, num_channels=self.in_planes)
+ self.norm2 = nn.GroupNorm(num_groups=8, num_channels=output_dim * 2)
+ elif self.norm_fn == "batch":
+ self.norm1 = nn.BatchNorm2d(self.in_planes)
+ self.norm2 = nn.BatchNorm2d(output_dim * 2)
+ elif self.norm_fn == "instance":
+ self.norm1 = nn.InstanceNorm2d(self.in_planes)
+ self.norm2 = nn.InstanceNorm2d(output_dim * 2)
+ elif self.norm_fn == "none":
+ self.norm1 = nn.Sequential()
+
+ self.conv1 = nn.Conv2d(input_dim, self.in_planes, kernel_size=3, stride=2, padding=1, padding_mode="zeros")
+ self.relu1 = nn.ReLU(inplace=True)
+
+ self.layer1 = self._make_layer(output_dim, stride=2)
+
+ self.layer2 = self._make_layer(output_dim, stride=2)
+ self.conv2 = nn.Conv2d(output_dim, output_dim, kernel_size=1)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
+ elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
+ if m.weight is not None:
+ nn.init.constant_(m.weight, 1)
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+
+ def _make_layer(self, dim, stride=1):
+ self.in_planes = dim
+
+ layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
+ return layer1
+
+ def forward(self, x):
+ _, _, H, W = x.shape
+
+ x = self.conv1(x)
+ x = self.norm1(x)
+ x = self.relu1(x)
+
+ tmp = self.layer1(x)
+ x = x + F.interpolate(tmp, (x.shape[-2:]), mode="bilinear", align_corners=True)
+ tmp = self.layer2(tmp)
+ x = x + F.interpolate(tmp, (x.shape[-2:]), mode="bilinear", align_corners=True)
+ tmp = None
+ x = self.conv2(x) + x
+
+ x = F.interpolate(x, (H // self.stride, W // self.stride), mode="bilinear", align_corners=True)
+
+ return x
+
+
+def _bilinear_intepolate(x, stride, H, W):
+ return F.interpolate(x, (H // stride, W // stride), mode="bilinear", align_corners=True)
+
+
+class EfficientUpdateFormer(nn.Module):
+ """
+ Transformer model that updates track estimates.
+ """
+
+ def __init__(
+ self,
+ space_depth=6,
+ time_depth=6,
+ input_dim=320,
+ hidden_size=384,
+ num_heads=8,
+ output_dim=130,
+ mlp_ratio=4.0,
+ add_space_attn=True,
+ num_virtual_tracks=64,
+ ):
+ super().__init__()
+
+ self.out_channels = 2
+ self.num_heads = num_heads
+ self.hidden_size = hidden_size
+ self.add_space_attn = add_space_attn
+ self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True)
+ self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True)
+ self.num_virtual_tracks = num_virtual_tracks
+
+ if self.add_space_attn:
+ self.virual_tracks = nn.Parameter(torch.randn(1, num_virtual_tracks, 1, hidden_size))
+ else:
+ self.virual_tracks = None
+
+ self.time_blocks = nn.ModuleList(
+ [
+ AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, attn_class=nn.MultiheadAttention)
+ for _ in range(time_depth)
+ ]
+ )
+
+ if add_space_attn:
+ self.space_virtual_blocks = nn.ModuleList(
+ [
+ AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, attn_class=nn.MultiheadAttention)
+ for _ in range(space_depth)
+ ]
+ )
+ self.space_point2virtual_blocks = nn.ModuleList(
+ [CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)]
+ )
+ self.space_virtual2point_blocks = nn.ModuleList(
+ [CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)]
+ )
+ assert len(self.time_blocks) >= len(self.space_virtual2point_blocks)
+ self.initialize_weights()
+
+ def initialize_weights(self):
+ def _basic_init(module):
+ if isinstance(module, nn.Linear):
+ torch.nn.init.xavier_uniform_(module.weight)
+ if module.bias is not None:
+ nn.init.constant_(module.bias, 0)
+
+ def init_weights_vit_timm(module: nn.Module, name: str = ""):
+ """ViT weight initialization, original timm impl (for reproducibility)"""
+ if isinstance(module, nn.Linear):
+ trunc_normal_(module.weight, std=0.02)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+
+ def forward(self, input_tensor, mask=None):
+ tokens = self.input_transform(input_tensor)
+
+ init_tokens = tokens
+
+ B, _, T, _ = tokens.shape
+
+ if self.add_space_attn:
+ virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1)
+ tokens = torch.cat([tokens, virtual_tokens], dim=1)
+
+ _, N, _, _ = tokens.shape
+
+ j = 0
+ for i in range(len(self.time_blocks)):
+ time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C
+ time_tokens = self.time_blocks[i](time_tokens)
+
+ tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C
+ if self.add_space_attn and (i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0):
+ space_tokens = tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1) # B N T C -> (B T) N C
+ point_tokens = space_tokens[:, : N - self.num_virtual_tracks]
+ virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :]
+
+ virtual_tokens = self.space_virtual2point_blocks[j](virtual_tokens, point_tokens, mask=mask)
+ virtual_tokens = self.space_virtual_blocks[j](virtual_tokens)
+ point_tokens = self.space_point2virtual_blocks[j](point_tokens, virtual_tokens, mask=mask)
+ space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1)
+ tokens = space_tokens.view(B, T, N, -1).permute(0, 2, 1, 3) # (B T) N C -> B N T C
+ j += 1
+
+ if self.add_space_attn:
+ tokens = tokens[:, : N - self.num_virtual_tracks]
+
+ tokens = tokens + init_tokens
+
+ flow = self.flow_head(tokens)
+ return flow
+
+
+class CorrBlock:
+ def __init__(self, fmaps, num_levels=4, radius=4, multiple_track_feats=False, padding_mode="zeros"):
+ B, S, C, H, W = fmaps.shape
+ self.S, self.C, self.H, self.W = S, C, H, W
+ self.padding_mode = padding_mode
+ self.num_levels = num_levels
+ self.radius = radius
+ self.fmaps_pyramid = []
+ self.multiple_track_feats = multiple_track_feats
+
+ self.fmaps_pyramid.append(fmaps)
+ for i in range(self.num_levels - 1):
+ fmaps_ = fmaps.reshape(B * S, C, H, W)
+ fmaps_ = F.avg_pool2d(fmaps_, 2, stride=2)
+ _, _, H, W = fmaps_.shape
+ fmaps = fmaps_.reshape(B, S, C, H, W)
+ self.fmaps_pyramid.append(fmaps)
+
+ def sample(self, coords):
+ r = self.radius
+ B, S, N, D = coords.shape
+ assert D == 2
+
+ H, W = self.H, self.W
+ out_pyramid = []
+ for i in range(self.num_levels):
+ corrs = self.corrs_pyramid[i] # B, S, N, H, W
+ *_, H, W = corrs.shape
+
+ dx = torch.linspace(-r, r, 2 * r + 1)
+ dy = torch.linspace(-r, r, 2 * r + 1)
+ delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to(coords.device)
+
+ centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2**i
+ delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
+ coords_lvl = centroid_lvl + delta_lvl
+
+ corrs = bilinear_sampler(corrs.reshape(B * S * N, 1, H, W), coords_lvl, padding_mode=self.padding_mode)
+ corrs = corrs.view(B, S, N, -1)
+
+ out_pyramid.append(corrs)
+
+ out = torch.cat(out_pyramid, dim=-1).contiguous() # B, S, N, LRR*2
+ return out
+
+ def corr(self, targets):
+ B, S, N, C = targets.shape
+ if self.multiple_track_feats:
+ targets_split = targets.split(C // self.num_levels, dim=-1)
+ B, S, N, C = targets_split[0].shape
+
+ assert C == self.C
+ assert S == self.S
+
+ fmap1 = targets
+
+ self.corrs_pyramid = []
+ for i, fmaps in enumerate(self.fmaps_pyramid):
+ *_, H, W = fmaps.shape
+ fmap2s = fmaps.view(B, S, C, H * W) # B S C H W -> B S C (H W)
+ if self.multiple_track_feats:
+ fmap1 = targets_split[i]
+ corrs = torch.matmul(fmap1, fmap2s)
+ corrs = corrs.view(B, S, N, H, W) # B S N (H W) -> B S N H W
+ corrs = corrs / torch.sqrt(torch.tensor(C).float())
+ self.corrs_pyramid.append(corrs)
diff --git a/libs/vggt/dependency/track_modules/modules.py b/libs/vggt/dependency/track_modules/modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..e89b26edc7717f04a897977041f26e5c4f1c52b2
--- /dev/null
+++ b/libs/vggt/dependency/track_modules/modules.py
@@ -0,0 +1,202 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from functools import partial
+from typing import Callable
+import collections
+from torch import Tensor
+from itertools import repeat
+
+
+# From PyTorch internals
+def _ntuple(n):
+ def parse(x):
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
+ return tuple(x)
+ return tuple(repeat(x, n))
+
+ return parse
+
+
+def exists(val):
+ return val is not None
+
+
+def default(val, d):
+ return val if exists(val) else d
+
+
+to_2tuple = _ntuple(2)
+
+
+class ResidualBlock(nn.Module):
+ """
+ ResidualBlock: construct a block of two conv layers with residual connections
+ """
+
+ def __init__(self, in_planes, planes, norm_fn="group", stride=1, kernel_size=3):
+ super(ResidualBlock, self).__init__()
+
+ self.conv1 = nn.Conv2d(
+ in_planes, planes, kernel_size=kernel_size, padding=1, stride=stride, padding_mode="zeros"
+ )
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=kernel_size, padding=1, padding_mode="zeros")
+ self.relu = nn.ReLU(inplace=True)
+
+ num_groups = planes // 8
+
+ if norm_fn == "group":
+ self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
+ self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
+ if not stride == 1:
+ self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
+
+ elif norm_fn == "batch":
+ self.norm1 = nn.BatchNorm2d(planes)
+ self.norm2 = nn.BatchNorm2d(planes)
+ if not stride == 1:
+ self.norm3 = nn.BatchNorm2d(planes)
+
+ elif norm_fn == "instance":
+ self.norm1 = nn.InstanceNorm2d(planes)
+ self.norm2 = nn.InstanceNorm2d(planes)
+ if not stride == 1:
+ self.norm3 = nn.InstanceNorm2d(planes)
+
+ elif norm_fn == "none":
+ self.norm1 = nn.Sequential()
+ self.norm2 = nn.Sequential()
+ if not stride == 1:
+ self.norm3 = nn.Sequential()
+ else:
+ raise NotImplementedError
+
+ if stride == 1:
+ self.downsample = None
+ else:
+ self.downsample = nn.Sequential(nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
+
+ def forward(self, x):
+ y = x
+ y = self.relu(self.norm1(self.conv1(y)))
+ y = self.relu(self.norm2(self.conv2(y)))
+
+ if self.downsample is not None:
+ x = self.downsample(x)
+
+ return self.relu(x + y)
+
+
+class Mlp(nn.Module):
+ """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
+
+ def __init__(
+ self,
+ in_features,
+ hidden_features=None,
+ out_features=None,
+ act_layer=nn.GELU,
+ norm_layer=None,
+ bias=True,
+ drop=0.0,
+ use_conv=False,
+ ):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ bias = to_2tuple(bias)
+ drop_probs = to_2tuple(drop)
+ linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
+
+ self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
+ self.act = act_layer()
+ self.drop1 = nn.Dropout(drop_probs[0])
+ self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
+ self.drop2 = nn.Dropout(drop_probs[1])
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop1(x)
+ x = self.fc2(x)
+ x = self.drop2(x)
+ return x
+
+
+class AttnBlock(nn.Module):
+ def __init__(
+ self,
+ hidden_size,
+ num_heads,
+ attn_class: Callable[..., nn.Module] = nn.MultiheadAttention,
+ mlp_ratio=4.0,
+ **block_kwargs,
+ ):
+ """
+ Self attention block
+ """
+ super().__init__()
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+
+ self.attn = attn_class(embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs)
+
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
+
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0)
+
+ def forward(self, x, mask=None):
+ # Prepare the mask for PyTorch's attention (it expects a different format)
+ # attn_mask = mask if mask is not None else None
+ # Normalize before attention
+ x = self.norm1(x)
+
+ # PyTorch's MultiheadAttention returns attn_output, attn_output_weights
+ # attn_output, _ = self.attn(x, x, x, attn_mask=attn_mask)
+
+ attn_output, _ = self.attn(x, x, x)
+
+ # Add & Norm
+ x = x + attn_output
+ x = x + self.mlp(self.norm2(x))
+ return x
+
+
+class CrossAttnBlock(nn.Module):
+ def __init__(self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, **block_kwargs):
+ """
+ Cross attention block
+ """
+ super().__init__()
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.norm_context = nn.LayerNorm(hidden_size)
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+
+ self.cross_attn = nn.MultiheadAttention(
+ embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs
+ )
+
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
+
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0)
+
+ def forward(self, x, context, mask=None):
+ # Normalize inputs
+ x = self.norm1(x)
+ context = self.norm_context(context)
+
+ # Apply cross attention
+ # Note: nn.MultiheadAttention returns attn_output, attn_output_weights
+ attn_output, _ = self.cross_attn(x, context, context, attn_mask=mask)
+
+ # Add & Norm
+ x = x + attn_output
+ x = x + self.mlp(self.norm2(x))
+ return x
diff --git a/libs/vggt/dependency/track_modules/track_refine.py b/libs/vggt/dependency/track_modules/track_refine.py
new file mode 100644
index 0000000000000000000000000000000000000000..54a7ace1d49686304e5fbf28c33168667c28e181
--- /dev/null
+++ b/libs/vggt/dependency/track_modules/track_refine.py
@@ -0,0 +1,419 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from functools import partial
+from torch import nn, einsum
+from einops import rearrange, repeat
+from einops.layers.torch import Rearrange, Reduce
+
+from PIL import Image
+import os
+from typing import Union, Tuple
+
+
+def refine_track(
+ images, fine_fnet, fine_tracker, coarse_pred, compute_score=False, pradius=15, sradius=2, fine_iters=6, chunk=40960
+):
+ """
+ Refines the tracking of images using a fine track predictor and a fine feature network.
+ Check https://arxiv.org/abs/2312.04563 for more details.
+
+ Args:
+ images (torch.Tensor): The images to be tracked.
+ fine_fnet (nn.Module): The fine feature network.
+ fine_tracker (nn.Module): The fine track predictor.
+ coarse_pred (torch.Tensor): The coarse predictions of tracks.
+ compute_score (bool, optional): Whether to compute the score. Defaults to False.
+ pradius (int, optional): The radius of a patch. Defaults to 15.
+ sradius (int, optional): The search radius. Defaults to 2.
+
+ Returns:
+ torch.Tensor: The refined tracks.
+ torch.Tensor, optional: The score.
+ """
+
+ # coarse_pred shape: BxSxNx2,
+ # where B is the batch, S is the video/images length, and N is the number of tracks
+ # now we are going to extract patches with the center at coarse_pred
+ # Please note that the last dimension indicates x and y, and hence has a dim number of 2
+ B, S, N, _ = coarse_pred.shape
+ _, _, _, H, W = images.shape
+
+ # Given the raidus of a patch, compute the patch size
+ psize = pradius * 2 + 1
+
+ # Note that we assume the first frame is the query frame
+ # so the 2D locations of the first frame are the query points
+ query_points = coarse_pred[:, 0]
+
+ # Given 2D positions, we can use grid_sample to extract patches
+ # but it takes too much memory.
+ # Instead, we use the floored track xy to sample patches.
+
+ # For example, if the query point xy is (128.16, 252.78),
+ # and the patch size is (31, 31),
+ # our goal is to extract the content of a rectangle
+ # with left top: (113.16, 237.78)
+ # and right bottom: (143.16, 267.78).
+ # However, we record the floored left top: (113, 237)
+ # and the offset (0.16, 0.78)
+ # Then what we need is just unfolding the images like in CNN,
+ # picking the content at [(113, 237), (143, 267)].
+ # Such operations are highly optimized at pytorch
+ # (well if you really want to use interpolation, check the function extract_glimpse() below)
+
+ with torch.no_grad():
+ content_to_extract = images.reshape(B * S, 3, H, W)
+ C_in = content_to_extract.shape[1]
+
+ # Please refer to https://pytorch.org/docs/stable/generated/torch.nn.Unfold.html
+ # for the detailed explanation of unfold()
+ # Here it runs sliding windows (psize x psize) to build patches
+ # The shape changes from
+ # (B*S)x C_in x H x W to (B*S)x C_in x H_new x W_new x Psize x Psize
+ # where Psize is the size of patch
+ content_to_extract = content_to_extract.unfold(2, psize, 1).unfold(3, psize, 1)
+
+ # Floor the coarse predictions to get integers and save the fractional/decimal
+ track_int = coarse_pred.floor().int()
+ track_frac = coarse_pred - track_int
+
+ # Note the points represent the center of patches
+ # now we get the location of the top left corner of patches
+ # because the ouput of pytorch unfold are indexed by top left corner
+ topleft = track_int - pradius
+ topleft_BSN = topleft.clone()
+
+ # clamp the values so that we will not go out of indexes
+ # NOTE: (VERY IMPORTANT: This operation ASSUMES H=W).
+ # You need to seperately clamp x and y if H!=W
+ topleft = topleft.clamp(0, H - psize)
+
+ # Reshape from BxSxNx2 -> (B*S)xNx2
+ topleft = topleft.reshape(B * S, N, 2)
+
+ # Prepare batches for indexing, shape: (B*S)xN
+ batch_indices = torch.arange(B * S)[:, None].expand(-1, N).to(content_to_extract.device)
+
+ # extracted_patches: (B*S) x N x C_in x Psize x Psize
+ extracted_patches = content_to_extract[batch_indices, :, topleft[..., 1], topleft[..., 0]]
+
+ if chunk < 0:
+ # Extract image patches based on top left corners
+ # Feed patches to fine fent for features
+ patch_feat = fine_fnet(extracted_patches.reshape(B * S * N, C_in, psize, psize))
+ else:
+ patches = extracted_patches.reshape(B * S * N, C_in, psize, psize)
+
+ patch_feat_list = []
+ for p in torch.split(patches, chunk):
+ patch_feat_list += [fine_fnet(p)]
+ patch_feat = torch.cat(patch_feat_list, 0)
+
+ C_out = patch_feat.shape[1]
+
+ # Refine the coarse tracks by fine_tracker
+ # reshape back to B x S x N x C_out x Psize x Psize
+ patch_feat = patch_feat.reshape(B, S, N, C_out, psize, psize)
+ patch_feat = rearrange(patch_feat, "b s n c p q -> (b n) s c p q")
+
+ # Prepare for the query points for fine tracker
+ # They are relative to the patch left top corner,
+ # instead of the image top left corner now
+ # patch_query_points: N x 1 x 2
+ # only 1 here because for each patch we only have 1 query point
+ patch_query_points = track_frac[:, 0] + pradius
+ patch_query_points = patch_query_points.reshape(B * N, 2).unsqueeze(1)
+
+ # Feed the PATCH query points and tracks into fine tracker
+ fine_pred_track_lists, _, _, query_point_feat = fine_tracker(
+ query_points=patch_query_points, fmaps=patch_feat, iters=fine_iters, return_feat=True
+ )
+
+ # relative the patch top left
+ fine_pred_track = fine_pred_track_lists[-1].clone()
+
+ # From (relative to the patch top left) to (relative to the image top left)
+ for idx in range(len(fine_pred_track_lists)):
+ fine_level = rearrange(fine_pred_track_lists[idx], "(b n) s u v -> b s n u v", b=B, n=N)
+ fine_level = fine_level.squeeze(-2)
+ fine_level = fine_level + topleft_BSN
+ fine_pred_track_lists[idx] = fine_level
+
+ # relative to the image top left
+ refined_tracks = fine_pred_track_lists[-1].clone()
+ refined_tracks[:, 0] = query_points
+
+ score = None
+
+ if compute_score:
+ score = compute_score_fn(query_point_feat, patch_feat, fine_pred_track, sradius, psize, B, N, S, C_out)
+
+ return refined_tracks, score
+
+
+def refine_track_v0(
+ images, fine_fnet, fine_tracker, coarse_pred, compute_score=False, pradius=15, sradius=2, fine_iters=6
+):
+ """
+ COPIED FROM VGGSfM
+
+ Refines the tracking of images using a fine track predictor and a fine feature network.
+ Check https://arxiv.org/abs/2312.04563 for more details.
+
+ Args:
+ images (torch.Tensor): The images to be tracked.
+ fine_fnet (nn.Module): The fine feature network.
+ fine_tracker (nn.Module): The fine track predictor.
+ coarse_pred (torch.Tensor): The coarse predictions of tracks.
+ compute_score (bool, optional): Whether to compute the score. Defaults to False.
+ pradius (int, optional): The radius of a patch. Defaults to 15.
+ sradius (int, optional): The search radius. Defaults to 2.
+
+ Returns:
+ torch.Tensor: The refined tracks.
+ torch.Tensor, optional: The score.
+ """
+
+ # coarse_pred shape: BxSxNx2,
+ # where B is the batch, S is the video/images length, and N is the number of tracks
+ # now we are going to extract patches with the center at coarse_pred
+ # Please note that the last dimension indicates x and y, and hence has a dim number of 2
+ B, S, N, _ = coarse_pred.shape
+ _, _, _, H, W = images.shape
+
+ # Given the raidus of a patch, compute the patch size
+ psize = pradius * 2 + 1
+
+ # Note that we assume the first frame is the query frame
+ # so the 2D locations of the first frame are the query points
+ query_points = coarse_pred[:, 0]
+
+ # Given 2D positions, we can use grid_sample to extract patches
+ # but it takes too much memory.
+ # Instead, we use the floored track xy to sample patches.
+
+ # For example, if the query point xy is (128.16, 252.78),
+ # and the patch size is (31, 31),
+ # our goal is to extract the content of a rectangle
+ # with left top: (113.16, 237.78)
+ # and right bottom: (143.16, 267.78).
+ # However, we record the floored left top: (113, 237)
+ # and the offset (0.16, 0.78)
+ # Then what we need is just unfolding the images like in CNN,
+ # picking the content at [(113, 237), (143, 267)].
+ # Such operations are highly optimized at pytorch
+ # (well if you really want to use interpolation, check the function extract_glimpse() below)
+
+ with torch.no_grad():
+ content_to_extract = images.reshape(B * S, 3, H, W)
+ C_in = content_to_extract.shape[1]
+
+ # Please refer to https://pytorch.org/docs/stable/generated/torch.nn.Unfold.html
+ # for the detailed explanation of unfold()
+ # Here it runs sliding windows (psize x psize) to build patches
+ # The shape changes from
+ # (B*S)x C_in x H x W to (B*S)x C_in x H_new x W_new x Psize x Psize
+ # where Psize is the size of patch
+ content_to_extract = content_to_extract.unfold(2, psize, 1).unfold(3, psize, 1)
+
+ # Floor the coarse predictions to get integers and save the fractional/decimal
+ track_int = coarse_pred.floor().int()
+ track_frac = coarse_pred - track_int
+
+ # Note the points represent the center of patches
+ # now we get the location of the top left corner of patches
+ # because the ouput of pytorch unfold are indexed by top left corner
+ topleft = track_int - pradius
+ topleft_BSN = topleft.clone()
+
+ # clamp the values so that we will not go out of indexes
+ # NOTE: (VERY IMPORTANT: This operation ASSUMES H=W).
+ # You need to seperately clamp x and y if H!=W
+ topleft = topleft.clamp(0, H - psize)
+
+ # Reshape from BxSxNx2 -> (B*S)xNx2
+ topleft = topleft.reshape(B * S, N, 2)
+
+ # Prepare batches for indexing, shape: (B*S)xN
+ batch_indices = torch.arange(B * S)[:, None].expand(-1, N).to(content_to_extract.device)
+
+ # Extract image patches based on top left corners
+ # extracted_patches: (B*S) x N x C_in x Psize x Psize
+ extracted_patches = content_to_extract[batch_indices, :, topleft[..., 1], topleft[..., 0]]
+
+ # Feed patches to fine fent for features
+ patch_feat = fine_fnet(extracted_patches.reshape(B * S * N, C_in, psize, psize))
+
+ C_out = patch_feat.shape[1]
+
+ # Refine the coarse tracks by fine_tracker
+
+ # reshape back to B x S x N x C_out x Psize x Psize
+ patch_feat = patch_feat.reshape(B, S, N, C_out, psize, psize)
+ patch_feat = rearrange(patch_feat, "b s n c p q -> (b n) s c p q")
+
+ # Prepare for the query points for fine tracker
+ # They are relative to the patch left top corner,
+ # instead of the image top left corner now
+ # patch_query_points: N x 1 x 2
+ # only 1 here because for each patch we only have 1 query point
+ patch_query_points = track_frac[:, 0] + pradius
+ patch_query_points = patch_query_points.reshape(B * N, 2).unsqueeze(1)
+
+ # Feed the PATCH query points and tracks into fine tracker
+ fine_pred_track_lists, _, _, query_point_feat = fine_tracker(
+ query_points=patch_query_points, fmaps=patch_feat, iters=fine_iters, return_feat=True
+ )
+
+ # relative the patch top left
+ fine_pred_track = fine_pred_track_lists[-1].clone()
+
+ # From (relative to the patch top left) to (relative to the image top left)
+ for idx in range(len(fine_pred_track_lists)):
+ fine_level = rearrange(fine_pred_track_lists[idx], "(b n) s u v -> b s n u v", b=B, n=N)
+ fine_level = fine_level.squeeze(-2)
+ fine_level = fine_level + topleft_BSN
+ fine_pred_track_lists[idx] = fine_level
+
+ # relative to the image top left
+ refined_tracks = fine_pred_track_lists[-1].clone()
+ refined_tracks[:, 0] = query_points
+
+ score = None
+
+ if compute_score:
+ score = compute_score_fn(query_point_feat, patch_feat, fine_pred_track, sradius, psize, B, N, S, C_out)
+
+ return refined_tracks, score
+
+
+################################## NOTE: NOT USED ##################################
+
+
+def compute_score_fn(query_point_feat, patch_feat, fine_pred_track, sradius, psize, B, N, S, C_out):
+ """
+ Compute the scores, i.e., the standard deviation of the 2D similarity heatmaps,
+ given the query point features and reference frame feature maps
+ """
+
+ from kornia.utils.grid import create_meshgrid
+ from kornia.geometry.subpix import dsnt
+
+ # query_point_feat initial shape: B x N x C_out,
+ # query_point_feat indicates the feat at the coorponsing query points
+ # Therefore we don't have S dimension here
+ query_point_feat = query_point_feat.reshape(B, N, C_out)
+ # reshape and expand to B x (S-1) x N x C_out
+ query_point_feat = query_point_feat.unsqueeze(1).expand(-1, S - 1, -1, -1)
+ # and reshape to (B*(S-1)*N) x C_out
+ query_point_feat = query_point_feat.reshape(B * (S - 1) * N, C_out)
+
+ # Radius and size for computing the score
+ ssize = sradius * 2 + 1
+
+ # Reshape, you know it, so many reshaping operations
+ patch_feat = rearrange(patch_feat, "(b n) s c p q -> b s n c p q", b=B, n=N)
+
+ # Again, we unfold the patches to smaller patches
+ # so that we can then focus on smaller patches
+ # patch_feat_unfold shape:
+ # B x S x N x C_out x (psize - 2*sradius) x (psize - 2*sradius) x ssize x ssize
+ # well a bit scary, but actually not
+ patch_feat_unfold = patch_feat.unfold(4, ssize, 1).unfold(5, ssize, 1)
+
+ # Do the same stuffs above, i.e., the same as extracting patches
+ fine_prediction_floor = fine_pred_track.floor().int()
+ fine_level_floor_topleft = fine_prediction_floor - sradius
+
+ # Clamp to ensure the smaller patch is valid
+ fine_level_floor_topleft = fine_level_floor_topleft.clamp(0, psize - ssize)
+ fine_level_floor_topleft = fine_level_floor_topleft.squeeze(2)
+
+ # Prepare the batch indices and xy locations
+
+ batch_indices_score = torch.arange(B)[:, None, None].expand(-1, S, N) # BxSxN
+ batch_indices_score = batch_indices_score.reshape(-1).to(patch_feat_unfold.device) # B*S*N
+ y_indices = fine_level_floor_topleft[..., 0].flatten() # Flatten H indices
+ x_indices = fine_level_floor_topleft[..., 1].flatten() # Flatten W indices
+
+ reference_frame_feat = patch_feat_unfold.reshape(
+ B * S * N, C_out, psize - sradius * 2, psize - sradius * 2, ssize, ssize
+ )
+
+ # Note again, according to pytorch convention
+ # x_indices cooresponds to [..., 1] and y_indices cooresponds to [..., 0]
+ reference_frame_feat = reference_frame_feat[batch_indices_score, :, x_indices, y_indices]
+ reference_frame_feat = reference_frame_feat.reshape(B, S, N, C_out, ssize, ssize)
+ # pick the frames other than the first one, so we have S-1 frames here
+ reference_frame_feat = reference_frame_feat[:, 1:].reshape(B * (S - 1) * N, C_out, ssize * ssize)
+
+ # Compute similarity
+ sim_matrix = torch.einsum("mc,mcr->mr", query_point_feat, reference_frame_feat)
+ softmax_temp = 1.0 / C_out**0.5
+ heatmap = torch.softmax(softmax_temp * sim_matrix, dim=1)
+ # 2D heatmaps
+ heatmap = heatmap.reshape(B * (S - 1) * N, ssize, ssize) # * x ssize x ssize
+
+ coords_normalized = dsnt.spatial_expectation2d(heatmap[None], True)[0]
+ grid_normalized = create_meshgrid(ssize, ssize, normalized_coordinates=True, device=heatmap.device).reshape(
+ 1, -1, 2
+ )
+
+ var = torch.sum(grid_normalized**2 * heatmap.view(-1, ssize * ssize, 1), dim=1) - coords_normalized**2
+ std = torch.sum(torch.sqrt(torch.clamp(var, min=1e-10)), -1) # clamp needed for numerical stability
+
+ score = std.reshape(B, S - 1, N)
+ # set score as 1 for the query frame
+ score = torch.cat([torch.ones_like(score[:, 0:1]), score], dim=1)
+
+ return score
+
+
+def extract_glimpse(
+ tensor: torch.Tensor, size: Tuple[int, int], offsets, mode="bilinear", padding_mode="zeros", debug=False, orib=None
+):
+ B, C, W, H = tensor.shape
+
+ h, w = size
+ xs = torch.arange(0, w, dtype=tensor.dtype, device=tensor.device) - (w - 1) / 2.0
+ ys = torch.arange(0, h, dtype=tensor.dtype, device=tensor.device) - (h - 1) / 2.0
+
+ vy, vx = torch.meshgrid(ys, xs)
+ grid = torch.stack([vx, vy], dim=-1) # h, w, 2
+ grid = grid[None]
+
+ B, N, _ = offsets.shape
+
+ offsets = offsets.reshape((B * N), 1, 1, 2)
+ offsets_grid = offsets + grid
+
+ # normalised grid to [-1, 1]
+ offsets_grid = (offsets_grid - offsets_grid.new_tensor([W / 2, H / 2])) / offsets_grid.new_tensor([W / 2, H / 2])
+
+ # BxCxHxW -> Bx1xCxHxW
+ tensor = tensor[:, None]
+
+ # Bx1xCxHxW -> BxNxCxHxW
+ tensor = tensor.expand(-1, N, -1, -1, -1)
+
+ # BxNxCxHxW -> (B*N)xCxHxW
+ tensor = tensor.reshape((B * N), C, W, H)
+
+ sampled = torch.nn.functional.grid_sample(
+ tensor, offsets_grid, mode=mode, align_corners=False, padding_mode=padding_mode
+ )
+
+ # NOTE: I am not sure it should be h, w or w, h here
+ # but okay for sqaures
+ sampled = sampled.reshape(B, N, C, h, w)
+
+ return sampled
diff --git a/libs/vggt/dependency/track_modules/utils.py b/libs/vggt/dependency/track_modules/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..d8954e87beb85e71c5fa4b5d7eb4f2b476680e6f
--- /dev/null
+++ b/libs/vggt/dependency/track_modules/utils.py
@@ -0,0 +1,216 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# Modified from https://github.com/facebookresearch/PoseDiffusion
+# and https://github.com/facebookresearch/co-tracker/tree/main
+
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from typing import Optional, Tuple, Union
+from einops import rearrange, repeat
+
+
+def get_2d_sincos_pos_embed(embed_dim: int, grid_size: Union[int, Tuple[int, int]], return_grid=False) -> torch.Tensor:
+ """
+ This function initializes a grid and generates a 2D positional embedding using sine and cosine functions.
+ It is a wrapper of get_2d_sincos_pos_embed_from_grid.
+ Args:
+ - embed_dim: The embedding dimension.
+ - grid_size: The grid size.
+ Returns:
+ - pos_embed: The generated 2D positional embedding.
+ """
+ if isinstance(grid_size, tuple):
+ grid_size_h, grid_size_w = grid_size
+ else:
+ grid_size_h = grid_size_w = grid_size
+ grid_h = torch.arange(grid_size_h, dtype=torch.float)
+ grid_w = torch.arange(grid_size_w, dtype=torch.float)
+ grid = torch.meshgrid(grid_w, grid_h, indexing="xy")
+ grid = torch.stack(grid, dim=0)
+ grid = grid.reshape([2, 1, grid_size_h, grid_size_w])
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
+ if return_grid:
+ return (pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2), grid)
+ return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2)
+
+
+def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: torch.Tensor) -> torch.Tensor:
+ """
+ This function generates a 2D positional embedding from a given grid using sine and cosine functions.
+
+ Args:
+ - embed_dim: The embedding dimension.
+ - grid: The grid to generate the embedding from.
+
+ Returns:
+ - emb: The generated 2D positional embedding.
+ """
+ assert embed_dim % 2 == 0
+
+ # use half of dimensions to encode grid_h
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
+
+ emb = torch.cat([emb_h, emb_w], dim=2) # (H*W, D)
+ return emb
+
+
+def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: torch.Tensor) -> torch.Tensor:
+ """
+ This function generates a 1D positional embedding from a given grid using sine and cosine functions.
+
+ Args:
+ - embed_dim: The embedding dimension.
+ - pos: The position to generate the embedding from.
+
+ Returns:
+ - emb: The generated 1D positional embedding.
+ """
+ assert embed_dim % 2 == 0
+ omega = torch.arange(embed_dim // 2, dtype=torch.double)
+ omega /= embed_dim / 2.0
+ omega = 1.0 / 10000**omega # (D/2,)
+
+ pos = pos.reshape(-1) # (M,)
+ out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
+
+ emb_sin = torch.sin(out) # (M, D/2)
+ emb_cos = torch.cos(out) # (M, D/2)
+
+ emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
+ return emb[None].float()
+
+
+def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor:
+ """
+ This function generates a 2D positional embedding from given coordinates using sine and cosine functions.
+
+ Args:
+ - xy: The coordinates to generate the embedding from.
+ - C: The size of the embedding.
+ - cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding.
+
+ Returns:
+ - pe: The generated 2D positional embedding.
+ """
+ B, N, D = xy.shape
+ assert D == 2
+
+ x = xy[:, :, 0:1]
+ y = xy[:, :, 1:2]
+ div_term = (torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C)).reshape(1, 1, int(C / 2))
+
+ pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
+ pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
+
+ pe_x[:, :, 0::2] = torch.sin(x * div_term)
+ pe_x[:, :, 1::2] = torch.cos(x * div_term)
+
+ pe_y[:, :, 0::2] = torch.sin(y * div_term)
+ pe_y[:, :, 1::2] = torch.cos(y * div_term)
+
+ pe = torch.cat([pe_x, pe_y], dim=2) # (B, N, C*3)
+ if cat_coords:
+ pe = torch.cat([xy, pe], dim=2) # (B, N, C*3+3)
+ return pe
+
+
+def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"):
+ r"""Sample a tensor using bilinear interpolation
+
+ `bilinear_sampler(input, coords)` samples a tensor :attr:`input` at
+ coordinates :attr:`coords` using bilinear interpolation. It is the same
+ as `torch.nn.functional.grid_sample()` but with a different coordinate
+ convention.
+
+ The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where
+ :math:`B` is the batch size, :math:`C` is the number of channels,
+ :math:`H` is the height of the image, and :math:`W` is the width of the
+ image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is
+ interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`.
+
+ Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`,
+ in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note
+ that in this case the order of the components is slightly different
+ from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`.
+
+ If `align_corners` is `True`, the coordinate :math:`x` is assumed to be
+ in the range :math:`[0,W-1]`, with 0 corresponding to the center of the
+ left-most image pixel :math:`W-1` to the center of the right-most
+ pixel.
+
+ If `align_corners` is `False`, the coordinate :math:`x` is assumed to
+ be in the range :math:`[0,W]`, with 0 corresponding to the left edge of
+ the left-most pixel :math:`W` to the right edge of the right-most
+ pixel.
+
+ Similar conventions apply to the :math:`y` for the range
+ :math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range
+ :math:`[0,T-1]` and :math:`[0,T]`.
+
+ Args:
+ input (Tensor): batch of input images.
+ coords (Tensor): batch of coordinates.
+ align_corners (bool, optional): Coordinate convention. Defaults to `True`.
+ padding_mode (str, optional): Padding mode. Defaults to `"border"`.
+
+ Returns:
+ Tensor: sampled points.
+ """
+
+ sizes = input.shape[2:]
+
+ assert len(sizes) in [2, 3]
+
+ if len(sizes) == 3:
+ # t x y -> x y t to match dimensions T H W in grid_sample
+ coords = coords[..., [1, 2, 0]]
+
+ if align_corners:
+ coords = coords * torch.tensor([2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device)
+ else:
+ coords = coords * torch.tensor([2 / size for size in reversed(sizes)], device=coords.device)
+
+ coords -= 1
+
+ return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode)
+
+
+def sample_features4d(input, coords):
+ r"""Sample spatial features
+
+ `sample_features4d(input, coords)` samples the spatial features
+ :attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`.
+
+ The field is sampled at coordinates :attr:`coords` using bilinear
+ interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R,
+ 2)`, where each sample has the format :math:`(x_i, y_i)`. This uses the
+ same convention as :func:`bilinear_sampler` with `align_corners=True`.
+
+ The output tensor has one feature per point, and has shape :math:`(B,
+ R, C)`.
+
+ Args:
+ input (Tensor): spatial features.
+ coords (Tensor): points.
+
+ Returns:
+ Tensor: sampled features.
+ """
+
+ B, _, _, _ = input.shape
+
+ # B R 2 -> B R 1 2
+ coords = coords.unsqueeze(2)
+
+ # B C R 1
+ feats = bilinear_sampler(input, coords)
+
+ return feats.permute(0, 2, 1, 3).view(B, -1, feats.shape[1] * feats.shape[3]) # B C R 1 -> B R C
diff --git a/libs/vggt/dependency/track_predict.py b/libs/vggt/dependency/track_predict.py
new file mode 100644
index 0000000000000000000000000000000000000000..c15c23fea612acb9383d7f03d7779b6d0f2dbf82
--- /dev/null
+++ b/libs/vggt/dependency/track_predict.py
@@ -0,0 +1,326 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import numpy as np
+from .vggsfm_utils import *
+
+
+def predict_tracks(
+ images,
+ conf=None,
+ points_3d=None,
+ masks=None,
+ max_query_pts=2048,
+ query_frame_num=5,
+ keypoint_extractor="aliked+sp",
+ max_points_num=163840,
+ fine_tracking=True,
+ complete_non_vis=True,
+):
+ """
+ Predict tracks for the given images and masks.
+
+ TODO: support non-square images
+ TODO: support masks
+
+
+ This function predicts the tracks for the given images and masks using the specified query method
+ and track predictor. It finds query points, and predicts the tracks, visibility, and scores for the query frames.
+
+ Args:
+ images: Tensor of shape [S, 3, H, W] containing the input images.
+ conf: Tensor of shape [S, 1, H, W] containing the confidence scores. Default is None.
+ points_3d: Tensor containing 3D points. Default is None.
+ masks: Optional tensor of shape [S, 1, H, W] containing masks. Default is None.
+ max_query_pts: Maximum number of query points. Default is 2048.
+ query_frame_num: Number of query frames to use. Default is 5.
+ keypoint_extractor: Method for keypoint extraction. Default is "aliked+sp".
+ max_points_num: Maximum number of points to process at once. Default is 163840.
+ fine_tracking: Whether to use fine tracking. Default is True.
+ complete_non_vis: Whether to augment non-visible frames. Default is True.
+
+ Returns:
+ pred_tracks: Numpy array containing the predicted tracks.
+ pred_vis_scores: Numpy array containing the visibility scores for the tracks.
+ pred_confs: Numpy array containing the confidence scores for the tracks.
+ pred_points_3d: Numpy array containing the 3D points for the tracks.
+ pred_colors: Numpy array containing the point colors for the tracks. (0, 255)
+ """
+
+ device = images.device
+ dtype = images.dtype
+ tracker = build_vggsfm_tracker().to(device, dtype)
+
+ # Find query frames
+ query_frame_indexes = generate_rank_by_dino(images, query_frame_num=query_frame_num, device=device)
+
+ # Add the first image to the front if not already present
+ if 0 in query_frame_indexes:
+ query_frame_indexes.remove(0)
+ query_frame_indexes = [0, *query_frame_indexes]
+
+ # TODO: add the functionality to handle the masks
+ keypoint_extractors = initialize_feature_extractors(
+ max_query_pts, extractor_method=keypoint_extractor, device=device
+ )
+
+ pred_tracks = []
+ pred_vis_scores = []
+ pred_confs = []
+ pred_points_3d = []
+ pred_colors = []
+
+ fmaps_for_tracker = tracker.process_images_to_fmaps(images)
+
+ if fine_tracking:
+ print("For faster inference, consider disabling fine_tracking")
+
+ for query_index in query_frame_indexes:
+ print(f"Predicting tracks for query frame {query_index}")
+ pred_track, pred_vis, pred_conf, pred_point_3d, pred_color = _forward_on_query(
+ query_index,
+ images,
+ conf,
+ points_3d,
+ fmaps_for_tracker,
+ keypoint_extractors,
+ tracker,
+ max_points_num,
+ fine_tracking,
+ device,
+ )
+
+ pred_tracks.append(pred_track)
+ pred_vis_scores.append(pred_vis)
+ pred_confs.append(pred_conf)
+ pred_points_3d.append(pred_point_3d)
+ pred_colors.append(pred_color)
+
+ if complete_non_vis:
+ pred_tracks, pred_vis_scores, pred_confs, pred_points_3d, pred_colors = _augment_non_visible_frames(
+ pred_tracks,
+ pred_vis_scores,
+ pred_confs,
+ pred_points_3d,
+ pred_colors,
+ images,
+ conf,
+ points_3d,
+ fmaps_for_tracker,
+ keypoint_extractors,
+ tracker,
+ max_points_num,
+ fine_tracking,
+ min_vis=500,
+ non_vis_thresh=0.1,
+ device=device,
+ )
+
+ pred_tracks = np.concatenate(pred_tracks, axis=1)
+ pred_vis_scores = np.concatenate(pred_vis_scores, axis=1)
+ pred_confs = np.concatenate(pred_confs, axis=0) if pred_confs else None
+ pred_points_3d = np.concatenate(pred_points_3d, axis=0) if pred_points_3d else None
+ pred_colors = np.concatenate(pred_colors, axis=0) if pred_colors else None
+
+ # from vggt.utils.visual_track import visualize_tracks_on_images
+ # visualize_tracks_on_images(images[None], torch.from_numpy(pred_tracks[None]), torch.from_numpy(pred_vis_scores[None])>0.2, out_dir="track_visuals")
+
+ return pred_tracks, pred_vis_scores, pred_confs, pred_points_3d, pred_colors
+
+
+def _forward_on_query(
+ query_index,
+ images,
+ conf,
+ points_3d,
+ fmaps_for_tracker,
+ keypoint_extractors,
+ tracker,
+ max_points_num,
+ fine_tracking,
+ device,
+):
+ """
+ Process a single query frame for track prediction.
+
+ Args:
+ query_index: Index of the query frame
+ images: Tensor of shape [S, 3, H, W] containing the input images
+ conf: Confidence tensor
+ points_3d: 3D points tensor
+ fmaps_for_tracker: Feature maps for the tracker
+ keypoint_extractors: Initialized feature extractors
+ tracker: VGG-SFM tracker
+ max_points_num: Maximum number of points to process at once
+ fine_tracking: Whether to use fine tracking
+ device: Device to use for computation
+
+ Returns:
+ pred_track: Predicted tracks
+ pred_vis: Visibility scores for the tracks
+ pred_conf: Confidence scores for the tracks
+ pred_point_3d: 3D points for the tracks
+ pred_color: Point colors for the tracks (0, 255)
+ """
+ frame_num, _, height, width = images.shape
+
+ query_image = images[query_index]
+ query_points = extract_keypoints(query_image, keypoint_extractors, round_keypoints=False)
+ query_points = query_points[:, torch.randperm(query_points.shape[1], device=device)]
+
+ # Extract the color at the keypoint locations
+ query_points_long = query_points.squeeze(0).round().long()
+ pred_color = images[query_index][:, query_points_long[:, 1], query_points_long[:, 0]]
+ pred_color = (pred_color.permute(1, 0).cpu().numpy() * 255).astype(np.uint8)
+
+ # Query the confidence and points_3d at the keypoint locations
+ if (conf is not None) and (points_3d is not None):
+ assert height == width
+ assert conf.shape[-2] == conf.shape[-1]
+ assert conf.shape[:3] == points_3d.shape[:3]
+ scale = conf.shape[-1] / width
+
+ query_points_scaled = (query_points.squeeze(0) * scale).round().long()
+ query_points_scaled = query_points_scaled.cpu().numpy()
+
+ pred_conf = conf[query_index][query_points_scaled[:, 1], query_points_scaled[:, 0]]
+ pred_point_3d = points_3d[query_index][query_points_scaled[:, 1], query_points_scaled[:, 0]]
+
+ # heuristic to remove low confidence points
+ # should I export this as an input parameter?
+ valid_mask = pred_conf > 1.2
+ if valid_mask.sum() > 512:
+ query_points = query_points[:, valid_mask] # Make sure shape is compatible
+ pred_conf = pred_conf[valid_mask]
+ pred_point_3d = pred_point_3d[valid_mask]
+ pred_color = pred_color[valid_mask]
+ else:
+ pred_conf = None
+ pred_point_3d = None
+
+ reorder_index = calculate_index_mappings(query_index, frame_num, device=device)
+
+ images_feed, fmaps_feed = switch_tensor_order([images, fmaps_for_tracker], reorder_index, dim=0)
+ images_feed = images_feed[None] # add batch dimension
+ fmaps_feed = fmaps_feed[None] # add batch dimension
+
+ all_points_num = images_feed.shape[1] * query_points.shape[1]
+
+ # Don't need to be scared, this is just chunking to make GPU happy
+ if all_points_num > max_points_num:
+ num_splits = (all_points_num + max_points_num - 1) // max_points_num
+ query_points = torch.chunk(query_points, num_splits, dim=1)
+ else:
+ query_points = [query_points]
+
+ pred_track, pred_vis, _ = predict_tracks_in_chunks(
+ tracker, images_feed, query_points, fmaps_feed, fine_tracking=fine_tracking
+ )
+
+ pred_track, pred_vis = switch_tensor_order([pred_track, pred_vis], reorder_index, dim=1)
+
+ pred_track = pred_track.squeeze(0).float().cpu().numpy()
+ pred_vis = pred_vis.squeeze(0).float().cpu().numpy()
+
+ return pred_track, pred_vis, pred_conf, pred_point_3d, pred_color
+
+
+def _augment_non_visible_frames(
+ pred_tracks: list, # ← running list of np.ndarrays
+ pred_vis_scores: list, # ← running list of np.ndarrays
+ pred_confs: list, # ← running list of np.ndarrays for confidence scores
+ pred_points_3d: list, # ← running list of np.ndarrays for 3D points
+ pred_colors: list, # ← running list of np.ndarrays for colors
+ images: torch.Tensor,
+ conf,
+ points_3d,
+ fmaps_for_tracker,
+ keypoint_extractors,
+ tracker,
+ max_points_num: int,
+ fine_tracking: bool,
+ *,
+ min_vis: int = 500,
+ non_vis_thresh: float = 0.1,
+ device: torch.device = None,
+):
+ """
+ Augment tracking for frames with insufficient visibility.
+
+ Args:
+ pred_tracks: List of numpy arrays containing predicted tracks.
+ pred_vis_scores: List of numpy arrays containing visibility scores.
+ pred_confs: List of numpy arrays containing confidence scores.
+ pred_points_3d: List of numpy arrays containing 3D points.
+ pred_colors: List of numpy arrays containing point colors.
+ images: Tensor of shape [S, 3, H, W] containing the input images.
+ conf: Tensor of shape [S, 1, H, W] containing confidence scores
+ points_3d: Tensor containing 3D points
+ fmaps_for_tracker: Feature maps for the tracker
+ keypoint_extractors: Initialized feature extractors
+ tracker: VGG-SFM tracker
+ max_points_num: Maximum number of points to process at once
+ fine_tracking: Whether to use fine tracking
+ min_vis: Minimum visibility threshold
+ non_vis_thresh: Non-visibility threshold
+ device: Device to use for computation
+
+ Returns:
+ Updated pred_tracks, pred_vis_scores, pred_confs, pred_points_3d, and pred_colors lists.
+ """
+ last_query = -1
+ final_trial = False
+ cur_extractors = keypoint_extractors # may be replaced on the final trial
+
+ while True:
+ # Visibility per frame
+ vis_array = np.concatenate(pred_vis_scores, axis=1)
+
+ # Count frames with sufficient visibility using numpy
+ sufficient_vis_count = (vis_array > non_vis_thresh).sum(axis=-1)
+ non_vis_frames = np.where(sufficient_vis_count < min_vis)[0].tolist()
+
+ if len(non_vis_frames) == 0:
+ break
+
+ print("Processing non visible frames:", non_vis_frames)
+
+ # Decide the frames & extractor for this round
+ if non_vis_frames[0] == last_query:
+ # Same frame failed twice - final "all-in" attempt
+ final_trial = True
+ cur_extractors = initialize_feature_extractors(2048, extractor_method="sp+sift+aliked", device=device)
+ query_frame_list = non_vis_frames # blast them all at once
+ else:
+ query_frame_list = [non_vis_frames[0]] # Process one at a time
+
+ last_query = non_vis_frames[0]
+
+ # Run the tracker for every selected frame
+ for query_index in query_frame_list:
+ new_track, new_vis, new_conf, new_point_3d, new_color = _forward_on_query(
+ query_index,
+ images,
+ conf,
+ points_3d,
+ fmaps_for_tracker,
+ cur_extractors,
+ tracker,
+ max_points_num,
+ fine_tracking,
+ device,
+ )
+ pred_tracks.append(new_track)
+ pred_vis_scores.append(new_vis)
+ pred_confs.append(new_conf)
+ pred_points_3d.append(new_point_3d)
+ pred_colors.append(new_color)
+
+ if final_trial:
+ break # Stop after final attempt
+
+ return pred_tracks, pred_vis_scores, pred_confs, pred_points_3d, pred_colors
diff --git a/libs/vggt/dependency/vggsfm_tracker.py b/libs/vggt/dependency/vggsfm_tracker.py
new file mode 100644
index 0000000000000000000000000000000000000000..d79aeef000dcfec506dc4afb4e500d22a758122b
--- /dev/null
+++ b/libs/vggt/dependency/vggsfm_tracker.py
@@ -0,0 +1,124 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from functools import partial
+from torch import nn, einsum
+from einops import rearrange, repeat
+from einops.layers.torch import Rearrange, Reduce
+
+from hydra.utils import instantiate
+from omegaconf import OmegaConf
+
+from .track_modules.track_refine import refine_track
+from .track_modules.blocks import BasicEncoder, ShallowEncoder
+from .track_modules.base_track_predictor import BaseTrackerPredictor
+
+
+class TrackerPredictor(nn.Module):
+ def __init__(self, **extra_args):
+ super(TrackerPredictor, self).__init__()
+ """
+ Initializes the tracker predictor.
+
+ Both coarse_predictor and fine_predictor are constructed as a BaseTrackerPredictor,
+ check track_modules/base_track_predictor.py
+
+ Both coarse_fnet and fine_fnet are constructed as a 2D CNN network
+ check track_modules/blocks.py for BasicEncoder and ShallowEncoder
+ """
+ # Define coarse predictor configuration
+ coarse_stride = 4
+ self.coarse_down_ratio = 2
+
+ # Create networks directly instead of using instantiate
+ self.coarse_fnet = BasicEncoder(stride=coarse_stride)
+ self.coarse_predictor = BaseTrackerPredictor(stride=coarse_stride)
+
+ # Create fine predictor with stride = 1
+ self.fine_fnet = ShallowEncoder(stride=1)
+ self.fine_predictor = BaseTrackerPredictor(
+ stride=1,
+ depth=4,
+ corr_levels=3,
+ corr_radius=3,
+ latent_dim=32,
+ hidden_size=256,
+ fine=True,
+ use_spaceatt=False,
+ )
+
+ def forward(
+ self, images, query_points, fmaps=None, coarse_iters=6, inference=True, fine_tracking=True, fine_chunk=40960
+ ):
+ """
+ Args:
+ images (torch.Tensor): Images as RGB, in the range of [0, 1], with a shape of B x S x 3 x H x W.
+ query_points (torch.Tensor): 2D xy of query points, relative to top left, with a shape of B x N x 2.
+ fmaps (torch.Tensor, optional): Precomputed feature maps. Defaults to None.
+ coarse_iters (int, optional): Number of iterations for coarse prediction. Defaults to 6.
+ inference (bool, optional): Whether to perform inference. Defaults to True.
+ fine_tracking (bool, optional): Whether to perform fine tracking. Defaults to True.
+
+ Returns:
+ tuple: A tuple containing fine_pred_track, coarse_pred_track, pred_vis, and pred_score.
+ """
+
+ if fmaps is None:
+ batch_num, frame_num, image_dim, height, width = images.shape
+ reshaped_image = images.reshape(batch_num * frame_num, image_dim, height, width)
+ fmaps = self.process_images_to_fmaps(reshaped_image)
+ fmaps = fmaps.reshape(batch_num, frame_num, -1, fmaps.shape[-2], fmaps.shape[-1])
+
+ if inference:
+ torch.cuda.empty_cache()
+
+ # Coarse prediction
+ coarse_pred_track_lists, pred_vis = self.coarse_predictor(
+ query_points=query_points, fmaps=fmaps, iters=coarse_iters, down_ratio=self.coarse_down_ratio
+ )
+ coarse_pred_track = coarse_pred_track_lists[-1]
+
+ if inference:
+ torch.cuda.empty_cache()
+
+ if fine_tracking:
+ # Refine the coarse prediction
+ fine_pred_track, pred_score = refine_track(
+ images, self.fine_fnet, self.fine_predictor, coarse_pred_track, compute_score=False, chunk=fine_chunk
+ )
+
+ if inference:
+ torch.cuda.empty_cache()
+ else:
+ fine_pred_track = coarse_pred_track
+ pred_score = torch.ones_like(pred_vis)
+
+ return fine_pred_track, coarse_pred_track, pred_vis, pred_score
+
+ def process_images_to_fmaps(self, images):
+ """
+ This function processes images for inference.
+
+ Args:
+ images (torch.Tensor): The images to be processed with shape S x 3 x H x W.
+
+ Returns:
+ torch.Tensor: The processed feature maps.
+ """
+ if self.coarse_down_ratio > 1:
+ # whether or not scale down the input images to save memory
+ fmaps = self.coarse_fnet(
+ F.interpolate(images, scale_factor=1 / self.coarse_down_ratio, mode="bilinear", align_corners=True)
+ )
+ else:
+ fmaps = self.coarse_fnet(images)
+
+ return fmaps
diff --git a/libs/vggt/dependency/vggsfm_utils.py b/libs/vggt/dependency/vggsfm_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..3f7d9ba6a28da07b7f030a17730f4826feaa828e
--- /dev/null
+++ b/libs/vggt/dependency/vggsfm_utils.py
@@ -0,0 +1,305 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import logging
+import warnings
+from typing import Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import pycolmap
+import torch
+import torch.nn.functional as F
+from lightglue import ALIKED, SIFT, SuperPoint
+
+from .vggsfm_tracker import TrackerPredictor
+
+# Suppress verbose logging from dependencies
+logging.getLogger("dinov2").setLevel(logging.WARNING)
+warnings.filterwarnings("ignore", message="xFormers is available")
+warnings.filterwarnings("ignore", message="dinov2")
+
+# Constants
+_RESNET_MEAN = [0.485, 0.456, 0.406]
+_RESNET_STD = [0.229, 0.224, 0.225]
+
+
+def build_vggsfm_tracker(model_path=None):
+ """
+ Build and initialize the VGGSfM tracker.
+
+ Args:
+ model_path: Path to the model weights file. If None, weights are downloaded from HuggingFace.
+
+ Returns:
+ Initialized tracker model in eval mode.
+ """
+ tracker = TrackerPredictor()
+
+ if model_path is None:
+ default_url = "https://huggingface.co/facebook/VGGSfM/resolve/main/vggsfm_v2_tracker.pt"
+ tracker.load_state_dict(torch.hub.load_state_dict_from_url(default_url))
+ else:
+ tracker.load_state_dict(torch.load(model_path))
+
+ tracker.eval()
+ return tracker
+
+
+def generate_rank_by_dino(
+ images, query_frame_num, image_size=336, model_name="dinov2_vitb14_reg", device="cuda", spatial_similarity=False
+):
+ """
+ Generate a ranking of frames using DINO ViT features.
+
+ Args:
+ images: Tensor of shape (S, 3, H, W) with values in range [0, 1]
+ query_frame_num: Number of frames to select
+ image_size: Size to resize images to before processing
+ model_name: Name of the DINO model to use
+ device: Device to run the model on
+ spatial_similarity: Whether to use spatial token similarity or CLS token similarity
+
+ Returns:
+ List of frame indices ranked by their representativeness
+ """
+ # Resize images to the target size
+ images = F.interpolate(images, (image_size, image_size), mode="bilinear", align_corners=False)
+
+ # Load DINO model
+ dino_v2_model = torch.hub.load("facebookresearch/dinov2", model_name)
+ dino_v2_model.eval()
+ dino_v2_model = dino_v2_model.to(device)
+
+ # Normalize images using ResNet normalization
+ resnet_mean = torch.tensor(_RESNET_MEAN, device=device).view(1, 3, 1, 1)
+ resnet_std = torch.tensor(_RESNET_STD, device=device).view(1, 3, 1, 1)
+ images_resnet_norm = (images - resnet_mean) / resnet_std
+
+ with torch.no_grad():
+ frame_feat = dino_v2_model(images_resnet_norm, is_training=True)
+
+ # Process features based on similarity type
+ if spatial_similarity:
+ frame_feat = frame_feat["x_norm_patchtokens"]
+ frame_feat_norm = F.normalize(frame_feat, p=2, dim=1)
+
+ # Compute the similarity matrix
+ frame_feat_norm = frame_feat_norm.permute(1, 0, 2)
+ similarity_matrix = torch.bmm(frame_feat_norm, frame_feat_norm.transpose(-1, -2))
+ similarity_matrix = similarity_matrix.mean(dim=0)
+ else:
+ frame_feat = frame_feat["x_norm_clstoken"]
+ frame_feat_norm = F.normalize(frame_feat, p=2, dim=1)
+ similarity_matrix = torch.mm(frame_feat_norm, frame_feat_norm.transpose(-1, -2))
+
+ distance_matrix = 100 - similarity_matrix.clone()
+
+ # Ignore self-pairing
+ similarity_matrix.fill_diagonal_(-100)
+ similarity_sum = similarity_matrix.sum(dim=1)
+
+ # Find the most common frame
+ most_common_frame_index = torch.argmax(similarity_sum).item()
+
+ # Conduct FPS sampling starting from the most common frame
+ fps_idx = farthest_point_sampling(distance_matrix, query_frame_num, most_common_frame_index)
+
+ # Clean up all tensors and models to free memory
+ del frame_feat, frame_feat_norm, similarity_matrix, distance_matrix
+ del dino_v2_model
+ torch.cuda.empty_cache()
+
+ return fps_idx
+
+
+def farthest_point_sampling(distance_matrix, num_samples, most_common_frame_index=0):
+ """
+ Farthest point sampling algorithm to select diverse frames.
+
+ Args:
+ distance_matrix: Matrix of distances between frames
+ num_samples: Number of frames to select
+ most_common_frame_index: Index of the first frame to select
+
+ Returns:
+ List of selected frame indices
+ """
+ distance_matrix = distance_matrix.clamp(min=0)
+ N = distance_matrix.size(0)
+
+ # Initialize with the most common frame
+ selected_indices = [most_common_frame_index]
+ check_distances = distance_matrix[selected_indices]
+
+ while len(selected_indices) < num_samples:
+ # Find the farthest point from the current set of selected points
+ farthest_point = torch.argmax(check_distances)
+ selected_indices.append(farthest_point.item())
+
+ check_distances = distance_matrix[farthest_point]
+ # Mark already selected points to avoid selecting them again
+ check_distances[selected_indices] = 0
+
+ # Break if all points have been selected
+ if len(selected_indices) == N:
+ break
+
+ return selected_indices
+
+
+def calculate_index_mappings(query_index, S, device=None):
+ """
+ Construct an order that switches [query_index] and [0]
+ so that the content of query_index would be placed at [0].
+
+ Args:
+ query_index: Index to swap with 0
+ S: Total number of elements
+ device: Device to place the tensor on
+
+ Returns:
+ Tensor of indices with the swapped order
+ """
+ new_order = torch.arange(S)
+ new_order[0] = query_index
+ new_order[query_index] = 0
+ if device is not None:
+ new_order = new_order.to(device)
+ return new_order
+
+
+def switch_tensor_order(tensors, order, dim=1):
+ """
+ Reorder tensors along a specific dimension according to the given order.
+
+ Args:
+ tensors: List of tensors to reorder
+ order: Tensor of indices specifying the new order
+ dim: Dimension along which to reorder
+
+ Returns:
+ List of reordered tensors
+ """
+ return [torch.index_select(tensor, dim, order) if tensor is not None else None for tensor in tensors]
+
+
+def initialize_feature_extractors(max_query_num, det_thres=0.005, extractor_method="aliked", device="cuda"):
+ """
+ Initialize feature extractors that can be reused based on a method string.
+
+ Args:
+ max_query_num: Maximum number of keypoints to extract
+ det_thres: Detection threshold for keypoint extraction
+ extractor_method: String specifying which extractors to use (e.g., "aliked", "sp+sift", "aliked+sp+sift")
+ device: Device to run extraction on
+
+ Returns:
+ Dictionary of initialized extractors
+ """
+ extractors = {}
+ methods = extractor_method.lower().split("+")
+
+ for method in methods:
+ method = method.strip()
+ if method == "aliked":
+ aliked_extractor = ALIKED(max_num_keypoints=max_query_num, detection_threshold=det_thres)
+ extractors["aliked"] = aliked_extractor.to(device).eval()
+ elif method == "sp":
+ sp_extractor = SuperPoint(max_num_keypoints=max_query_num, detection_threshold=det_thres)
+ extractors["sp"] = sp_extractor.to(device).eval()
+ elif method == "sift":
+ sift_extractor = SIFT(max_num_keypoints=max_query_num)
+ extractors["sift"] = sift_extractor.to(device).eval()
+ else:
+ print(f"Warning: Unknown feature extractor '{method}', ignoring.")
+
+ if not extractors:
+ print(f"Warning: No valid extractors found in '{extractor_method}'. Using ALIKED by default.")
+ aliked_extractor = ALIKED(max_num_keypoints=max_query_num, detection_threshold=det_thres)
+ extractors["aliked"] = aliked_extractor.to(device).eval()
+
+ return extractors
+
+
+def extract_keypoints(query_image, extractors, round_keypoints=True):
+ """
+ Extract keypoints using pre-initialized feature extractors.
+
+ Args:
+ query_image: Input image tensor (3xHxW, range [0, 1])
+ extractors: Dictionary of initialized extractors
+
+ Returns:
+ Tensor of keypoint coordinates (1xNx2)
+ """
+ query_points = None
+
+ with torch.no_grad():
+ for extractor_name, extractor in extractors.items():
+ query_points_data = extractor.extract(query_image, invalid_mask=None)
+ extractor_points = query_points_data["keypoints"]
+ if round_keypoints:
+ extractor_points = extractor_points.round()
+
+ if query_points is not None:
+ query_points = torch.cat([query_points, extractor_points], dim=1)
+ else:
+ query_points = extractor_points
+
+ return query_points
+
+
+def predict_tracks_in_chunks(
+ track_predictor, images_feed, query_points_list, fmaps_feed, fine_tracking, num_splits=None, fine_chunk=40960
+):
+ """
+ Process a list of query points to avoid memory issues.
+
+ Args:
+ track_predictor (object): The track predictor object used for predicting tracks.
+ images_feed (torch.Tensor): A tensor of shape (B, T, C, H, W) representing a batch of images.
+ query_points_list (list or tuple): A list/tuple of tensors, each of shape (B, Ni, 2) representing chunks of query points.
+ fmaps_feed (torch.Tensor): A tensor of feature maps for the tracker.
+ fine_tracking (bool): Whether to perform fine tracking.
+ num_splits (int, optional): Ignored when query_points_list is provided. Kept for backward compatibility.
+
+ Returns:
+ tuple: A tuple containing the concatenated predicted tracks, visibility, and scores.
+ """
+ # If query_points_list is not a list or tuple but a single tensor, handle it like the old version for backward compatibility
+ if not isinstance(query_points_list, (list, tuple)):
+ query_points = query_points_list
+ if num_splits is None:
+ num_splits = 1
+ query_points_list = torch.chunk(query_points, num_splits, dim=1)
+
+ # Ensure query_points_list is a list for iteration (as torch.chunk returns a tuple)
+ if isinstance(query_points_list, tuple):
+ query_points_list = list(query_points_list)
+
+ fine_pred_track_list = []
+ pred_vis_list = []
+ pred_score_list = []
+
+ for split_points in query_points_list:
+ # Feed into track predictor for each split
+ fine_pred_track, _, pred_vis, pred_score = track_predictor(
+ images_feed, split_points, fmaps=fmaps_feed, fine_tracking=fine_tracking, fine_chunk=fine_chunk
+ )
+ fine_pred_track_list.append(fine_pred_track)
+ pred_vis_list.append(pred_vis)
+ pred_score_list.append(pred_score)
+
+ # Concatenate the results from all splits
+ fine_pred_track = torch.cat(fine_pred_track_list, dim=2)
+ pred_vis = torch.cat(pred_vis_list, dim=2)
+
+ if pred_score is not None:
+ pred_score = torch.cat(pred_score_list, dim=2)
+ else:
+ pred_score = None
+
+ return fine_pred_track, pred_vis, pred_score
diff --git a/libs/vggt/heads/camera_head.py b/libs/vggt/heads/camera_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d1ffb57d6c675dd7ef6166deaf4e9a3354b68dd
--- /dev/null
+++ b/libs/vggt/heads/camera_head.py
@@ -0,0 +1,149 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+import numpy as np
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from vggt.layers import Mlp
+from vggt.layers.block import Block
+from vggt.heads.head_act import activate_pose
+
+
+class CameraHead(nn.Module):
+ """
+ CameraHead predicts camera parameters from token representations using iterative refinement.
+
+ It applies a series of transformer blocks (the "trunk") to dedicated camera tokens.
+ """
+
+ def __init__(
+ self,
+ dim_in: int = 2048,
+ trunk_depth: int = 4,
+ pose_encoding_type: str = "absT_quaR_FoV",
+ num_heads: int = 16,
+ mlp_ratio: int = 4,
+ init_values: float = 0.01,
+ trans_act: str = "linear",
+ quat_act: str = "linear",
+ fl_act: str = "relu", # Field of view activations: ensures FOV values are positive.
+ ):
+ super().__init__()
+
+ if pose_encoding_type == "absT_quaR_FoV":
+ self.target_dim = 9
+ else:
+ raise ValueError(f"Unsupported camera encoding type: {pose_encoding_type}")
+
+ self.trans_act = trans_act
+ self.quat_act = quat_act
+ self.fl_act = fl_act
+ self.trunk_depth = trunk_depth
+
+ # Build the trunk using a sequence of transformer blocks.
+ self.trunk = nn.Sequential(
+ *[
+ Block(dim=dim_in, num_heads=num_heads, mlp_ratio=mlp_ratio, init_values=init_values)
+ for _ in range(trunk_depth)
+ ]
+ )
+
+ # Normalizations for camera token and trunk output.
+ self.token_norm = nn.LayerNorm(dim_in)
+ self.trunk_norm = nn.LayerNorm(dim_in)
+
+ # Learnable empty camera pose token.
+ self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim))
+ self.embed_pose = nn.Linear(self.target_dim, dim_in)
+
+ # Module for producing modulation parameters: shift, scale, and a gate.
+ self.poseLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True))
+
+ # Adaptive layer normalization without affine parameters.
+ self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6)
+ self.pose_branch = Mlp(in_features=dim_in, hidden_features=dim_in // 2, out_features=self.target_dim, drop=0)
+
+ def forward(self, aggregated_tokens_list: list, num_iterations: int = 4) -> list:
+ """
+ Forward pass to predict camera parameters.
+
+ Args:
+ aggregated_tokens_list (list): List of token tensors from the network;
+ the last tensor is used for prediction.
+ num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4.
+
+ Returns:
+ list: A list of predicted camera encodings (post-activation) from each iteration.
+ """
+ # Use tokens from the last block for camera prediction.
+ tokens = aggregated_tokens_list[-1]
+
+ # Extract the camera tokens
+ pose_tokens = tokens[:, :, 0]
+ pose_tokens = self.token_norm(pose_tokens)
+
+ pred_pose_enc_list = self.trunk_fn(pose_tokens, num_iterations)
+ return pred_pose_enc_list
+
+ def trunk_fn(self, pose_tokens: torch.Tensor, num_iterations: int) -> list:
+ """
+ Iteratively refine camera pose predictions.
+
+ Args:
+ pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, S, C].
+ num_iterations (int): Number of refinement iterations.
+
+ Returns:
+ list: List of activated camera encodings from each iteration.
+ """
+ B, S, C = pose_tokens.shape
+ pred_pose_enc = None
+ pred_pose_enc_list = []
+
+ for _ in range(num_iterations):
+ # Use a learned empty pose for the first iteration.
+ if pred_pose_enc is None:
+ module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1))
+ else:
+ # Detach the previous prediction to avoid backprop through time.
+ pred_pose_enc = pred_pose_enc.detach()
+ module_input = self.embed_pose(pred_pose_enc)
+
+ # Generate modulation parameters and split them into shift, scale, and gate components.
+ shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(3, dim=-1)
+
+ # Adaptive layer normalization and modulation.
+ pose_tokens_modulated = gate_msa * modulate(self.adaln_norm(pose_tokens), shift_msa, scale_msa)
+ pose_tokens_modulated = pose_tokens_modulated + pose_tokens
+
+ pose_tokens_modulated = self.trunk(pose_tokens_modulated)
+ # Compute the delta update for the pose encoding.
+ pred_pose_enc_delta = self.pose_branch(self.trunk_norm(pose_tokens_modulated))
+
+ if pred_pose_enc is None:
+ pred_pose_enc = pred_pose_enc_delta
+ else:
+ pred_pose_enc = pred_pose_enc + pred_pose_enc_delta
+
+ # Apply final activation functions for translation, quaternion, and field-of-view.
+ activated_pose = activate_pose(
+ pred_pose_enc, trans_act=self.trans_act, quat_act=self.quat_act, fl_act=self.fl_act
+ )
+ pred_pose_enc_list.append(activated_pose)
+
+ return pred_pose_enc_list
+
+
+def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
+ """
+ Modulate the input tensor using scaling and shifting parameters.
+ """
+ # modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19
+ return x * (1 + scale) + shift
diff --git a/libs/vggt/heads/dpt_head.py b/libs/vggt/heads/dpt_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..73978a87bf3ff134e53076ad20135bfac3045341
--- /dev/null
+++ b/libs/vggt/heads/dpt_head.py
@@ -0,0 +1,484 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+# Inspired by https://github.com/DepthAnything/Depth-Anything-V2
+
+
+import os
+from typing import List, Dict, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from .head_act import activate_head
+from .utils import create_uv_grid, position_grid_to_embed
+
+
+class DPTHead(nn.Module):
+ """
+ DPT Head for dense prediction tasks.
+
+ This implementation follows the architecture described in "Vision Transformers for Dense Prediction"
+ (https://arxiv.org/abs/2103.13413). The DPT head processes features from a vision transformer
+ backbone and produces dense predictions by fusing multi-scale features.
+
+ Args:
+ dim_in (int): Input dimension (channels).
+ patch_size (int, optional): Patch size. Default is 14.
+ output_dim (int, optional): Number of output channels. Default is 4.
+ activation (str, optional): Activation type. Default is "inv_log".
+ conf_activation (str, optional): Confidence activation type. Default is "expp1".
+ features (int, optional): Feature channels for intermediate representations. Default is 256.
+ out_channels (List[int], optional): Output channels for each intermediate layer.
+ intermediate_layer_idx (List[int], optional): Indices of layers from aggregated tokens used for DPT.
+ pos_embed (bool, optional): Whether to use positional embedding. Default is True.
+ feature_only (bool, optional): If True, return features only without the last several layers and activation head. Default is False.
+ down_ratio (int, optional): Downscaling factor for the output resolution. Default is 1.
+ """
+
+ def __init__(
+ self,
+ dim_in: int,
+ patch_size: int = 14,
+ output_dim: int = 4,
+ activation: str = "inv_log",
+ conf_activation: str = "expp1",
+ features: int = 256,
+ out_channels: List[int] = [256, 512, 1024, 1024],
+ intermediate_layer_idx: List[int] = [4, 11, 17, 23],
+ pos_embed: bool = True,
+ feature_only: bool = False,
+ down_ratio: int = 1,
+ ) -> None:
+ super(DPTHead, self).__init__()
+ self.patch_size = patch_size
+ self.activation = activation
+ self.conf_activation = conf_activation
+ self.pos_embed = pos_embed
+ self.feature_only = feature_only
+ self.down_ratio = down_ratio
+ self.intermediate_layer_idx = intermediate_layer_idx
+
+ self.norm = nn.LayerNorm(dim_in)
+
+ # Projection layers for each output channel from tokens.
+ self.projects = nn.ModuleList(
+ [nn.Conv2d(in_channels=dim_in, out_channels=oc, kernel_size=1, stride=1, padding=0) for oc in out_channels]
+ )
+
+ # Resize layers for upsampling feature maps.
+ self.resize_layers = nn.ModuleList(
+ [
+ nn.ConvTranspose2d(
+ in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0
+ ),
+ nn.ConvTranspose2d(
+ in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0
+ ),
+ nn.Identity(),
+ nn.Conv2d(
+ in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1
+ ),
+ ]
+ )
+
+ self.scratch = _make_scratch(out_channels, features, expand=False)
+
+ # Attach additional modules to scratch.
+ self.scratch.stem_transpose = None
+ self.scratch.refinenet1 = _make_fusion_block(features)
+ self.scratch.refinenet2 = _make_fusion_block(features)
+ self.scratch.refinenet3 = _make_fusion_block(features)
+ self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False)
+
+ head_features_1 = features
+ head_features_2 = 32
+
+ if feature_only:
+ self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1)
+ else:
+ self.scratch.output_conv1 = nn.Conv2d(
+ head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1
+ )
+ conv2_in_channels = head_features_1 // 2
+
+ self.scratch.output_conv2 = nn.Sequential(
+ nn.Conv2d(conv2_in_channels, head_features_2, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0),
+ )
+
+ def forward(
+ self,
+ aggregated_tokens_list: List[torch.Tensor],
+ images: torch.Tensor,
+ patch_start_idx: int,
+ frames_chunk_size: int = 8,
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+ """
+ Forward pass through the DPT head, supports processing by chunking frames.
+ Args:
+ aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
+ images (Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1].
+ patch_start_idx (int): Starting index for patch tokens in the token sequence.
+ Used to separate patch tokens from other tokens (e.g., camera or register tokens).
+ frames_chunk_size (int, optional): Number of frames to process in each chunk.
+ If None or larger than S, all frames are processed at once. Default: 8.
+
+ Returns:
+ Tensor or Tuple[Tensor, Tensor]:
+ - If feature_only=True: Feature maps with shape [B, S, C, H, W]
+ - Otherwise: Tuple of (predictions, confidence) both with shape [B, S, 1, H, W]
+ """
+ B, S, _, H, W = images.shape
+
+ # If frames_chunk_size is not specified or greater than S, process all frames at once
+ if frames_chunk_size is None or frames_chunk_size >= S:
+ return self._forward_impl(aggregated_tokens_list, images, patch_start_idx)
+
+ # Otherwise, process frames in chunks to manage memory usage
+ assert frames_chunk_size > 0
+
+ # Process frames in batches
+ all_preds = []
+ all_conf = []
+
+ for frames_start_idx in range(0, S, frames_chunk_size):
+ frames_end_idx = min(frames_start_idx + frames_chunk_size, S)
+
+ # Process batch of frames
+ if self.feature_only:
+ chunk_output = self._forward_impl(
+ aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx
+ )
+ all_preds.append(chunk_output)
+ else:
+ chunk_preds, chunk_conf = self._forward_impl(
+ aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx
+ )
+ all_preds.append(chunk_preds)
+ all_conf.append(chunk_conf)
+
+ # Concatenate results along the sequence dimension
+ if self.feature_only:
+ return torch.cat(all_preds, dim=1)
+ else:
+ return torch.cat(all_preds, dim=1), torch.cat(all_conf, dim=1)
+
+ def _forward_impl(
+ self,
+ aggregated_tokens_list: List[torch.Tensor],
+ images: torch.Tensor,
+ patch_start_idx: int,
+ frames_start_idx: int = None,
+ frames_end_idx: int = None,
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+ """
+ Implementation of the forward pass through the DPT head.
+
+ This method processes a specific chunk of frames from the sequence.
+
+ Args:
+ aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
+ images (Tensor): Input images with shape [B, S, 3, H, W].
+ patch_start_idx (int): Starting index for patch tokens.
+ frames_start_idx (int, optional): Starting index for frames to process.
+ frames_end_idx (int, optional): Ending index for frames to process.
+
+ Returns:
+ Tensor or Tuple[Tensor, Tensor]: Feature maps or (predictions, confidence).
+ """
+ if frames_start_idx is not None and frames_end_idx is not None:
+ images = images[:, frames_start_idx:frames_end_idx].contiguous()
+
+ B, S, _, H, W = images.shape
+
+ patch_h, patch_w = H // self.patch_size, W // self.patch_size
+
+ out = []
+ dpt_idx = 0
+
+ for layer_idx in self.intermediate_layer_idx:
+ x = aggregated_tokens_list[layer_idx][:, :, patch_start_idx:]
+
+ # Select frames if processing a chunk
+ if frames_start_idx is not None and frames_end_idx is not None:
+ x = x[:, frames_start_idx:frames_end_idx]
+
+ x = x.reshape(B * S, -1, x.shape[-1])
+
+ x = self.norm(x)
+
+ x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
+
+ x = self.projects[dpt_idx](x)
+ if self.pos_embed:
+ x = self._apply_pos_embed(x, W, H)
+ x = self.resize_layers[dpt_idx](x)
+
+ out.append(x)
+ dpt_idx += 1
+
+ # Fuse features from multiple layers.
+ out = self.scratch_forward(out)
+ # Interpolate fused output to match target image resolution.
+ out = custom_interpolate(
+ out,
+ (int(patch_h * self.patch_size / self.down_ratio), int(patch_w * self.patch_size / self.down_ratio)),
+ mode="bilinear",
+ align_corners=True,
+ )
+
+ if self.pos_embed:
+ out = self._apply_pos_embed(out, W, H)
+
+ if self.feature_only:
+ return out.view(B, S, *out.shape[1:])
+
+ out = self.scratch.output_conv2(out)
+ preds, conf = activate_head(out, activation=self.activation, conf_activation=self.conf_activation)
+
+ preds = preds.view(B, S, *preds.shape[1:])
+ conf = conf.view(B, S, *conf.shape[1:])
+ return preds, conf
+
+ def _apply_pos_embed(self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor:
+ """
+ Apply positional embedding to tensor x.
+ """
+ patch_w = x.shape[-1]
+ patch_h = x.shape[-2]
+ pos_embed = create_uv_grid(patch_w, patch_h, aspect_ratio=W / H, dtype=x.dtype, device=x.device)
+ pos_embed = position_grid_to_embed(pos_embed, x.shape[1])
+ pos_embed = pos_embed * ratio
+ pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1)
+ return x + pos_embed
+
+ def scratch_forward(self, features: List[torch.Tensor]) -> torch.Tensor:
+ """
+ Forward pass through the fusion blocks.
+
+ Args:
+ features (List[Tensor]): List of feature maps from different layers.
+
+ Returns:
+ Tensor: Fused feature map.
+ """
+ layer_1, layer_2, layer_3, layer_4 = features
+
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
+
+ out = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
+ del layer_4_rn, layer_4
+
+ out = self.scratch.refinenet3(out, layer_3_rn, size=layer_2_rn.shape[2:])
+ del layer_3_rn, layer_3
+
+ out = self.scratch.refinenet2(out, layer_2_rn, size=layer_1_rn.shape[2:])
+ del layer_2_rn, layer_2
+
+ out = self.scratch.refinenet1(out, layer_1_rn)
+ del layer_1_rn, layer_1
+
+ out = self.scratch.output_conv1(out)
+ return out
+
+
+################################################################################
+# Modules
+################################################################################
+
+
+def _make_fusion_block(features: int, size: int = None, has_residual: bool = True, groups: int = 1) -> nn.Module:
+ return FeatureFusionBlock(
+ features,
+ nn.ReLU(inplace=True),
+ deconv=False,
+ bn=False,
+ expand=False,
+ align_corners=True,
+ size=size,
+ has_residual=has_residual,
+ groups=groups,
+ )
+
+
+def _make_scratch(in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False) -> nn.Module:
+ scratch = nn.Module()
+ out_shape1 = out_shape
+ out_shape2 = out_shape
+ out_shape3 = out_shape
+ if len(in_shape) >= 4:
+ out_shape4 = out_shape
+
+ if expand:
+ out_shape1 = out_shape
+ out_shape2 = out_shape * 2
+ out_shape3 = out_shape * 4
+ if len(in_shape) >= 4:
+ out_shape4 = out_shape * 8
+
+ scratch.layer1_rn = nn.Conv2d(
+ in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+ scratch.layer2_rn = nn.Conv2d(
+ in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+ scratch.layer3_rn = nn.Conv2d(
+ in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+ if len(in_shape) >= 4:
+ scratch.layer4_rn = nn.Conv2d(
+ in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+ return scratch
+
+
+class ResidualConvUnit(nn.Module):
+ """Residual convolution module."""
+
+ def __init__(self, features, activation, bn, groups=1):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super().__init__()
+
+ self.bn = bn
+ self.groups = groups
+ self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
+ self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
+
+ self.norm1 = None
+ self.norm2 = None
+
+ self.activation = activation
+ self.skip_add = nn.quantized.FloatFunctional()
+
+ def forward(self, x):
+ """Forward pass.
+
+ Args:
+ x (tensor): input
+
+ Returns:
+ tensor: output
+ """
+
+ out = self.activation(x)
+ out = self.conv1(out)
+ if self.norm1 is not None:
+ out = self.norm1(out)
+
+ out = self.activation(out)
+ out = self.conv2(out)
+ if self.norm2 is not None:
+ out = self.norm2(out)
+
+ return self.skip_add.add(out, x)
+
+
+class FeatureFusionBlock(nn.Module):
+ """Feature fusion block."""
+
+ def __init__(
+ self,
+ features,
+ activation,
+ deconv=False,
+ bn=False,
+ expand=False,
+ align_corners=True,
+ size=None,
+ has_residual=True,
+ groups=1,
+ ):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super(FeatureFusionBlock, self).__init__()
+
+ self.deconv = deconv
+ self.align_corners = align_corners
+ self.groups = groups
+ self.expand = expand
+ out_features = features
+ if self.expand == True:
+ out_features = features // 2
+
+ self.out_conv = nn.Conv2d(
+ features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=self.groups
+ )
+
+ if has_residual:
+ self.resConfUnit1 = ResidualConvUnit(features, activation, bn, groups=self.groups)
+
+ self.has_residual = has_residual
+ self.resConfUnit2 = ResidualConvUnit(features, activation, bn, groups=self.groups)
+
+ self.skip_add = nn.quantized.FloatFunctional()
+ self.size = size
+
+ def forward(self, *xs, size=None):
+ """Forward pass.
+
+ Returns:
+ tensor: output
+ """
+ output = xs[0]
+
+ if self.has_residual:
+ res = self.resConfUnit1(xs[1])
+ output = self.skip_add.add(output, res)
+
+ output = self.resConfUnit2(output)
+
+ if (size is None) and (self.size is None):
+ modifier = {"scale_factor": 2}
+ elif size is None:
+ modifier = {"size": self.size}
+ else:
+ modifier = {"size": size}
+
+ output = custom_interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
+ output = self.out_conv(output)
+
+ return output
+
+
+def custom_interpolate(
+ x: torch.Tensor,
+ size: Tuple[int, int] = None,
+ scale_factor: float = None,
+ mode: str = "bilinear",
+ align_corners: bool = True,
+) -> torch.Tensor:
+ """
+ Custom interpolate to avoid INT_MAX issues in nn.functional.interpolate.
+ """
+ if size is None:
+ size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor))
+
+ INT_MAX = 1610612736
+
+ input_elements = size[0] * size[1] * x.shape[0] * x.shape[1]
+
+ if input_elements > INT_MAX:
+ chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0)
+ interpolated_chunks = [
+ nn.functional.interpolate(chunk, size=size, mode=mode, align_corners=align_corners) for chunk in chunks
+ ]
+ x = torch.cat(interpolated_chunks, dim=0)
+ return x.contiguous()
+ else:
+ return nn.functional.interpolate(x, size=size, mode=mode, align_corners=align_corners)
diff --git a/libs/vggt/heads/head_act.py b/libs/vggt/heads/head_act.py
new file mode 100644
index 0000000000000000000000000000000000000000..2dedfcf1180a653dddc99623e60df625e5897489
--- /dev/null
+++ b/libs/vggt/heads/head_act.py
@@ -0,0 +1,125 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+import torch
+import torch.nn.functional as F
+
+
+def activate_pose(pred_pose_enc, trans_act="linear", quat_act="linear", fl_act="linear"):
+ """
+ Activate pose parameters with specified activation functions.
+
+ Args:
+ pred_pose_enc: Tensor containing encoded pose parameters [translation, quaternion, focal length]
+ trans_act: Activation type for translation component
+ quat_act: Activation type for quaternion component
+ fl_act: Activation type for focal length component
+
+ Returns:
+ Activated pose parameters tensor
+ """
+ T = pred_pose_enc[..., :3]
+ quat = pred_pose_enc[..., 3:7]
+ fl = pred_pose_enc[..., 7:] # or fov
+
+ T = base_pose_act(T, trans_act)
+ quat = base_pose_act(quat, quat_act)
+ fl = base_pose_act(fl, fl_act) # or fov
+
+ pred_pose_enc = torch.cat([T, quat, fl], dim=-1)
+
+ return pred_pose_enc
+
+
+def base_pose_act(pose_enc, act_type="linear"):
+ """
+ Apply basic activation function to pose parameters.
+
+ Args:
+ pose_enc: Tensor containing encoded pose parameters
+ act_type: Activation type ("linear", "inv_log", "exp", "relu")
+
+ Returns:
+ Activated pose parameters
+ """
+ if act_type == "linear":
+ return pose_enc
+ elif act_type == "inv_log":
+ return inverse_log_transform(pose_enc)
+ elif act_type == "exp":
+ return torch.exp(pose_enc)
+ elif act_type == "relu":
+ return F.relu(pose_enc)
+ else:
+ raise ValueError(f"Unknown act_type: {act_type}")
+
+
+def activate_head(out, activation="norm_exp", conf_activation="expp1"):
+ """
+ Process network output to extract 3D points and confidence values.
+
+ Args:
+ out: Network output tensor (B, C, H, W)
+ activation: Activation type for 3D points
+ conf_activation: Activation type for confidence values
+
+ Returns:
+ Tuple of (3D points tensor, confidence tensor)
+ """
+ # Move channels from last dim to the 4th dimension => (B, H, W, C)
+ fmap = out.permute(0, 2, 3, 1) # B,H,W,C expected
+
+ # Split into xyz (first C-1 channels) and confidence (last channel)
+ xyz = fmap[:, :, :, :-1]
+ conf = fmap[:, :, :, -1]
+
+ if activation == "norm_exp":
+ d = xyz.norm(dim=-1, keepdim=True).clamp(min=1e-8)
+ xyz_normed = xyz / d
+ pts3d = xyz_normed * torch.expm1(d)
+ elif activation == "norm":
+ pts3d = xyz / xyz.norm(dim=-1, keepdim=True)
+ elif activation == "exp":
+ pts3d = torch.exp(xyz)
+ elif activation == "relu":
+ pts3d = F.relu(xyz)
+ elif activation == "inv_log":
+ pts3d = inverse_log_transform(xyz)
+ elif activation == "xy_inv_log":
+ xy, z = xyz.split([2, 1], dim=-1)
+ z = inverse_log_transform(z)
+ pts3d = torch.cat([xy * z, z], dim=-1)
+ elif activation == "sigmoid":
+ pts3d = torch.sigmoid(xyz)
+ elif activation == "linear":
+ pts3d = xyz
+ else:
+ raise ValueError(f"Unknown activation: {activation}")
+
+ if conf_activation == "expp1":
+ conf_out = 1 + conf.exp()
+ elif conf_activation == "expp0":
+ conf_out = conf.exp()
+ elif conf_activation == "sigmoid":
+ conf_out = torch.sigmoid(conf)
+ else:
+ raise ValueError(f"Unknown conf_activation: {conf_activation}")
+
+ return pts3d, conf_out
+
+
+def inverse_log_transform(y):
+ """
+ Apply inverse log transform: sign(y) * (exp(|y|) - 1)
+
+ Args:
+ y: Input tensor
+
+ Returns:
+ Transformed tensor
+ """
+ return torch.sign(y) * (torch.expm1(torch.abs(y)))
diff --git a/libs/vggt/heads/track_head.py b/libs/vggt/heads/track_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..a4f1d9bd83cca1f74f97a644a02b984904f84706
--- /dev/null
+++ b/libs/vggt/heads/track_head.py
@@ -0,0 +1,104 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch.nn as nn
+from .dpt_head import DPTHead
+from .track_modules.base_track_predictor import BaseTrackerPredictor
+
+
+class TrackHead(nn.Module):
+ """
+ Track head that uses DPT head to process tokens and BaseTrackerPredictor for tracking.
+ The tracking is performed iteratively, refining predictions over multiple iterations.
+ """
+
+ def __init__(
+ self,
+ dim_in,
+ patch_size=14,
+ features=128,
+ iters=4,
+ predict_conf=True,
+ stride=2,
+ corr_levels=7,
+ corr_radius=4,
+ hidden_size=384,
+ ):
+ """
+ Initialize the TrackHead module.
+
+ Args:
+ dim_in (int): Input dimension of tokens from the backbone.
+ patch_size (int): Size of image patches used in the vision transformer.
+ features (int): Number of feature channels in the feature extractor output.
+ iters (int): Number of refinement iterations for tracking predictions.
+ predict_conf (bool): Whether to predict confidence scores for tracked points.
+ stride (int): Stride value for the tracker predictor.
+ corr_levels (int): Number of correlation pyramid levels
+ corr_radius (int): Radius for correlation computation, controlling the search area.
+ hidden_size (int): Size of hidden layers in the tracker network.
+ """
+ super().__init__()
+
+ self.patch_size = patch_size
+
+ # Feature extractor based on DPT architecture
+ # Processes tokens into feature maps for tracking
+ self.feature_extractor = DPTHead(
+ dim_in=dim_in,
+ patch_size=patch_size,
+ features=features,
+ feature_only=True, # Only output features, no activation
+ down_ratio=2, # Reduces spatial dimensions by factor of 2
+ pos_embed=False,
+ )
+
+ # Tracker module that predicts point trajectories
+ # Takes feature maps and predicts coordinates and visibility
+ self.tracker = BaseTrackerPredictor(
+ latent_dim=features, # Match the output_dim of feature extractor
+ predict_conf=predict_conf,
+ stride=stride,
+ corr_levels=corr_levels,
+ corr_radius=corr_radius,
+ hidden_size=hidden_size,
+ )
+
+ self.iters = iters
+
+ def forward(self, aggregated_tokens_list, images, patch_start_idx, query_points=None, iters=None):
+ """
+ Forward pass of the TrackHead.
+
+ Args:
+ aggregated_tokens_list (list): List of aggregated tokens from the backbone.
+ images (torch.Tensor): Input images of shape (B, S, C, H, W) where:
+ B = batch size, S = sequence length.
+ patch_start_idx (int): Starting index for patch tokens.
+ query_points (torch.Tensor, optional): Initial query points to track.
+ If None, points are initialized by the tracker.
+ iters (int, optional): Number of refinement iterations. If None, uses self.iters.
+
+ Returns:
+ tuple:
+ - coord_preds (torch.Tensor): Predicted coordinates for tracked points.
+ - vis_scores (torch.Tensor): Visibility scores for tracked points.
+ - conf_scores (torch.Tensor): Confidence scores for tracked points (if predict_conf=True).
+ """
+ B, S, _, H, W = images.shape
+
+ # Extract features from tokens
+ # feature_maps has shape (B, S, C, H//2, W//2) due to down_ratio=2
+ feature_maps = self.feature_extractor(aggregated_tokens_list, images, patch_start_idx)
+
+ # Use default iterations if not specified
+ if iters is None:
+ iters = self.iters
+
+ # Perform tracking using the extracted features
+ coord_preds, vis_scores, conf_scores = self.tracker(query_points=query_points, fmaps=feature_maps, iters=iters)
+
+ return coord_preds, vis_scores, conf_scores
diff --git a/libs/vggt/heads/track_modules/__init__.py b/libs/vggt/heads/track_modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0952fcc3f57e34b3747962e9ebd6fc57aeea63fa
--- /dev/null
+++ b/libs/vggt/heads/track_modules/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
diff --git a/libs/vggt/heads/track_modules/base_track_predictor.py b/libs/vggt/heads/track_modules/base_track_predictor.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ce8ec4b66fff236e015d1bcaf85c8237a52be7a
--- /dev/null
+++ b/libs/vggt/heads/track_modules/base_track_predictor.py
@@ -0,0 +1,209 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+from einops import rearrange, repeat
+
+
+from .blocks import EfficientUpdateFormer, CorrBlock
+from .utils import sample_features4d, get_2d_embedding, get_2d_sincos_pos_embed
+from .modules import Mlp
+
+
+class BaseTrackerPredictor(nn.Module):
+ def __init__(
+ self,
+ stride=1,
+ corr_levels=5,
+ corr_radius=4,
+ latent_dim=128,
+ hidden_size=384,
+ use_spaceatt=True,
+ depth=6,
+ max_scale=518,
+ predict_conf=True,
+ ):
+ super(BaseTrackerPredictor, self).__init__()
+ """
+ The base template to create a track predictor
+
+ Modified from https://github.com/facebookresearch/co-tracker/
+ and https://github.com/facebookresearch/vggsfm
+ """
+
+ self.stride = stride
+ self.latent_dim = latent_dim
+ self.corr_levels = corr_levels
+ self.corr_radius = corr_radius
+ self.hidden_size = hidden_size
+ self.max_scale = max_scale
+ self.predict_conf = predict_conf
+
+ self.flows_emb_dim = latent_dim // 2
+
+ self.corr_mlp = Mlp(
+ in_features=self.corr_levels * (self.corr_radius * 2 + 1) ** 2,
+ hidden_features=self.hidden_size,
+ out_features=self.latent_dim,
+ )
+
+ self.transformer_dim = self.latent_dim + self.latent_dim + self.latent_dim + 4
+
+ self.query_ref_token = nn.Parameter(torch.randn(1, 2, self.transformer_dim))
+
+ space_depth = depth if use_spaceatt else 0
+ time_depth = depth
+
+ self.updateformer = EfficientUpdateFormer(
+ space_depth=space_depth,
+ time_depth=time_depth,
+ input_dim=self.transformer_dim,
+ hidden_size=self.hidden_size,
+ output_dim=self.latent_dim + 2,
+ mlp_ratio=4.0,
+ add_space_attn=use_spaceatt,
+ )
+
+ self.fmap_norm = nn.LayerNorm(self.latent_dim)
+ self.ffeat_norm = nn.GroupNorm(1, self.latent_dim)
+
+ # A linear layer to update track feats at each iteration
+ self.ffeat_updater = nn.Sequential(nn.Linear(self.latent_dim, self.latent_dim), nn.GELU())
+
+ self.vis_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1))
+
+ if predict_conf:
+ self.conf_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1))
+
+ def forward(self, query_points, fmaps=None, iters=6, return_feat=False, down_ratio=1, apply_sigmoid=True):
+ """
+ query_points: B x N x 2, the number of batches, tracks, and xy
+ fmaps: B x S x C x HH x WW, the number of batches, frames, and feature dimension.
+ note HH and WW is the size of feature maps instead of original images
+ """
+ B, N, D = query_points.shape
+ B, S, C, HH, WW = fmaps.shape
+
+ assert D == 2, "Input points must be 2D coordinates"
+
+ # apply a layernorm to fmaps here
+ fmaps = self.fmap_norm(fmaps.permute(0, 1, 3, 4, 2))
+ fmaps = fmaps.permute(0, 1, 4, 2, 3)
+
+ # Scale the input query_points because we may downsample the images
+ # by down_ratio or self.stride
+ # e.g., if a 3x1024x1024 image is processed to a 128x256x256 feature map
+ # its query_points should be query_points/4
+ if down_ratio > 1:
+ query_points = query_points / float(down_ratio)
+
+ query_points = query_points / float(self.stride)
+
+ # Init with coords as the query points
+ # It means the search will start from the position of query points at the reference frames
+ coords = query_points.clone().reshape(B, 1, N, 2).repeat(1, S, 1, 1)
+
+ # Sample/extract the features of the query points in the query frame
+ query_track_feat = sample_features4d(fmaps[:, 0], coords[:, 0])
+
+ # init track feats by query feats
+ track_feats = query_track_feat.unsqueeze(1).repeat(1, S, 1, 1) # B, S, N, C
+ # back up the init coords
+ coords_backup = coords.clone()
+
+ fcorr_fn = CorrBlock(fmaps, num_levels=self.corr_levels, radius=self.corr_radius)
+
+ coord_preds = []
+
+ # Iterative Refinement
+ for _ in range(iters):
+ # Detach the gradients from the last iteration
+ # (in my experience, not very important for performance)
+ coords = coords.detach()
+
+ fcorrs = fcorr_fn.corr_sample(track_feats, coords)
+
+ corr_dim = fcorrs.shape[3]
+ fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, corr_dim)
+ fcorrs_ = self.corr_mlp(fcorrs_)
+
+ # Movement of current coords relative to query points
+ flows = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 2)
+
+ flows_emb = get_2d_embedding(flows, self.flows_emb_dim, cat_coords=False)
+
+ # (In my trials, it is also okay to just add the flows_emb instead of concat)
+ flows_emb = torch.cat([flows_emb, flows / self.max_scale, flows / self.max_scale], dim=-1)
+
+ track_feats_ = track_feats.permute(0, 2, 1, 3).reshape(B * N, S, self.latent_dim)
+
+ # Concatenate them as the input for the transformers
+ transformer_input = torch.cat([flows_emb, fcorrs_, track_feats_], dim=2)
+
+ # 2D positional embed
+ # TODO: this can be much simplified
+ pos_embed = get_2d_sincos_pos_embed(self.transformer_dim, grid_size=(HH, WW)).to(query_points.device)
+ sampled_pos_emb = sample_features4d(pos_embed.expand(B, -1, -1, -1), coords[:, 0])
+
+ sampled_pos_emb = rearrange(sampled_pos_emb, "b n c -> (b n) c").unsqueeze(1)
+
+ x = transformer_input + sampled_pos_emb
+
+ # Add the query ref token to the track feats
+ query_ref_token = torch.cat(
+ [self.query_ref_token[:, 0:1], self.query_ref_token[:, 1:2].expand(-1, S - 1, -1)], dim=1
+ )
+ x = x + query_ref_token.to(x.device).to(x.dtype)
+
+ # B, N, S, C
+ x = rearrange(x, "(b n) s d -> b n s d", b=B)
+
+ # Compute the delta coordinates and delta track features
+ delta, _ = self.updateformer(x)
+
+ # BN, S, C
+ delta = rearrange(delta, " b n s d -> (b n) s d", b=B)
+ delta_coords_ = delta[:, :, :2]
+ delta_feats_ = delta[:, :, 2:]
+
+ track_feats_ = track_feats_.reshape(B * N * S, self.latent_dim)
+ delta_feats_ = delta_feats_.reshape(B * N * S, self.latent_dim)
+
+ # Update the track features
+ track_feats_ = self.ffeat_updater(self.ffeat_norm(delta_feats_)) + track_feats_
+
+ track_feats = track_feats_.reshape(B, N, S, self.latent_dim).permute(0, 2, 1, 3) # BxSxNxC
+
+ # B x S x N x 2
+ coords = coords + delta_coords_.reshape(B, N, S, 2).permute(0, 2, 1, 3)
+
+ # Force coord0 as query
+ # because we assume the query points should not be changed
+ coords[:, 0] = coords_backup[:, 0]
+
+ # The predicted tracks are in the original image scale
+ if down_ratio > 1:
+ coord_preds.append(coords * self.stride * down_ratio)
+ else:
+ coord_preds.append(coords * self.stride)
+
+ # B, S, N
+ vis_e = self.vis_predictor(track_feats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N)
+ if apply_sigmoid:
+ vis_e = torch.sigmoid(vis_e)
+
+ if self.predict_conf:
+ conf_e = self.conf_predictor(track_feats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N)
+ if apply_sigmoid:
+ conf_e = torch.sigmoid(conf_e)
+ else:
+ conf_e = None
+
+ if return_feat:
+ return coord_preds, vis_e, track_feats, query_track_feat, conf_e
+ else:
+ return coord_preds, vis_e, conf_e
diff --git a/libs/vggt/heads/track_modules/blocks.py b/libs/vggt/heads/track_modules/blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..15c161c89ef99742b0f2c6f397c9121fe9301e08
--- /dev/null
+++ b/libs/vggt/heads/track_modules/blocks.py
@@ -0,0 +1,236 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+# Modified from https://github.com/facebookresearch/co-tracker/
+
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .utils import bilinear_sampler
+from .modules import Mlp, AttnBlock, CrossAttnBlock, ResidualBlock
+
+
+class EfficientUpdateFormer(nn.Module):
+ """
+ Transformer model that updates track estimates.
+ """
+
+ def __init__(
+ self,
+ space_depth=6,
+ time_depth=6,
+ input_dim=320,
+ hidden_size=384,
+ num_heads=8,
+ output_dim=130,
+ mlp_ratio=4.0,
+ add_space_attn=True,
+ num_virtual_tracks=64,
+ ):
+ super().__init__()
+
+ self.out_channels = 2
+ self.num_heads = num_heads
+ self.hidden_size = hidden_size
+ self.add_space_attn = add_space_attn
+
+ # Add input LayerNorm before linear projection
+ self.input_norm = nn.LayerNorm(input_dim)
+ self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True)
+
+ # Add output LayerNorm before final projection
+ self.output_norm = nn.LayerNorm(hidden_size)
+ self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True)
+ self.num_virtual_tracks = num_virtual_tracks
+
+ if self.add_space_attn:
+ self.virual_tracks = nn.Parameter(torch.randn(1, num_virtual_tracks, 1, hidden_size))
+ else:
+ self.virual_tracks = None
+
+ self.time_blocks = nn.ModuleList(
+ [
+ AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, attn_class=nn.MultiheadAttention)
+ for _ in range(time_depth)
+ ]
+ )
+
+ if add_space_attn:
+ self.space_virtual_blocks = nn.ModuleList(
+ [
+ AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, attn_class=nn.MultiheadAttention)
+ for _ in range(space_depth)
+ ]
+ )
+ self.space_point2virtual_blocks = nn.ModuleList(
+ [CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)]
+ )
+ self.space_virtual2point_blocks = nn.ModuleList(
+ [CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)]
+ )
+ assert len(self.time_blocks) >= len(self.space_virtual2point_blocks)
+ self.initialize_weights()
+
+ def initialize_weights(self):
+ def _basic_init(module):
+ if isinstance(module, nn.Linear):
+ torch.nn.init.xavier_uniform_(module.weight)
+ if module.bias is not None:
+ nn.init.constant_(module.bias, 0)
+ torch.nn.init.trunc_normal_(self.flow_head.weight, std=0.001)
+
+ self.apply(_basic_init)
+
+ def forward(self, input_tensor, mask=None):
+ # Apply input LayerNorm
+ input_tensor = self.input_norm(input_tensor)
+ tokens = self.input_transform(input_tensor)
+
+ init_tokens = tokens
+
+ B, _, T, _ = tokens.shape
+
+ if self.add_space_attn:
+ virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1)
+ tokens = torch.cat([tokens, virtual_tokens], dim=1)
+
+ _, N, _, _ = tokens.shape
+
+ j = 0
+ for i in range(len(self.time_blocks)):
+ time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C
+
+ time_tokens = self.time_blocks[i](time_tokens)
+
+ tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C
+ if self.add_space_attn and (i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0):
+ space_tokens = tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1) # B N T C -> (B T) N C
+ point_tokens = space_tokens[:, : N - self.num_virtual_tracks]
+ virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :]
+
+ virtual_tokens = self.space_virtual2point_blocks[j](virtual_tokens, point_tokens, mask=mask)
+ virtual_tokens = self.space_virtual_blocks[j](virtual_tokens)
+ point_tokens = self.space_point2virtual_blocks[j](point_tokens, virtual_tokens, mask=mask)
+
+ space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1)
+ tokens = space_tokens.view(B, T, N, -1).permute(0, 2, 1, 3) # (B T) N C -> B N T C
+ j += 1
+
+ if self.add_space_attn:
+ tokens = tokens[:, : N - self.num_virtual_tracks]
+
+ tokens = tokens + init_tokens
+
+ # Apply output LayerNorm before final projection
+ tokens = self.output_norm(tokens)
+ flow = self.flow_head(tokens)
+
+ return flow, None
+
+
+class CorrBlock:
+ def __init__(self, fmaps, num_levels=4, radius=4, multiple_track_feats=False, padding_mode="zeros"):
+ """
+ Build a pyramid of feature maps from the input.
+
+ fmaps: Tensor (B, S, C, H, W)
+ num_levels: number of pyramid levels (each downsampled by factor 2)
+ radius: search radius for sampling correlation
+ multiple_track_feats: if True, split the target features per pyramid level
+ padding_mode: passed to grid_sample / bilinear_sampler
+ """
+ B, S, C, H, W = fmaps.shape
+ self.S, self.C, self.H, self.W = S, C, H, W
+ self.num_levels = num_levels
+ self.radius = radius
+ self.padding_mode = padding_mode
+ self.multiple_track_feats = multiple_track_feats
+
+ # Build pyramid: each level is half the spatial resolution of the previous
+ self.fmaps_pyramid = [fmaps] # level 0 is full resolution
+ current_fmaps = fmaps
+ for i in range(num_levels - 1):
+ B, S, C, H, W = current_fmaps.shape
+ # Merge batch & sequence dimensions
+ current_fmaps = current_fmaps.reshape(B * S, C, H, W)
+ # Avg pool down by factor 2
+ current_fmaps = F.avg_pool2d(current_fmaps, kernel_size=2, stride=2)
+ _, _, H_new, W_new = current_fmaps.shape
+ current_fmaps = current_fmaps.reshape(B, S, C, H_new, W_new)
+ self.fmaps_pyramid.append(current_fmaps)
+
+ # Precompute a delta grid (of shape (2r+1, 2r+1, 2)) for sampling.
+ # This grid is added to the (scaled) coordinate centroids.
+ r = self.radius
+ dx = torch.linspace(-r, r, 2 * r + 1, device=fmaps.device, dtype=fmaps.dtype)
+ dy = torch.linspace(-r, r, 2 * r + 1, device=fmaps.device, dtype=fmaps.dtype)
+ # delta: for every (dy,dx) displacement (i.e. Δx, Δy)
+ self.delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), dim=-1) # shape: (2r+1, 2r+1, 2)
+
+ def corr_sample(self, targets, coords):
+ """
+ Instead of storing the entire correlation pyramid, we compute each level's correlation
+ volume, sample it immediately, then discard it. This saves GPU memory.
+
+ Args:
+ targets: Tensor (B, S, N, C) — features for the current targets.
+ coords: Tensor (B, S, N, 2) — coordinates at full resolution.
+
+ Returns:
+ Tensor (B, S, N, L) where L = num_levels * (2*radius+1)**2 (concatenated sampled correlations)
+ """
+ B, S, N, C = targets.shape
+
+ # If you have multiple track features, split them per level.
+ if self.multiple_track_feats:
+ targets_split = torch.split(targets, C // self.num_levels, dim=-1)
+
+ out_pyramid = []
+ for i, fmaps in enumerate(self.fmaps_pyramid):
+ # Get current spatial resolution H, W for this pyramid level.
+ B, S, C, H, W = fmaps.shape
+ # Reshape feature maps for correlation computation:
+ # fmap2s: (B, S, C, H*W)
+ fmap2s = fmaps.view(B, S, C, H * W)
+ # Choose appropriate target features.
+ fmap1 = targets_split[i] if self.multiple_track_feats else targets # shape: (B, S, N, C)
+
+ # Compute correlation directly
+ corrs = compute_corr_level(fmap1, fmap2s, C)
+ corrs = corrs.view(B, S, N, H, W)
+
+ # Prepare sampling grid:
+ # Scale down the coordinates for the current level.
+ centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / (2**i)
+ # Make sure our precomputed delta grid is on the same device/dtype.
+ delta_lvl = self.delta.to(coords.device).to(coords.dtype)
+ # Now the grid for grid_sample is:
+ # coords_lvl = centroid_lvl + delta_lvl (broadcasted over grid)
+ coords_lvl = centroid_lvl + delta_lvl.view(1, 2 * self.radius + 1, 2 * self.radius + 1, 2)
+
+ # Sample from the correlation volume using bilinear interpolation.
+ # We reshape corrs to (B * S * N, 1, H, W) so grid_sample acts over each target.
+ corrs_sampled = bilinear_sampler(
+ corrs.reshape(B * S * N, 1, H, W), coords_lvl, padding_mode=self.padding_mode
+ )
+ # The sampled output is (B * S * N, 1, 2r+1, 2r+1). Flatten the last two dims.
+ corrs_sampled = corrs_sampled.view(B, S, N, -1) # Now shape: (B, S, N, (2r+1)^2)
+ out_pyramid.append(corrs_sampled)
+
+ # Concatenate all levels along the last dimension.
+ out = torch.cat(out_pyramid, dim=-1).contiguous()
+ return out
+
+
+def compute_corr_level(fmap1, fmap2s, C):
+ # fmap1: (B, S, N, C)
+ # fmap2s: (B, S, C, H*W)
+ corrs = torch.matmul(fmap1, fmap2s) # (B, S, N, H*W)
+ corrs = corrs.view(fmap1.shape[0], fmap1.shape[1], fmap1.shape[2], -1) # (B, S, N, H*W)
+ return corrs / math.sqrt(C)
diff --git a/libs/vggt/heads/track_modules/modules.py b/libs/vggt/heads/track_modules/modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..12de4f1ad76364d4665e53ac80e1037fadf98d08
--- /dev/null
+++ b/libs/vggt/heads/track_modules/modules.py
@@ -0,0 +1,204 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from functools import partial
+from typing import Callable
+import collections
+from torch import Tensor
+from itertools import repeat
+
+
+# From PyTorch internals
+def _ntuple(n):
+ def parse(x):
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
+ return tuple(x)
+ return tuple(repeat(x, n))
+
+ return parse
+
+
+def exists(val):
+ return val is not None
+
+
+def default(val, d):
+ return val if exists(val) else d
+
+
+to_2tuple = _ntuple(2)
+
+
+class ResidualBlock(nn.Module):
+ """
+ ResidualBlock: construct a block of two conv layers with residual connections
+ """
+
+ def __init__(self, in_planes, planes, norm_fn="group", stride=1, kernel_size=3):
+ super(ResidualBlock, self).__init__()
+
+ self.conv1 = nn.Conv2d(
+ in_planes, planes, kernel_size=kernel_size, padding=1, stride=stride, padding_mode="zeros"
+ )
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=kernel_size, padding=1, padding_mode="zeros")
+ self.relu = nn.ReLU(inplace=True)
+
+ num_groups = planes // 8
+
+ if norm_fn == "group":
+ self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
+ self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
+ if not stride == 1:
+ self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
+
+ elif norm_fn == "batch":
+ self.norm1 = nn.BatchNorm2d(planes)
+ self.norm2 = nn.BatchNorm2d(planes)
+ if not stride == 1:
+ self.norm3 = nn.BatchNorm2d(planes)
+
+ elif norm_fn == "instance":
+ self.norm1 = nn.InstanceNorm2d(planes)
+ self.norm2 = nn.InstanceNorm2d(planes)
+ if not stride == 1:
+ self.norm3 = nn.InstanceNorm2d(planes)
+
+ elif norm_fn == "none":
+ self.norm1 = nn.Sequential()
+ self.norm2 = nn.Sequential()
+ if not stride == 1:
+ self.norm3 = nn.Sequential()
+ else:
+ raise NotImplementedError
+
+ if stride == 1:
+ self.downsample = None
+ else:
+ self.downsample = nn.Sequential(nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
+
+ def forward(self, x):
+ y = x
+ y = self.relu(self.norm1(self.conv1(y)))
+ y = self.relu(self.norm2(self.conv2(y)))
+
+ if self.downsample is not None:
+ x = self.downsample(x)
+
+ return self.relu(x + y)
+
+
+class Mlp(nn.Module):
+ """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
+
+ def __init__(
+ self,
+ in_features,
+ hidden_features=None,
+ out_features=None,
+ act_layer=nn.GELU,
+ norm_layer=None,
+ bias=True,
+ drop=0.0,
+ use_conv=False,
+ ):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ bias = to_2tuple(bias)
+ drop_probs = to_2tuple(drop)
+ linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
+
+ self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
+ self.act = act_layer()
+ self.drop1 = nn.Dropout(drop_probs[0])
+ self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
+ self.drop2 = nn.Dropout(drop_probs[1])
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop1(x)
+ x = self.fc2(x)
+ x = self.drop2(x)
+ return x
+
+
+class AttnBlock(nn.Module):
+ def __init__(
+ self,
+ hidden_size,
+ num_heads,
+ attn_class: Callable[..., nn.Module] = nn.MultiheadAttention,
+ mlp_ratio=4.0,
+ **block_kwargs,
+ ):
+ """
+ Self attention block
+ """
+ super().__init__()
+
+ self.norm1 = nn.LayerNorm(hidden_size)
+ self.norm2 = nn.LayerNorm(hidden_size)
+
+ self.attn = attn_class(embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs)
+
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
+
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0)
+
+ def forward(self, x, mask=None):
+ # Prepare the mask for PyTorch's attention (it expects a different format)
+ # attn_mask = mask if mask is not None else None
+ # Normalize before attention
+ x = self.norm1(x)
+
+ # PyTorch's MultiheadAttention returns attn_output, attn_output_weights
+ # attn_output, _ = self.attn(x, x, x, attn_mask=attn_mask)
+
+ attn_output, _ = self.attn(x, x, x)
+
+ # Add & Norm
+ x = x + attn_output
+ x = x + self.mlp(self.norm2(x))
+ return x
+
+
+class CrossAttnBlock(nn.Module):
+ def __init__(self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, **block_kwargs):
+ """
+ Cross attention block
+ """
+ super().__init__()
+
+ self.norm1 = nn.LayerNorm(hidden_size)
+ self.norm_context = nn.LayerNorm(hidden_size)
+ self.norm2 = nn.LayerNorm(hidden_size)
+
+ self.cross_attn = nn.MultiheadAttention(
+ embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs
+ )
+
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
+
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0)
+
+ def forward(self, x, context, mask=None):
+ # Normalize inputs
+ x = self.norm1(x)
+ context = self.norm_context(context)
+
+ # Apply cross attention
+ # Note: nn.MultiheadAttention returns attn_output, attn_output_weights
+ attn_output, _ = self.cross_attn(x, context, context, attn_mask=mask)
+
+ # Add & Norm
+ x = x + attn_output
+ x = x + self.mlp(self.norm2(x))
+ return x
diff --git a/libs/vggt/heads/track_modules/utils.py b/libs/vggt/heads/track_modules/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..3f1fffeaedd33c7f1c2ef54220e24a2a0e5a57b2
--- /dev/null
+++ b/libs/vggt/heads/track_modules/utils.py
@@ -0,0 +1,223 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# Modified from https://github.com/facebookresearch/vggsfm
+# and https://github.com/facebookresearch/co-tracker/tree/main
+
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from typing import Optional, Tuple, Union
+
+
+def get_2d_sincos_pos_embed(embed_dim: int, grid_size: Union[int, Tuple[int, int]], return_grid=False) -> torch.Tensor:
+ """
+ This function initializes a grid and generates a 2D positional embedding using sine and cosine functions.
+ It is a wrapper of get_2d_sincos_pos_embed_from_grid.
+ Args:
+ - embed_dim: The embedding dimension.
+ - grid_size: The grid size.
+ Returns:
+ - pos_embed: The generated 2D positional embedding.
+ """
+ if isinstance(grid_size, tuple):
+ grid_size_h, grid_size_w = grid_size
+ else:
+ grid_size_h = grid_size_w = grid_size
+ grid_h = torch.arange(grid_size_h, dtype=torch.float)
+ grid_w = torch.arange(grid_size_w, dtype=torch.float)
+ grid = torch.meshgrid(grid_w, grid_h, indexing="xy")
+ grid = torch.stack(grid, dim=0)
+ grid = grid.reshape([2, 1, grid_size_h, grid_size_w])
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
+ if return_grid:
+ return (pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2), grid)
+ return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2)
+
+
+def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: torch.Tensor) -> torch.Tensor:
+ """
+ This function generates a 2D positional embedding from a given grid using sine and cosine functions.
+
+ Args:
+ - embed_dim: The embedding dimension.
+ - grid: The grid to generate the embedding from.
+
+ Returns:
+ - emb: The generated 2D positional embedding.
+ """
+ assert embed_dim % 2 == 0
+
+ # use half of dimensions to encode grid_h
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
+
+ emb = torch.cat([emb_h, emb_w], dim=2) # (H*W, D)
+ return emb
+
+
+def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: torch.Tensor) -> torch.Tensor:
+ """
+ This function generates a 1D positional embedding from a given grid using sine and cosine functions.
+
+ Args:
+ - embed_dim: The embedding dimension.
+ - pos: The position to generate the embedding from.
+
+ Returns:
+ - emb: The generated 1D positional embedding.
+ """
+ assert embed_dim % 2 == 0
+ omega = torch.arange(embed_dim // 2, dtype=torch.double)
+ omega /= embed_dim / 2.0
+ omega = 1.0 / 10000**omega # (D/2,)
+
+ pos = pos.reshape(-1) # (M,)
+ out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
+
+ emb_sin = torch.sin(out) # (M, D/2)
+ emb_cos = torch.cos(out) # (M, D/2)
+
+ emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
+ return emb[None].float()
+
+
+def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor:
+ """
+ This function generates a 2D positional embedding from given coordinates using sine and cosine functions.
+
+ Args:
+ - xy: The coordinates to generate the embedding from.
+ - C: The size of the embedding.
+ - cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding.
+
+ Returns:
+ - pe: The generated 2D positional embedding.
+ """
+ B, N, D = xy.shape
+ assert D == 2
+
+ x = xy[:, :, 0:1]
+ y = xy[:, :, 1:2]
+ div_term = (torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C)).reshape(1, 1, int(C / 2))
+
+ pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
+ pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
+
+ pe_x[:, :, 0::2] = torch.sin(x * div_term)
+ pe_x[:, :, 1::2] = torch.cos(x * div_term)
+
+ pe_y[:, :, 0::2] = torch.sin(y * div_term)
+ pe_y[:, :, 1::2] = torch.cos(y * div_term)
+
+ pe = torch.cat([pe_x, pe_y], dim=2) # (B, N, C*3)
+ if cat_coords:
+ pe = torch.cat([xy, pe], dim=2) # (B, N, C*3+3)
+ return pe
+
+
+def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"):
+ r"""Sample a tensor using bilinear interpolation
+
+ `bilinear_sampler(input, coords)` samples a tensor :attr:`input` at
+ coordinates :attr:`coords` using bilinear interpolation. It is the same
+ as `torch.nn.functional.grid_sample()` but with a different coordinate
+ convention.
+
+ The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where
+ :math:`B` is the batch size, :math:`C` is the number of channels,
+ :math:`H` is the height of the image, and :math:`W` is the width of the
+ image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is
+ interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`.
+
+ Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`,
+ in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note
+ that in this case the order of the components is slightly different
+ from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`.
+
+ If `align_corners` is `True`, the coordinate :math:`x` is assumed to be
+ in the range :math:`[0,W-1]`, with 0 corresponding to the center of the
+ left-most image pixel :math:`W-1` to the center of the right-most
+ pixel.
+
+ If `align_corners` is `False`, the coordinate :math:`x` is assumed to
+ be in the range :math:`[0,W]`, with 0 corresponding to the left edge of
+ the left-most pixel :math:`W` to the right edge of the right-most
+ pixel.
+
+ Similar conventions apply to the :math:`y` for the range
+ :math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range
+ :math:`[0,T-1]` and :math:`[0,T]`.
+
+ Args:
+ input (Tensor): batch of input images.
+ coords (Tensor): batch of coordinates.
+ align_corners (bool, optional): Coordinate convention. Defaults to `True`.
+ padding_mode (str, optional): Padding mode. Defaults to `"border"`.
+
+ Returns:
+ Tensor: sampled points.
+ """
+ coords = coords.detach().clone()
+ ############################################################
+ # IMPORTANT:
+ coords = coords.to(input.device).to(input.dtype)
+ ############################################################
+
+ sizes = input.shape[2:]
+
+ assert len(sizes) in [2, 3]
+
+ if len(sizes) == 3:
+ # t x y -> x y t to match dimensions T H W in grid_sample
+ coords = coords[..., [1, 2, 0]]
+
+ if align_corners:
+ scale = torch.tensor(
+ [2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device, dtype=coords.dtype
+ )
+ else:
+ scale = torch.tensor([2 / size for size in reversed(sizes)], device=coords.device, dtype=coords.dtype)
+
+ coords.mul_(scale) # coords = coords * scale
+ coords.sub_(1) # coords = coords - 1
+
+ return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode)
+
+
+def sample_features4d(input, coords):
+ r"""Sample spatial features
+
+ `sample_features4d(input, coords)` samples the spatial features
+ :attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`.
+
+ The field is sampled at coordinates :attr:`coords` using bilinear
+ interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R,
+ 2)`, where each sample has the format :math:`(x_i, y_i)`. This uses the
+ same convention as :func:`bilinear_sampler` with `align_corners=True`.
+
+ The output tensor has one feature per point, and has shape :math:`(B,
+ R, C)`.
+
+ Args:
+ input (Tensor): spatial features.
+ coords (Tensor): points.
+
+ Returns:
+ Tensor: sampled features.
+ """
+
+ B, _, _, _ = input.shape
+
+ # B R 2 -> B R 1 2
+ coords = coords.unsqueeze(2)
+
+ # B C R 1
+ feats = bilinear_sampler(input, coords)
+
+ return feats.permute(0, 2, 1, 3).view(B, -1, feats.shape[1] * feats.shape[3]) # B C R 1 -> B R C
diff --git a/libs/vggt/heads/utils.py b/libs/vggt/heads/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..533fc8ae67a75cd0a94d5ca96dc5a0513446c64f
--- /dev/null
+++ b/libs/vggt/heads/utils.py
@@ -0,0 +1,109 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+
+
+def position_grid_to_embed(pos_grid: torch.Tensor, embed_dim: int, omega_0: float = 100) -> torch.Tensor:
+ """
+ Convert 2D position grid (HxWx2) to sinusoidal embeddings (HxWxC)
+
+ Args:
+ pos_grid: Tensor of shape (H, W, 2) containing 2D coordinates
+ embed_dim: Output channel dimension for embeddings
+
+ Returns:
+ Tensor of shape (H, W, embed_dim) with positional embeddings
+ """
+ H, W, grid_dim = pos_grid.shape
+ assert grid_dim == 2
+ pos_flat = pos_grid.reshape(-1, grid_dim) # Flatten to (H*W, 2)
+
+ # Process x and y coordinates separately
+ emb_x = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 0], omega_0=omega_0) # [1, H*W, D/2]
+ emb_y = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 1], omega_0=omega_0) # [1, H*W, D/2]
+
+ # Combine and reshape
+ emb = torch.cat([emb_x, emb_y], dim=-1) # [1, H*W, D]
+
+ return emb.view(H, W, embed_dim) # [H, W, D]
+
+
+def make_sincos_pos_embed(embed_dim: int, pos: torch.Tensor, omega_0: float = 100) -> torch.Tensor:
+ """
+ This function generates a 1D positional embedding from a given grid using sine and cosine functions.
+
+ Args:
+ - embed_dim: The embedding dimension.
+ - pos: The position to generate the embedding from.
+
+ Returns:
+ - emb: The generated 1D positional embedding.
+ """
+ assert embed_dim % 2 == 0
+ device = pos.device
+ omega = torch.arange(embed_dim // 2, dtype=torch.float32 if device.type == "mps" else torch.double, device=device)
+ omega /= embed_dim / 2.0
+ omega = 1.0 / omega_0**omega # (D/2,)
+
+ pos = pos.reshape(-1) # (M,)
+ out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
+
+ emb_sin = torch.sin(out) # (M, D/2)
+ emb_cos = torch.cos(out) # (M, D/2)
+
+ emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
+ return emb.float()
+
+
+# Inspired by https://github.com/microsoft/moge
+
+
+def create_uv_grid(
+ width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None
+) -> torch.Tensor:
+ """
+ Create a normalized UV grid of shape (width, height, 2).
+
+ The grid spans horizontally and vertically according to an aspect ratio,
+ ensuring the top-left corner is at (-x_span, -y_span) and the bottom-right
+ corner is at (x_span, y_span), normalized by the diagonal of the plane.
+
+ Args:
+ width (int): Number of points horizontally.
+ height (int): Number of points vertically.
+ aspect_ratio (float, optional): Width-to-height ratio. Defaults to width/height.
+ dtype (torch.dtype, optional): Data type of the resulting tensor.
+ device (torch.device, optional): Device on which the tensor is created.
+
+ Returns:
+ torch.Tensor: A (width, height, 2) tensor of UV coordinates.
+ """
+ # Derive aspect ratio if not explicitly provided
+ if aspect_ratio is None:
+ aspect_ratio = float(width) / float(height)
+
+ # Compute normalized spans for X and Y
+ diag_factor = (aspect_ratio**2 + 1.0) ** 0.5
+ span_x = aspect_ratio / diag_factor
+ span_y = 1.0 / diag_factor
+
+ # Establish the linspace boundaries
+ left_x = -span_x * (width - 1) / width
+ right_x = span_x * (width - 1) / width
+ top_y = -span_y * (height - 1) / height
+ bottom_y = span_y * (height - 1) / height
+
+ # Generate 1D coordinates
+ x_coords = torch.linspace(left_x, right_x, steps=width, dtype=dtype, device=device)
+ y_coords = torch.linspace(top_y, bottom_y, steps=height, dtype=dtype, device=device)
+
+ # Create 2D meshgrid (width x height) and stack into UV
+ uu, vv = torch.meshgrid(x_coords, y_coords, indexing="xy")
+ uv_grid = torch.stack((uu, vv), dim=-1)
+
+ return uv_grid
diff --git a/libs/vggt/layers/__init__.py b/libs/vggt/layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8120f4bc83066cb3f825ce32daa3b437f88486f1
--- /dev/null
+++ b/libs/vggt/layers/__init__.py
@@ -0,0 +1,11 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from .mlp import Mlp
+from .patch_embed import PatchEmbed
+from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
+from .block import NestedTensorBlock
+from .attention import MemEffAttention
diff --git a/libs/vggt/layers/attention.py b/libs/vggt/layers/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e823b4b7a93cca75e4cbab1cdfbbc3121a316fa
--- /dev/null
+++ b/libs/vggt/layers/attention.py
@@ -0,0 +1,93 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
+
+import logging
+import os
+import warnings
+
+from torch import Tensor
+from torch import nn
+import torch.nn.functional as F
+
+XFORMERS_AVAILABLE = False
+
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = True,
+ proj_bias: bool = True,
+ attn_drop: float = 0.0,
+ proj_drop: float = 0.0,
+ norm_layer: nn.Module = nn.LayerNorm,
+ qk_norm: bool = False,
+ fused_attn: bool = True, # use F.scaled_dot_product_attention or not
+ rope=None,
+ ) -> None:
+ super().__init__()
+ assert dim % num_heads == 0, "dim should be divisible by num_heads"
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+ self.scale = self.head_dim**-0.5
+ self.fused_attn = fused_attn
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
+ self.proj_drop = nn.Dropout(proj_drop)
+ self.rope = rope
+
+ def forward(self, x: Tensor, pos=None) -> Tensor:
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv.unbind(0)
+ q, k = self.q_norm(q), self.k_norm(k)
+
+ if self.rope is not None:
+ q = self.rope(q, pos)
+ k = self.rope(k, pos)
+
+ if self.fused_attn:
+ x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop.p if self.training else 0.0)
+ else:
+ q = q * self.scale
+ attn = q @ k.transpose(-2, -1)
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+ x = attn @ v
+
+ x = x.transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class MemEffAttention(Attention):
+ def forward(self, x: Tensor, attn_bias=None, pos=None) -> Tensor:
+ assert pos is None
+ if not XFORMERS_AVAILABLE:
+ if attn_bias is not None:
+ raise AssertionError("xFormers is required for using nested tensors")
+ return super().forward(x)
+
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
+
+ q, k, v = unbind(qkv, 2)
+
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
+ x = x.reshape([B, N, C])
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
diff --git a/libs/vggt/layers/block.py b/libs/vggt/layers/block.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc5847352a1f8f5d63da28c99e94270e50ccf3aa
--- /dev/null
+++ b/libs/vggt/layers/block.py
@@ -0,0 +1,247 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
+
+import logging
+import os
+from typing import Callable, List, Any, Tuple, Dict
+import warnings
+
+import torch
+from torch import nn, Tensor
+
+from .attention import Attention
+from .drop_path import DropPath
+from .layer_scale import LayerScale
+from .mlp import Mlp
+
+
+XFORMERS_AVAILABLE = False
+
+
+class Block(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ qkv_bias: bool = True,
+ proj_bias: bool = True,
+ ffn_bias: bool = True,
+ drop: float = 0.0,
+ attn_drop: float = 0.0,
+ init_values=None,
+ drop_path: float = 0.0,
+ act_layer: Callable[..., nn.Module] = nn.GELU,
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
+ attn_class: Callable[..., nn.Module] = Attention,
+ ffn_layer: Callable[..., nn.Module] = Mlp,
+ qk_norm: bool = False,
+ fused_attn: bool = True, # use F.scaled_dot_product_attention or not
+ rope=None,
+ ) -> None:
+ super().__init__()
+
+ self.norm1 = norm_layer(dim)
+
+ self.attn = attn_class(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ qk_norm=qk_norm,
+ fused_attn=fused_attn,
+ rope=rope,
+ )
+
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = ffn_layer(
+ in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, bias=ffn_bias
+ )
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.sample_drop_ratio = drop_path
+
+ def forward(self, x: Tensor, pos=None) -> Tensor:
+ def attn_residual_func(x: Tensor, pos=None) -> Tensor:
+ return self.ls1(self.attn(self.norm1(x), pos=pos))
+
+ def ffn_residual_func(x: Tensor) -> Tensor:
+ return self.ls2(self.mlp(self.norm2(x)))
+
+ if self.training and self.sample_drop_ratio > 0.1:
+ # the overhead is compensated only for a drop path rate larger than 0.1
+ x = drop_add_residual_stochastic_depth(
+ x, pos=pos, residual_func=attn_residual_func, sample_drop_ratio=self.sample_drop_ratio
+ )
+ x = drop_add_residual_stochastic_depth(
+ x, residual_func=ffn_residual_func, sample_drop_ratio=self.sample_drop_ratio
+ )
+ elif self.training and self.sample_drop_ratio > 0.0:
+ x = x + self.drop_path1(attn_residual_func(x, pos=pos))
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
+ else:
+ x = x + attn_residual_func(x, pos=pos)
+ x = x + ffn_residual_func(x)
+ return x
+
+
+def drop_add_residual_stochastic_depth(
+ x: Tensor, residual_func: Callable[[Tensor], Tensor], sample_drop_ratio: float = 0.0, pos=None
+) -> Tensor:
+ # 1) extract subset using permutation
+ b, n, d = x.shape
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
+ x_subset = x[brange]
+
+ # 2) apply residual_func to get residual
+ if pos is not None:
+ # if necessary, apply rope to the subset
+ pos = pos[brange]
+ residual = residual_func(x_subset, pos=pos)
+ else:
+ residual = residual_func(x_subset)
+
+ x_flat = x.flatten(1)
+ residual = residual.flatten(1)
+
+ residual_scale_factor = b / sample_subset_size
+
+ # 3) add the residual
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
+ return x_plus_residual.view_as(x)
+
+
+def get_branges_scales(x, sample_drop_ratio=0.0):
+ b, n, d = x.shape
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
+ residual_scale_factor = b / sample_subset_size
+ return brange, residual_scale_factor
+
+
+def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
+ if scaling_vector is None:
+ x_flat = x.flatten(1)
+ residual = residual.flatten(1)
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
+ else:
+ x_plus_residual = scaled_index_add(
+ x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
+ )
+ return x_plus_residual
+
+
+attn_bias_cache: Dict[Tuple, Any] = {}
+
+
+def get_attn_bias_and_cat(x_list, branges=None):
+ """
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
+ """
+ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
+ if all_shapes not in attn_bias_cache.keys():
+ seqlens = []
+ for b, x in zip(batch_sizes, x_list):
+ for _ in range(b):
+ seqlens.append(x.shape[1])
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
+ attn_bias._batch_sizes = batch_sizes
+ attn_bias_cache[all_shapes] = attn_bias
+
+ if branges is not None:
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
+ else:
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
+
+ return attn_bias_cache[all_shapes], cat_tensors
+
+
+def drop_add_residual_stochastic_depth_list(
+ x_list: List[Tensor],
+ residual_func: Callable[[Tensor, Any], Tensor],
+ sample_drop_ratio: float = 0.0,
+ scaling_vector=None,
+) -> Tensor:
+ # 1) generate random set of indices for dropping samples in the batch
+ branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
+ branges = [s[0] for s in branges_scales]
+ residual_scale_factors = [s[1] for s in branges_scales]
+
+ # 2) get attention bias and index+concat the tensors
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
+
+ # 3) apply residual_func to get residual, and split the result
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
+
+ outputs = []
+ for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
+ outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
+ return outputs
+
+
+class NestedTensorBlock(Block):
+ def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
+ """
+ x_list contains a list of tensors to nest together and run
+ """
+ assert isinstance(self.attn, MemEffAttention)
+
+ if self.training and self.sample_drop_ratio > 0.0:
+
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
+
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.mlp(self.norm2(x))
+
+ x_list = drop_add_residual_stochastic_depth_list(
+ x_list,
+ residual_func=attn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ scaling_vector=(self.ls1.gamma if isinstance(self.ls1, LayerScale) else None),
+ )
+ x_list = drop_add_residual_stochastic_depth_list(
+ x_list,
+ residual_func=ffn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ scaling_vector=(self.ls2.gamma if isinstance(self.ls1, LayerScale) else None),
+ )
+ return x_list
+ else:
+
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
+
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.ls2(self.mlp(self.norm2(x)))
+
+ attn_bias, x = get_attn_bias_and_cat(x_list)
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
+ x = x + ffn_residual_func(x)
+ return attn_bias.split(x)
+
+ def forward(self, x_or_x_list):
+ if isinstance(x_or_x_list, Tensor):
+ return super().forward(x_or_x_list)
+ elif isinstance(x_or_x_list, list):
+ if not XFORMERS_AVAILABLE:
+ raise AssertionError("xFormers is required for using nested tensors")
+ return self.forward_nested(x_or_x_list)
+ else:
+ raise AssertionError
diff --git a/libs/vggt/layers/drop_path.py b/libs/vggt/layers/drop_path.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d640e0b969b8dcba96260243473700b4e5b24b5
--- /dev/null
+++ b/libs/vggt/layers/drop_path.py
@@ -0,0 +1,34 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
+
+
+from torch import nn
+
+
+def drop_path(x, drop_prob: float = 0.0, training: bool = False):
+ if drop_prob == 0.0 or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
+ if keep_prob > 0.0:
+ random_tensor.div_(keep_prob)
+ output = x * random_tensor
+ return output
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob=None):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
diff --git a/libs/vggt/layers/layer_scale.py b/libs/vggt/layers/layer_scale.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ddfc51c3d87370d50175f5b4e649dac1c614ff9
--- /dev/null
+++ b/libs/vggt/layers/layer_scale.py
@@ -0,0 +1,22 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
+
+from typing import Union
+
+import torch
+from torch import Tensor
+from torch import nn
+
+
+class LayerScale(nn.Module):
+ def __init__(self, dim: int, init_values: Union[float, Tensor] = 1e-5, inplace: bool = False) -> None:
+ super().__init__()
+ self.inplace = inplace
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
+
+ def forward(self, x: Tensor) -> Tensor:
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
diff --git a/libs/vggt/layers/mlp.py b/libs/vggt/layers/mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..bbf9432aae9258612caeae910a7bde17999e328e
--- /dev/null
+++ b/libs/vggt/layers/mlp.py
@@ -0,0 +1,40 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
+
+
+from typing import Callable, Optional
+
+from torch import Tensor, nn
+
+
+class Mlp(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = nn.GELU,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x: Tensor) -> Tensor:
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
diff --git a/libs/vggt/layers/patch_embed.py b/libs/vggt/layers/patch_embed.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc19605e4d6e88d06355ae3b1afddc76f595aafe
--- /dev/null
+++ b/libs/vggt/layers/patch_embed.py
@@ -0,0 +1,85 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
+
+from typing import Callable, Optional, Tuple, Union
+
+from torch import Tensor
+import torch.nn as nn
+
+
+def make_2tuple(x):
+ if isinstance(x, tuple):
+ assert len(x) == 2
+ return x
+
+ assert isinstance(x, int)
+ return (x, x)
+
+
+class PatchEmbed(nn.Module):
+ """
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
+
+ Args:
+ img_size: Image size.
+ patch_size: Patch token size.
+ in_chans: Number of input image channels.
+ embed_dim: Number of linear projection output channels.
+ norm_layer: Normalization layer.
+ """
+
+ def __init__(
+ self,
+ img_size: Union[int, Tuple[int, int]] = 224,
+ patch_size: Union[int, Tuple[int, int]] = 16,
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ norm_layer: Optional[Callable] = None,
+ flatten_embedding: bool = True,
+ ) -> None:
+ super().__init__()
+
+ image_HW = make_2tuple(img_size)
+ patch_HW = make_2tuple(patch_size)
+ patch_grid_size = (image_HW[0] // patch_HW[0], image_HW[1] // patch_HW[1])
+
+ self.img_size = image_HW
+ self.patch_size = patch_HW
+ self.patches_resolution = patch_grid_size
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ self.flatten_embedding = flatten_embedding
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+
+ def forward(self, x: Tensor) -> Tensor:
+ _, _, H, W = x.shape
+ patch_H, patch_W = self.patch_size
+
+ assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
+ assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
+
+ x = self.proj(x) # B C H W
+ H, W = x.size(2), x.size(3)
+ x = x.flatten(2).transpose(1, 2) # B HW C
+ x = self.norm(x)
+ if not self.flatten_embedding:
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
+ return x
+
+ def flops(self) -> float:
+ Ho, Wo = self.patches_resolution
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
+ if self.norm is not None:
+ flops += Ho * Wo * self.embed_dim
+ return flops
diff --git a/libs/vggt/layers/rope.py b/libs/vggt/layers/rope.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d5d33304e55dbd05687bd86752a47a80e5f82df
--- /dev/null
+++ b/libs/vggt/layers/rope.py
@@ -0,0 +1,188 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+
+# Implementation of 2D Rotary Position Embeddings (RoPE).
+
+# This module provides a clean implementation of 2D Rotary Position Embeddings,
+# which extends the original RoPE concept to handle 2D spatial positions.
+
+# Inspired by:
+# https://github.com/meta-llama/codellama/blob/main/llama/model.py
+# https://github.com/naver-ai/rope-vit
+
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from typing import Dict, Tuple
+
+
+class PositionGetter:
+ """Generates and caches 2D spatial positions for patches in a grid.
+
+ This class efficiently manages the generation of spatial coordinates for patches
+ in a 2D grid, caching results to avoid redundant computations.
+
+ Attributes:
+ position_cache: Dictionary storing precomputed position tensors for different
+ grid dimensions.
+ """
+
+ def __init__(self):
+ """Initializes the position generator with an empty cache."""
+ self.position_cache: Dict[Tuple[int, int], torch.Tensor] = {}
+
+ def __call__(self, batch_size: int, height: int, width: int, device: torch.device) -> torch.Tensor:
+ """Generates spatial positions for a batch of patches.
+
+ Args:
+ batch_size: Number of samples in the batch.
+ height: Height of the grid in patches.
+ width: Width of the grid in patches.
+ device: Target device for the position tensor.
+
+ Returns:
+ Tensor of shape (batch_size, height*width, 2) containing y,x coordinates
+ for each position in the grid, repeated for each batch item.
+ """
+ if (height, width) not in self.position_cache:
+ y_coords = torch.arange(height, device=device)
+ x_coords = torch.arange(width, device=device)
+ positions = torch.cartesian_prod(y_coords, x_coords)
+ self.position_cache[height, width] = positions
+
+ cached_positions = self.position_cache[height, width]
+ return cached_positions.view(1, height * width, 2).expand(batch_size, -1, -1).clone()
+
+
+class RotaryPositionEmbedding2D(nn.Module):
+ """2D Rotary Position Embedding implementation.
+
+ This module applies rotary position embeddings to input tokens based on their
+ 2D spatial positions. It handles the position-dependent rotation of features
+ separately for vertical and horizontal dimensions.
+
+ Args:
+ frequency: Base frequency for the position embeddings. Default: 100.0
+ scaling_factor: Scaling factor for frequency computation. Default: 1.0
+
+ Attributes:
+ base_frequency: Base frequency for computing position embeddings.
+ scaling_factor: Factor to scale the computed frequencies.
+ frequency_cache: Cache for storing precomputed frequency components.
+ """
+
+ def __init__(self, frequency: float = 100.0, scaling_factor: float = 1.0):
+ """Initializes the 2D RoPE module."""
+ super().__init__()
+ self.base_frequency = frequency
+ self.scaling_factor = scaling_factor
+ self.frequency_cache: Dict[Tuple, Tuple[torch.Tensor, torch.Tensor]] = {}
+
+ def _compute_frequency_components(
+ self, dim: int, seq_len: int, device: torch.device, dtype: torch.dtype
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Computes frequency components for rotary embeddings.
+
+ Args:
+ dim: Feature dimension (must be even).
+ seq_len: Maximum sequence length.
+ device: Target device for computations.
+ dtype: Data type for the computed tensors.
+
+ Returns:
+ Tuple of (cosine, sine) tensors for frequency components.
+ """
+ cache_key = (dim, seq_len, device, dtype)
+ if cache_key not in self.frequency_cache:
+ # Compute frequency bands
+ exponents = torch.arange(0, dim, 2, device=device).float() / dim
+ inv_freq = 1.0 / (self.base_frequency**exponents)
+
+ # Generate position-dependent frequencies
+ positions = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
+ angles = torch.einsum("i,j->ij", positions, inv_freq)
+
+ # Compute and cache frequency components
+ angles = angles.to(dtype)
+ angles = torch.cat((angles, angles), dim=-1)
+ cos_components = angles.cos().to(dtype)
+ sin_components = angles.sin().to(dtype)
+ self.frequency_cache[cache_key] = (cos_components, sin_components)
+
+ return self.frequency_cache[cache_key]
+
+ @staticmethod
+ def _rotate_features(x: torch.Tensor) -> torch.Tensor:
+ """Performs feature rotation by splitting and recombining feature dimensions.
+
+ Args:
+ x: Input tensor to rotate.
+
+ Returns:
+ Rotated feature tensor.
+ """
+ feature_dim = x.shape[-1]
+ x1, x2 = x[..., : feature_dim // 2], x[..., feature_dim // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+ def _apply_1d_rope(
+ self, tokens: torch.Tensor, positions: torch.Tensor, cos_comp: torch.Tensor, sin_comp: torch.Tensor
+ ) -> torch.Tensor:
+ """Applies 1D rotary position embeddings along one dimension.
+
+ Args:
+ tokens: Input token features.
+ positions: Position indices.
+ cos_comp: Cosine components for rotation.
+ sin_comp: Sine components for rotation.
+
+ Returns:
+ Tokens with applied rotary position embeddings.
+ """
+ # Embed positions with frequency components
+ cos = F.embedding(positions, cos_comp)[:, None, :, :]
+ sin = F.embedding(positions, sin_comp)[:, None, :, :]
+
+ # Apply rotation
+ return (tokens * cos) + (self._rotate_features(tokens) * sin)
+
+ def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
+ """Applies 2D rotary position embeddings to input tokens.
+
+ Args:
+ tokens: Input tensor of shape (batch_size, n_heads, n_tokens, dim).
+ The feature dimension (dim) must be divisible by 4.
+ positions: Position tensor of shape (batch_size, n_tokens, 2) containing
+ the y and x coordinates for each token.
+
+ Returns:
+ Tensor of same shape as input with applied 2D rotary position embeddings.
+
+ Raises:
+ AssertionError: If input dimensions are invalid or positions are malformed.
+ """
+ # Validate inputs
+ assert tokens.size(-1) % 2 == 0, "Feature dimension must be even"
+ assert positions.ndim == 3 and positions.shape[-1] == 2, "Positions must have shape (batch_size, n_tokens, 2)"
+
+ # Compute feature dimension for each spatial direction
+ feature_dim = tokens.size(-1) // 2
+
+ # Get frequency components
+ max_position = int(positions.max()) + 1
+ cos_comp, sin_comp = self._compute_frequency_components(feature_dim, max_position, tokens.device, tokens.dtype)
+
+ # Split features for vertical and horizontal processing
+ vertical_features, horizontal_features = tokens.chunk(2, dim=-1)
+
+ # Apply RoPE separately for each dimension
+ vertical_features = self._apply_1d_rope(vertical_features, positions[..., 0], cos_comp, sin_comp)
+ horizontal_features = self._apply_1d_rope(horizontal_features, positions[..., 1], cos_comp, sin_comp)
+
+ # Combine processed features
+ return torch.cat((vertical_features, horizontal_features), dim=-1)
diff --git a/libs/vggt/layers/swiglu_ffn.py b/libs/vggt/layers/swiglu_ffn.py
new file mode 100644
index 0000000000000000000000000000000000000000..1dd991e1deb87141ccd282098d4b9d38fed6ef25
--- /dev/null
+++ b/libs/vggt/layers/swiglu_ffn.py
@@ -0,0 +1,67 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import os
+from typing import Callable, Optional
+import warnings
+
+from torch import Tensor, nn
+import torch.nn.functional as F
+
+
+class SwiGLUFFN(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = None,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
+
+ def forward(self, x: Tensor) -> Tensor:
+ x12 = self.w12(x)
+ x1, x2 = x12.chunk(2, dim=-1)
+ hidden = F.silu(x1) * x2
+ return self.w3(hidden)
+
+
+XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
+# try:
+# if XFORMERS_ENABLED:
+# from xformers.ops import SwiGLU
+
+# XFORMERS_AVAILABLE = True
+# warnings.warn("xFormers is available (SwiGLU)")
+# else:
+# warnings.warn("xFormers is disabled (SwiGLU)")
+# raise ImportError
+# except ImportError:
+SwiGLU = SwiGLUFFN
+XFORMERS_AVAILABLE = False
+
+# warnings.warn("xFormers is not available (SwiGLU)")
+
+
+class SwiGLUFFNFused(SwiGLU):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = None,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
+ super().__init__(in_features=in_features, hidden_features=hidden_features, out_features=out_features, bias=bias)
diff --git a/libs/vggt/layers/vision_transformer.py b/libs/vggt/layers/vision_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..deda8fde42b1b5b3340132c9c75338c65c9bea3f
--- /dev/null
+++ b/libs/vggt/layers/vision_transformer.py
@@ -0,0 +1,397 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
+
+from functools import partial
+import math
+import logging
+from typing import Sequence, Tuple, Union, Callable
+
+import torch
+import torch.nn as nn
+from torch.utils.checkpoint import checkpoint
+from torch.nn.init import trunc_normal_
+from . import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
+
+logger = logging.getLogger("dinov2")
+
+
+def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
+ if not depth_first and include_root:
+ fn(module=module, name=name)
+ for child_name, child_module in module.named_children():
+ child_name = ".".join((name, child_name)) if name else child_name
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
+ if depth_first and include_root:
+ fn(module=module, name=name)
+ return module
+
+
+class BlockChunk(nn.ModuleList):
+ def forward(self, x):
+ for b in self:
+ x = b(x)
+ return x
+
+
+class DinoVisionTransformer(nn.Module):
+ def __init__(
+ self,
+ img_size=224,
+ patch_size=16,
+ in_chans=3,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ ffn_bias=True,
+ proj_bias=True,
+ drop_path_rate=0.0,
+ drop_path_uniform=False,
+ init_values=None, # for layerscale: None or 0 => no layerscale
+ embed_layer=PatchEmbed,
+ act_layer=nn.GELU,
+ block_fn=Block,
+ ffn_layer="mlp",
+ block_chunks=1,
+ num_register_tokens=0,
+ interpolate_antialias=False,
+ interpolate_offset=0.1,
+ qk_norm=False,
+ ):
+ """
+ Args:
+ img_size (int, tuple): input image size
+ patch_size (int, tuple): patch size
+ in_chans (int): number of input channels
+ embed_dim (int): embedding dimension
+ depth (int): depth of transformer
+ num_heads (int): number of attention heads
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
+ qkv_bias (bool): enable bias for qkv if True
+ proj_bias (bool): enable bias for proj in attn if True
+ ffn_bias (bool): enable bias for ffn if True
+ drop_path_rate (float): stochastic depth rate
+ drop_path_uniform (bool): apply uniform drop rate across blocks
+ weight_init (str): weight init scheme
+ init_values (float): layer-scale init values
+ embed_layer (nn.Module): patch embedding layer
+ act_layer (nn.Module): MLP activation layer
+ block_fn (nn.Module): transformer block class
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
+ num_register_tokens: (int) number of extra cls tokens (so-called "registers")
+ interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
+ interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
+ """
+ super().__init__()
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
+
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
+ self.num_tokens = 1
+ self.n_blocks = depth
+ self.num_heads = num_heads
+ self.patch_size = patch_size
+ self.num_register_tokens = num_register_tokens
+ self.interpolate_antialias = interpolate_antialias
+ self.interpolate_offset = interpolate_offset
+ self.use_reentrant = False # hardcoded to False
+
+ self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
+ num_patches = self.patch_embed.num_patches
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
+ assert num_register_tokens >= 0
+ self.register_tokens = (
+ nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
+ )
+
+ if drop_path_uniform is True:
+ dpr = [drop_path_rate] * depth
+ else:
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
+
+ if ffn_layer == "mlp":
+ logger.info("using MLP layer as FFN")
+ ffn_layer = Mlp
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
+ logger.info("using SwiGLU layer as FFN")
+ ffn_layer = SwiGLUFFNFused
+ elif ffn_layer == "identity":
+ logger.info("using Identity layer as FFN")
+
+ def f(*args, **kwargs):
+ return nn.Identity()
+
+ ffn_layer = f
+ else:
+ raise NotImplementedError
+
+ blocks_list = [
+ block_fn(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ ffn_bias=ffn_bias,
+ drop_path=dpr[i],
+ norm_layer=norm_layer,
+ act_layer=act_layer,
+ ffn_layer=ffn_layer,
+ init_values=init_values,
+ qk_norm=qk_norm,
+ )
+ for i in range(depth)
+ ]
+ if block_chunks > 0:
+ self.chunked_blocks = True
+ chunked_blocks = []
+ chunksize = depth // block_chunks
+ for i in range(0, depth, chunksize):
+ # this is to keep the block index consistent if we chunk the block list
+ chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
+ else:
+ self.chunked_blocks = False
+ self.blocks = nn.ModuleList(blocks_list)
+
+ self.norm = norm_layer(embed_dim)
+ self.head = nn.Identity()
+
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
+
+ self.init_weights()
+
+ def init_weights(self):
+ trunc_normal_(self.pos_embed, std=0.02)
+ nn.init.normal_(self.cls_token, std=1e-6)
+ if self.register_tokens is not None:
+ nn.init.normal_(self.register_tokens, std=1e-6)
+ named_apply(init_weights_vit_timm, self)
+
+ def interpolate_pos_encoding(self, x, w, h):
+ previous_dtype = x.dtype
+ npatch = x.shape[1] - 1
+ N = self.pos_embed.shape[1] - 1
+ if npatch == N and w == h:
+ return self.pos_embed
+ pos_embed = self.pos_embed.float()
+ class_pos_embed = pos_embed[:, 0]
+ patch_pos_embed = pos_embed[:, 1:]
+ dim = x.shape[-1]
+ w0 = w // self.patch_size
+ h0 = h // self.patch_size
+ M = int(math.sqrt(N)) # Recover the number of patches in each dimension
+ assert N == M * M
+ kwargs = {}
+ if self.interpolate_offset:
+ # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
+ # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
+ sx = float(w0 + self.interpolate_offset) / M
+ sy = float(h0 + self.interpolate_offset) / M
+ kwargs["scale_factor"] = (sx, sy)
+ else:
+ # Simply specify an output size instead of a scale factor
+ kwargs["size"] = (w0, h0)
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
+ mode="bicubic",
+ antialias=self.interpolate_antialias,
+ **kwargs,
+ )
+ assert (w0, h0) == patch_pos_embed.shape[-2:]
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
+
+ def prepare_tokens_with_masks(self, x, masks=None):
+ B, nc, w, h = x.shape
+ x = self.patch_embed(x)
+ if masks is not None:
+ x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
+
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
+ x = x + self.interpolate_pos_encoding(x, w, h)
+
+ if self.register_tokens is not None:
+ x = torch.cat((x[:, :1], self.register_tokens.expand(x.shape[0], -1, -1), x[:, 1:]), dim=1)
+
+ return x
+
+ def forward_features_list(self, x_list, masks_list):
+ x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
+
+ for blk in self.blocks:
+ if self.training:
+ x = checkpoint(blk, x, use_reentrant=self.use_reentrant)
+ else:
+ x = blk(x)
+
+ all_x = x
+ output = []
+ for x, masks in zip(all_x, masks_list):
+ x_norm = self.norm(x)
+ output.append(
+ {
+ "x_norm_clstoken": x_norm[:, 0],
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
+ "x_prenorm": x,
+ "masks": masks,
+ }
+ )
+ return output
+
+ def forward_features(self, x, masks=None):
+ if isinstance(x, list):
+ return self.forward_features_list(x, masks)
+
+ x = self.prepare_tokens_with_masks(x, masks)
+
+ for blk in self.blocks:
+ if self.training:
+ x = checkpoint(blk, x, use_reentrant=self.use_reentrant)
+ else:
+ x = blk(x)
+
+ x_norm = self.norm(x)
+ return {
+ "x_norm_clstoken": x_norm[:, 0],
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
+ "x_prenorm": x,
+ "masks": masks,
+ }
+
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
+ x = self.prepare_tokens_with_masks(x)
+ # If n is an int, take the n last blocks. If it's a list, take them
+ output, total_block_len = [], len(self.blocks)
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+ for i, blk in enumerate(self.blocks):
+ x = blk(x)
+ if i in blocks_to_take:
+ output.append(x)
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
+ return output
+
+ def _get_intermediate_layers_chunked(self, x, n=1):
+ x = self.prepare_tokens_with_masks(x)
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
+ # If n is an int, take the n last blocks. If it's a list, take them
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+ for block_chunk in self.blocks:
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
+ x = blk(x)
+ if i in blocks_to_take:
+ output.append(x)
+ i += 1
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
+ return output
+
+ def get_intermediate_layers(
+ self,
+ x: torch.Tensor,
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
+ reshape: bool = False,
+ return_class_token: bool = False,
+ norm=True,
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
+ if self.chunked_blocks:
+ outputs = self._get_intermediate_layers_chunked(x, n)
+ else:
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
+ if norm:
+ outputs = [self.norm(out) for out in outputs]
+ class_tokens = [out[:, 0] for out in outputs]
+ outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs]
+ if reshape:
+ B, _, w, h = x.shape
+ outputs = [
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
+ for out in outputs
+ ]
+ if return_class_token:
+ return tuple(zip(outputs, class_tokens))
+ return tuple(outputs)
+
+ def forward(self, *args, is_training=True, **kwargs):
+ ret = self.forward_features(*args, **kwargs)
+ if is_training:
+ return ret
+ else:
+ return self.head(ret["x_norm_clstoken"])
+
+
+def init_weights_vit_timm(module: nn.Module, name: str = ""):
+ """ViT weight initialization, original timm impl (for reproducibility)"""
+ if isinstance(module, nn.Linear):
+ trunc_normal_(module.weight, std=0.02)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+
+
+def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=384,
+ depth=12,
+ num_heads=6,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
+
+
+def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
+
+
+def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=1024,
+ depth=24,
+ num_heads=16,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
+
+
+def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
+ """
+ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
+ """
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=1536,
+ depth=40,
+ num_heads=24,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
diff --git a/libs/vggt/models/aggregator.py b/libs/vggt/models/aggregator.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c6b25d6df44a0dbf71b214f5084b2a21fcd087e
--- /dev/null
+++ b/libs/vggt/models/aggregator.py
@@ -0,0 +1,331 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import logging
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.utils.checkpoint import checkpoint
+from typing import Optional, Tuple, Union, List, Dict, Any
+
+from vggt.layers import PatchEmbed
+from vggt.layers.block import Block
+from vggt.layers.rope import RotaryPositionEmbedding2D, PositionGetter
+from vggt.layers.vision_transformer import vit_small, vit_base, vit_large, vit_giant2
+
+logger = logging.getLogger(__name__)
+
+_RESNET_MEAN = [0.485, 0.456, 0.406]
+_RESNET_STD = [0.229, 0.224, 0.225]
+
+
+class Aggregator(nn.Module):
+ """
+ The Aggregator applies alternating-attention over input frames,
+ as described in VGGT: Visual Geometry Grounded Transformer.
+
+ Remember to set model.train() to enable gradient checkpointing to reduce memory usage.
+
+ Args:
+ img_size (int): Image size in pixels.
+ patch_size (int): Size of each patch for PatchEmbed.
+ embed_dim (int): Dimension of the token embeddings.
+ depth (int): Number of blocks.
+ num_heads (int): Number of attention heads.
+ mlp_ratio (float): Ratio of MLP hidden dim to embedding dim.
+ num_register_tokens (int): Number of register tokens.
+ block_fn (nn.Module): The block type used for attention (Block by default).
+ qkv_bias (bool): Whether to include bias in QKV projections.
+ proj_bias (bool): Whether to include bias in the output projection.
+ ffn_bias (bool): Whether to include bias in MLP layers.
+ patch_embed (str): Type of patch embed. e.g., "conv" or "dinov2_vitl14_reg".
+ aa_order (list[str]): The order of alternating attention, e.g. ["frame", "global"].
+ aa_block_size (int): How many blocks to group under each attention type before switching. If not necessary, set to 1.
+ qk_norm (bool): Whether to apply QK normalization.
+ rope_freq (int): Base frequency for rotary embedding. -1 to disable.
+ init_values (float): Init scale for layer scale.
+ """
+
+ def __init__(
+ self,
+ img_size=518,
+ patch_size=14,
+ embed_dim=1024,
+ depth=24,
+ num_heads=16,
+ mlp_ratio=4.0,
+ num_register_tokens=4,
+ block_fn=Block,
+ qkv_bias=True,
+ proj_bias=True,
+ ffn_bias=True,
+ patch_embed="dinov2_vitl14_reg",
+ aa_order=["frame", "global"],
+ aa_block_size=1,
+ qk_norm=True,
+ rope_freq=100,
+ init_values=0.01,
+ ):
+ super().__init__()
+
+ self.__build_patch_embed__(patch_embed, img_size, patch_size, num_register_tokens, embed_dim=embed_dim)
+
+ # Initialize rotary position embedding if frequency > 0
+ self.rope = RotaryPositionEmbedding2D(frequency=rope_freq) if rope_freq > 0 else None
+ self.position_getter = PositionGetter() if self.rope is not None else None
+
+ self.frame_blocks = nn.ModuleList(
+ [
+ block_fn(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ ffn_bias=ffn_bias,
+ init_values=init_values,
+ qk_norm=qk_norm,
+ rope=self.rope,
+ )
+ for _ in range(depth)
+ ]
+ )
+
+ self.global_blocks = nn.ModuleList(
+ [
+ block_fn(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ ffn_bias=ffn_bias,
+ init_values=init_values,
+ qk_norm=qk_norm,
+ rope=self.rope,
+ )
+ for _ in range(depth)
+ ]
+ )
+
+ self.depth = depth
+ self.aa_order = aa_order
+ self.patch_size = patch_size
+ self.aa_block_size = aa_block_size
+
+ # Validate that depth is divisible by aa_block_size
+ if self.depth % self.aa_block_size != 0:
+ raise ValueError(f"depth ({depth}) must be divisible by aa_block_size ({aa_block_size})")
+
+ self.aa_block_num = self.depth // self.aa_block_size
+
+ # Note: We have two camera tokens, one for the first frame and one for the rest
+ # The same applies for register tokens
+ self.camera_token = nn.Parameter(torch.randn(1, 2, 1, embed_dim))
+ self.register_token = nn.Parameter(torch.randn(1, 2, num_register_tokens, embed_dim))
+
+ # The patch tokens start after the camera and register tokens
+ self.patch_start_idx = 1 + num_register_tokens
+
+ # Initialize parameters with small values
+ nn.init.normal_(self.camera_token, std=1e-6)
+ nn.init.normal_(self.register_token, std=1e-6)
+
+ # Register normalization constants as buffers
+ for name, value in (("_resnet_mean", _RESNET_MEAN), ("_resnet_std", _RESNET_STD)):
+ self.register_buffer(name, torch.FloatTensor(value).view(1, 1, 3, 1, 1), persistent=False)
+
+ self.use_reentrant = False # hardcoded to False
+
+ def __build_patch_embed__(
+ self,
+ patch_embed,
+ img_size,
+ patch_size,
+ num_register_tokens,
+ interpolate_antialias=True,
+ interpolate_offset=0.0,
+ block_chunks=0,
+ init_values=1.0,
+ embed_dim=1024,
+ ):
+ """
+ Build the patch embed layer. If 'conv', we use a
+ simple PatchEmbed conv layer. Otherwise, we use a vision transformer.
+ """
+
+ if "conv" in patch_embed:
+ self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=3, embed_dim=embed_dim)
+ else:
+ vit_models = {
+ "dinov2_vitl14_reg": vit_large,
+ "dinov2_vitb14_reg": vit_base,
+ "dinov2_vits14_reg": vit_small,
+ "dinov2_vitg2_reg": vit_giant2,
+ }
+
+ self.patch_embed = vit_models[patch_embed](
+ img_size=img_size,
+ patch_size=patch_size,
+ num_register_tokens=num_register_tokens,
+ interpolate_antialias=interpolate_antialias,
+ interpolate_offset=interpolate_offset,
+ block_chunks=block_chunks,
+ init_values=init_values,
+ )
+
+ # Disable gradient updates for mask token
+ if hasattr(self.patch_embed, "mask_token"):
+ self.patch_embed.mask_token.requires_grad_(False)
+
+ def forward(self, images: torch.Tensor) -> Tuple[List[torch.Tensor], int]:
+ """
+ Args:
+ images (torch.Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1].
+ B: batch size, S: sequence length, 3: RGB channels, H: height, W: width
+
+ Returns:
+ (list[torch.Tensor], int):
+ The list of outputs from the attention blocks,
+ and the patch_start_idx indicating where patch tokens begin.
+ """
+ B, S, C_in, H, W = images.shape
+
+ if C_in != 3:
+ raise ValueError(f"Expected 3 input channels, got {C_in}")
+
+ # Normalize images and reshape for patch embed
+ images = (images - self._resnet_mean) / self._resnet_std
+
+ # Reshape to [B*S, C, H, W] for patch embedding
+ images = images.view(B * S, C_in, H, W)
+ patch_tokens = self.patch_embed(images)
+
+ if isinstance(patch_tokens, dict):
+ patch_tokens = patch_tokens["x_norm_patchtokens"]
+
+ _, P, C = patch_tokens.shape
+
+ # Expand camera and register tokens to match batch size and sequence length
+ camera_token = slice_expand_and_flatten(self.camera_token, B, S)
+ register_token = slice_expand_and_flatten(self.register_token, B, S)
+
+ # Concatenate special tokens with patch tokens
+ tokens = torch.cat([camera_token, register_token, patch_tokens], dim=1)
+
+ pos = None
+ if self.rope is not None:
+ pos = self.position_getter(B * S, H // self.patch_size, W // self.patch_size, device=images.device)
+
+ if self.patch_start_idx > 0:
+ # do not use position embedding for special tokens (camera and register tokens)
+ # so set pos to 0 for the special tokens
+ pos = pos + 1
+ pos_special = torch.zeros(B * S, self.patch_start_idx, 2).to(images.device).to(pos.dtype)
+ pos = torch.cat([pos_special, pos], dim=1)
+
+ # update P because we added special tokens
+ _, P, C = tokens.shape
+
+ frame_idx = 0
+ global_idx = 0
+ output_list = []
+
+ for _ in range(self.aa_block_num):
+ for attn_type in self.aa_order:
+ if attn_type == "frame":
+ tokens, frame_idx, frame_intermediates = self._process_frame_attention(
+ tokens, B, S, P, C, frame_idx, pos=pos
+ )
+ elif attn_type == "global":
+ tokens, global_idx, global_intermediates = self._process_global_attention(
+ tokens, B, S, P, C, global_idx, pos=pos
+ )
+ else:
+ raise ValueError(f"Unknown attention type: {attn_type}")
+
+ for i in range(len(frame_intermediates)):
+ # concat frame and global intermediates, [B x S x P x 2C]
+ concat_inter = torch.cat([frame_intermediates[i], global_intermediates[i]], dim=-1)
+ output_list.append(concat_inter)
+
+ del concat_inter
+ del frame_intermediates
+ del global_intermediates
+ return output_list, self.patch_start_idx
+
+ def _process_frame_attention(self, tokens, B, S, P, C, frame_idx, pos=None):
+ """
+ Process frame attention blocks. We keep tokens in shape (B*S, P, C).
+ """
+ # If needed, reshape tokens or positions:
+ if tokens.shape != (B * S, P, C):
+ tokens = tokens.view(B, S, P, C).view(B * S, P, C)
+
+ if pos is not None and pos.shape != (B * S, P, 2):
+ pos = pos.view(B, S, P, 2).view(B * S, P, 2)
+
+ intermediates = []
+
+ # by default, self.aa_block_size=1, which processes one block at a time
+ for _ in range(self.aa_block_size):
+ if self.training:
+ tokens = checkpoint(self.frame_blocks[frame_idx], tokens, pos, use_reentrant=self.use_reentrant)
+ else:
+ tokens = self.frame_blocks[frame_idx](tokens, pos=pos)
+ frame_idx += 1
+ intermediates.append(tokens.view(B, S, P, C))
+
+ return tokens, frame_idx, intermediates
+
+ def _process_global_attention(self, tokens, B, S, P, C, global_idx, pos=None):
+ """
+ Process global attention blocks. We keep tokens in shape (B, S*P, C).
+ """
+ if tokens.shape != (B, S * P, C):
+ tokens = tokens.view(B, S, P, C).view(B, S * P, C)
+
+ if pos is not None and pos.shape != (B, S * P, 2):
+ pos = pos.view(B, S, P, 2).view(B, S * P, 2)
+
+ intermediates = []
+
+ # by default, self.aa_block_size=1, which processes one block at a time
+ for _ in range(self.aa_block_size):
+ if self.training:
+ tokens = checkpoint(self.global_blocks[global_idx], tokens, pos, use_reentrant=self.use_reentrant)
+ else:
+ tokens = self.global_blocks[global_idx](tokens, pos=pos)
+ global_idx += 1
+ intermediates.append(tokens.view(B, S, P, C))
+
+ return tokens, global_idx, intermediates
+
+
+def slice_expand_and_flatten(token_tensor, B, S):
+ """
+ Processes specialized tokens with shape (1, 2, X, C) for multi-frame processing:
+ 1) Uses the first position (index=0) for the first frame only
+ 2) Uses the second position (index=1) for all remaining frames (S-1 frames)
+ 3) Expands both to match batch size B
+ 4) Concatenates to form (B, S, X, C) where each sequence has 1 first-position token
+ followed by (S-1) second-position tokens
+ 5) Flattens to (B*S, X, C) for processing
+
+ Returns:
+ torch.Tensor: Processed tokens with shape (B*S, X, C)
+ """
+
+ # Slice out the "query" tokens => shape (1, 1, ...)
+ query = token_tensor[:, 0:1, ...].expand(B, 1, *token_tensor.shape[2:])
+ # Slice out the "other" tokens => shape (1, S-1, ...)
+ others = token_tensor[:, 1:, ...].expand(B, S - 1, *token_tensor.shape[2:])
+ # Concatenate => shape (B, S, ...)
+ combined = torch.cat([query, others], dim=1)
+
+ # Finally flatten => shape (B*S, ...)
+ combined = combined.view(B * S, *combined.shape[2:])
+ return combined
diff --git a/libs/vggt/models/vggt.py b/libs/vggt/models/vggt.py
new file mode 100644
index 0000000000000000000000000000000000000000..686e6f9d3f9e37769c195258f429a66d927375c0
--- /dev/null
+++ b/libs/vggt/models/vggt.py
@@ -0,0 +1,97 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+from huggingface_hub import PyTorchModelHubMixin # used for model hub
+
+from vggt.models.aggregator import Aggregator
+from vggt.heads.camera_head import CameraHead
+from vggt.heads.dpt_head import DPTHead
+from vggt.heads.track_head import TrackHead
+
+
+class VGGT(nn.Module, PyTorchModelHubMixin):
+ def __init__(self, img_size=518, patch_size=14, embed_dim=1024,
+ enable_camera=True, enable_point=True, enable_depth=True, enable_track=True):
+ super().__init__()
+
+ self.aggregator = Aggregator(img_size=img_size, patch_size=patch_size, embed_dim=embed_dim)
+
+ self.camera_head = CameraHead(dim_in=2 * embed_dim) if enable_camera else None
+ self.point_head = DPTHead(dim_in=2 * embed_dim, output_dim=4, activation="inv_log", conf_activation="expp1") if enable_point else None
+ self.depth_head = DPTHead(dim_in=2 * embed_dim, output_dim=2, activation="exp", conf_activation="expp1") if enable_depth else None
+ self.track_head = TrackHead(dim_in=2 * embed_dim, patch_size=patch_size) if enable_track else None
+
+ def forward(self, images: torch.Tensor, query_points: torch.Tensor = None):
+ """
+ Forward pass of the VGGT model.
+
+ Args:
+ images (torch.Tensor): Input images with shape [S, 3, H, W] or [B, S, 3, H, W], in range [0, 1].
+ B: batch size, S: sequence length, 3: RGB channels, H: height, W: width
+ query_points (torch.Tensor, optional): Query points for tracking, in pixel coordinates.
+ Shape: [N, 2] or [B, N, 2], where N is the number of query points.
+ Default: None
+
+ Returns:
+ dict: A dictionary containing the following predictions:
+ - pose_enc (torch.Tensor): Camera pose encoding with shape [B, S, 9] (from the last iteration)
+ - depth (torch.Tensor): Predicted depth maps with shape [B, S, H, W, 1]
+ - depth_conf (torch.Tensor): Confidence scores for depth predictions with shape [B, S, H, W]
+ - world_points (torch.Tensor): 3D world coordinates for each pixel with shape [B, S, H, W, 3]
+ - world_points_conf (torch.Tensor): Confidence scores for world points with shape [B, S, H, W]
+ - images (torch.Tensor): Original input images, preserved for visualization
+
+ If query_points is provided, also includes:
+ - track (torch.Tensor): Point tracks with shape [B, S, N, 2] (from the last iteration), in pixel coordinates
+ - vis (torch.Tensor): Visibility scores for tracked points with shape [B, S, N]
+ - conf (torch.Tensor): Confidence scores for tracked points with shape [B, S, N]
+ """
+ # If without batch dimension, add it
+ if len(images.shape) == 4:
+ images = images.unsqueeze(0)
+
+ if query_points is not None and len(query_points.shape) == 2:
+ query_points = query_points.unsqueeze(0)
+
+ aggregated_tokens_list, patch_start_idx = self.aggregator(images)
+
+ predictions = {}
+
+ with torch.cuda.amp.autocast(enabled=False):
+ if self.camera_head is not None:
+ pose_enc_list = self.camera_head(aggregated_tokens_list)
+ predictions["pose_enc"] = pose_enc_list[-1] # pose encoding of the last iteration
+ predictions["pose_enc_list"] = pose_enc_list
+
+ if self.depth_head is not None:
+ depth, depth_conf = self.depth_head(
+ aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx
+ )
+ predictions["depth"] = depth
+ predictions["depth_conf"] = depth_conf
+
+ if self.point_head is not None:
+ pts3d, pts3d_conf = self.point_head(
+ aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx
+ )
+ predictions["world_points"] = pts3d
+ predictions["world_points_conf"] = pts3d_conf
+
+ if self.track_head is not None and query_points is not None:
+ track_list, vis, conf = self.track_head(
+ aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx, query_points=query_points
+ )
+ predictions["track"] = track_list[-1] # track of the last iteration
+ predictions["vis"] = vis
+ predictions["conf"] = conf
+
+ if not self.training:
+ predictions["images"] = images # store the images for visualization during inference
+
+ return predictions
+
diff --git a/libs/vggt/utils/geometry.py b/libs/vggt/utils/geometry.py
new file mode 100644
index 0000000000000000000000000000000000000000..f555516dbc8a7dd8c7b15e6fbc928a5bfe8f740b
--- /dev/null
+++ b/libs/vggt/utils/geometry.py
@@ -0,0 +1,324 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import torch
+import numpy as np
+
+
+from vggt.dependency.distortion import apply_distortion, iterative_undistortion, single_undistortion
+
+
+def unproject_depth_map_to_point_map(
+ depth_map: np.ndarray, extrinsics_cam: np.ndarray, intrinsics_cam: np.ndarray
+) -> np.ndarray:
+ """
+ Unproject a batch of depth maps to 3D world coordinates.
+
+ Args:
+ depth_map (np.ndarray): Batch of depth maps of shape (S, H, W, 1) or (S, H, W)
+ extrinsics_cam (np.ndarray): Batch of camera extrinsic matrices of shape (S, 3, 4)
+ intrinsics_cam (np.ndarray): Batch of camera intrinsic matrices of shape (S, 3, 3)
+
+ Returns:
+ np.ndarray: Batch of 3D world coordinates of shape (S, H, W, 3)
+ """
+ if isinstance(depth_map, torch.Tensor):
+ depth_map = depth_map.cpu().numpy()
+ if isinstance(extrinsics_cam, torch.Tensor):
+ extrinsics_cam = extrinsics_cam.cpu().numpy()
+ if isinstance(intrinsics_cam, torch.Tensor):
+ intrinsics_cam = intrinsics_cam.cpu().numpy()
+
+ world_points_list = []
+ for frame_idx in range(depth_map.shape[0]):
+ cur_world_points, _, _ = depth_to_world_coords_points(
+ depth_map[frame_idx].squeeze(-1), extrinsics_cam[frame_idx], intrinsics_cam[frame_idx]
+ )
+ world_points_list.append(cur_world_points)
+ world_points_array = np.stack(world_points_list, axis=0)
+
+ return world_points_array
+
+
+def depth_to_world_coords_points(
+ depth_map: np.ndarray,
+ extrinsic: np.ndarray,
+ intrinsic: np.ndarray,
+ eps=1e-8,
+) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
+ """
+ Convert a depth map to world coordinates.
+
+ Args:
+ depth_map (np.ndarray): Depth map of shape (H, W).
+ intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3).
+ extrinsic (np.ndarray): Camera extrinsic matrix of shape (3, 4). OpenCV camera coordinate convention, cam from world.
+
+ Returns:
+ tuple[np.ndarray, np.ndarray]: World coordinates (H, W, 3) and valid depth mask (H, W).
+ """
+ if depth_map is None:
+ return None, None, None
+
+ # Valid depth mask
+ point_mask = depth_map > eps
+
+ # Convert depth map to camera coordinates
+ cam_coords_points = depth_to_cam_coords_points(depth_map, intrinsic)
+
+ # Multiply with the inverse of extrinsic matrix to transform to world coordinates
+ # extrinsic_inv is 4x4 (note closed_form_inverse_OpenCV is batched, the output is (N, 4, 4))
+ cam_to_world_extrinsic = closed_form_inverse_se3(extrinsic[None])[0]
+
+ R_cam_to_world = cam_to_world_extrinsic[:3, :3]
+ t_cam_to_world = cam_to_world_extrinsic[:3, 3]
+
+ # Apply the rotation and translation to the camera coordinates
+ world_coords_points = np.dot(cam_coords_points, R_cam_to_world.T) + t_cam_to_world # HxWx3, 3x3 -> HxWx3
+ # world_coords_points = np.einsum("ij,hwj->hwi", R_cam_to_world, cam_coords_points) + t_cam_to_world
+
+ return world_coords_points, cam_coords_points, point_mask
+
+
+def depth_to_cam_coords_points(depth_map: np.ndarray, intrinsic: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
+ """
+ Convert a depth map to camera coordinates.
+
+ Args:
+ depth_map (np.ndarray): Depth map of shape (H, W).
+ intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3).
+
+ Returns:
+ tuple[np.ndarray, np.ndarray]: Camera coordinates (H, W, 3)
+ """
+ H, W = depth_map.shape
+ assert intrinsic.shape == (3, 3), "Intrinsic matrix must be 3x3"
+ assert intrinsic[0, 1] == 0 and intrinsic[1, 0] == 0, "Intrinsic matrix must have zero skew"
+
+ # Intrinsic parameters
+ fu, fv = intrinsic[0, 0], intrinsic[1, 1]
+ cu, cv = intrinsic[0, 2], intrinsic[1, 2]
+
+ # Generate grid of pixel coordinates
+ u, v = np.meshgrid(np.arange(W), np.arange(H))
+
+ # Unproject to camera coordinates
+ x_cam = (u - cu) * depth_map / fu
+ y_cam = (v - cv) * depth_map / fv
+ z_cam = depth_map
+
+ # Stack to form camera coordinates
+ cam_coords = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32)
+
+ return cam_coords
+
+
+def closed_form_inverse_se3(se3, R=None, T=None):
+ """
+ Compute the inverse of each 4x4 (or 3x4) SE3 matrix in a batch.
+
+ If `R` and `T` are provided, they must correspond to the rotation and translation
+ components of `se3`. Otherwise, they will be extracted from `se3`.
+
+ Args:
+ se3: Nx4x4 or Nx3x4 array or tensor of SE3 matrices.
+ R (optional): Nx3x3 array or tensor of rotation matrices.
+ T (optional): Nx3x1 array or tensor of translation vectors.
+
+ Returns:
+ Inverted SE3 matrices with the same type and device as `se3`.
+
+ Shapes:
+ se3: (N, 4, 4)
+ R: (N, 3, 3)
+ T: (N, 3, 1)
+ """
+ # Check if se3 is a numpy array or a torch tensor
+ is_numpy = isinstance(se3, np.ndarray)
+
+ # Validate shapes
+ if se3.shape[-2:] != (4, 4) and se3.shape[-2:] != (3, 4):
+ raise ValueError(f"se3 must be of shape (N,4,4), got {se3.shape}.")
+
+ # Extract R and T if not provided
+ if R is None:
+ R = se3[:, :3, :3] # (N,3,3)
+ if T is None:
+ T = se3[:, :3, 3:] # (N,3,1)
+
+ # Transpose R
+ if is_numpy:
+ # Compute the transpose of the rotation for NumPy
+ R_transposed = np.transpose(R, (0, 2, 1))
+ # -R^T t for NumPy
+ top_right = -np.matmul(R_transposed, T)
+ inverted_matrix = np.tile(np.eye(4), (len(R), 1, 1))
+ else:
+ R_transposed = R.transpose(1, 2) # (N,3,3)
+ top_right = -torch.bmm(R_transposed, T) # (N,3,1)
+ inverted_matrix = torch.eye(4, 4)[None].repeat(len(R), 1, 1)
+ inverted_matrix = inverted_matrix.to(R.dtype).to(R.device)
+
+ inverted_matrix[:, :3, :3] = R_transposed
+ inverted_matrix[:, :3, 3:] = top_right
+
+ return inverted_matrix
+
+
+# TODO: this code can be further cleaned up
+
+
+def project_world_points_to_camera_points_batch(world_points, cam_extrinsics):
+ """
+ Transforms 3D points to 2D using extrinsic and intrinsic parameters.
+ Args:
+ world_points (torch.Tensor): 3D points of shape BxSxHxWx3.
+ cam_extrinsics (torch.Tensor): Extrinsic parameters of shape BxSx3x4.
+ Returns:
+ """
+ # TODO: merge this into project_world_points_to_cam
+
+ # device = world_points.device
+ # with torch.autocast(device_type=device.type, enabled=False):
+ ones = torch.ones_like(world_points[..., :1]) # shape: (B, S, H, W, 1)
+ world_points_h = torch.cat([world_points, ones], dim=-1) # shape: (B, S, H, W, 4)
+
+ # extrinsics: (B, S, 3, 4) -> (B, S, 1, 1, 3, 4)
+ extrinsics_exp = cam_extrinsics.unsqueeze(2).unsqueeze(3)
+
+ # world_points_h: (B, S, H, W, 4) -> (B, S, H, W, 4, 1)
+ world_points_h_exp = world_points_h.unsqueeze(-1)
+
+ # Now perform the matrix multiplication
+ # (B, S, 1, 1, 3, 4) @ (B, S, H, W, 4, 1) broadcasts to (B, S, H, W, 3, 1)
+ camera_points = torch.matmul(extrinsics_exp, world_points_h_exp).squeeze(-1)
+
+ return camera_points
+
+
+
+def project_world_points_to_cam(
+ world_points,
+ cam_extrinsics,
+ cam_intrinsics=None,
+ distortion_params=None,
+ default=0,
+ only_points_cam=False,
+):
+ """
+ Transforms 3D points to 2D using extrinsic and intrinsic parameters.
+ Args:
+ world_points (torch.Tensor): 3D points of shape Px3.
+ cam_extrinsics (torch.Tensor): Extrinsic parameters of shape Bx3x4.
+ cam_intrinsics (torch.Tensor): Intrinsic parameters of shape Bx3x3.
+ distortion_params (torch.Tensor): Extra parameters of shape BxN, which is used for radial distortion.
+ Returns:
+ torch.Tensor: Transformed 2D points of shape BxNx2.
+ """
+ device = world_points.device
+ # with torch.autocast(device_type=device.type, dtype=torch.double):
+ with torch.autocast(device_type=device.type, enabled=False):
+ N = world_points.shape[0] # Number of points
+ B = cam_extrinsics.shape[0] # Batch size, i.e., number of cameras
+ world_points_homogeneous = torch.cat(
+ [world_points, torch.ones_like(world_points[..., 0:1])], dim=1
+ ) # Nx4
+ # Reshape for batch processing
+ world_points_homogeneous = world_points_homogeneous.unsqueeze(0).expand(
+ B, -1, -1
+ ) # BxNx4
+
+ # Step 1: Apply extrinsic parameters
+ # Transform 3D points to camera coordinate system for all cameras
+ cam_points = torch.bmm(
+ cam_extrinsics, world_points_homogeneous.transpose(-1, -2)
+ )
+
+ if only_points_cam:
+ return None, cam_points
+
+ # Step 2: Apply intrinsic parameters and (optional) distortion
+ image_points = img_from_cam(cam_intrinsics, cam_points, distortion_params, default=default)
+
+ return image_points, cam_points
+
+
+
+def img_from_cam(cam_intrinsics, cam_points, distortion_params=None, default=0.0):
+ """
+ Applies intrinsic parameters and optional distortion to the given 3D points.
+
+ Args:
+ cam_intrinsics (torch.Tensor): Intrinsic camera parameters of shape Bx3x3.
+ cam_points (torch.Tensor): 3D points in camera coordinates of shape Bx3xN.
+ distortion_params (torch.Tensor, optional): Distortion parameters of shape BxN, where N can be 1, 2, or 4.
+ default (float, optional): Default value to replace NaNs in the output.
+
+ Returns:
+ pixel_coords (torch.Tensor): 2D points in pixel coordinates of shape BxNx2.
+ """
+
+ # Normalized device coordinates (NDC)
+ cam_points = cam_points / cam_points[:, 2:3, :]
+ ndc_xy = cam_points[:, :2, :]
+
+ # Apply distortion if distortion_params are provided
+ if distortion_params is not None:
+ x_distorted, y_distorted = apply_distortion(distortion_params, ndc_xy[:, 0], ndc_xy[:, 1])
+ distorted_xy = torch.stack([x_distorted, y_distorted], dim=1)
+ else:
+ distorted_xy = ndc_xy
+
+ # Prepare cam_points for batch matrix multiplication
+ cam_coords_homo = torch.cat(
+ (distorted_xy, torch.ones_like(distorted_xy[:, :1, :])), dim=1
+ ) # Bx3xN
+ # Apply intrinsic parameters using batch matrix multiplication
+ pixel_coords = torch.bmm(cam_intrinsics, cam_coords_homo) # Bx3xN
+
+ # Extract x and y coordinates
+ pixel_coords = pixel_coords[:, :2, :] # Bx2xN
+
+ # Replace NaNs with default value
+ pixel_coords = torch.nan_to_num(pixel_coords, nan=default)
+
+ return pixel_coords.transpose(1, 2) # BxNx2
+
+
+
+
+def cam_from_img(pred_tracks, intrinsics, extra_params=None):
+ """
+ Normalize predicted tracks based on camera intrinsics.
+ Args:
+ intrinsics (torch.Tensor): The camera intrinsics tensor of shape [batch_size, 3, 3].
+ pred_tracks (torch.Tensor): The predicted tracks tensor of shape [batch_size, num_tracks, 2].
+ extra_params (torch.Tensor, optional): Distortion parameters of shape BxN, where N can be 1, 2, or 4.
+ Returns:
+ torch.Tensor: Normalized tracks tensor.
+ """
+
+ # We don't want to do intrinsics_inv = torch.inverse(intrinsics) here
+ # otherwise we can use something like
+ # tracks_normalized_homo = torch.bmm(pred_tracks_homo, intrinsics_inv.transpose(1, 2))
+
+ principal_point = intrinsics[:, [0, 1], [2, 2]].unsqueeze(-2)
+ focal_length = intrinsics[:, [0, 1], [0, 1]].unsqueeze(-2)
+ tracks_normalized = (pred_tracks - principal_point) / focal_length
+
+ if extra_params is not None:
+ # Apply iterative undistortion
+ try:
+ tracks_normalized = iterative_undistortion(
+ extra_params, tracks_normalized
+ )
+ except:
+ tracks_normalized = single_undistortion(
+ extra_params, tracks_normalized
+ )
+
+ return tracks_normalized
\ No newline at end of file
diff --git a/libs/vggt/utils/helper.py b/libs/vggt/utils/helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b019189c85ff86645a4cf3756632aa8d4500649
--- /dev/null
+++ b/libs/vggt/utils/helper.py
@@ -0,0 +1,60 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import numpy as np
+
+
+def randomly_limit_trues(mask: np.ndarray, max_trues: int) -> np.ndarray:
+ """
+ If mask has more than max_trues True values,
+ randomly keep only max_trues of them and set the rest to False.
+ """
+ # 1D positions of all True entries
+ true_indices = np.flatnonzero(mask) # shape = (N_true,)
+
+ # if already within budget, return as-is
+ if true_indices.size <= max_trues:
+ return mask
+
+ # randomly pick which True positions to keep
+ sampled_indices = np.random.choice(true_indices, size=max_trues, replace=False) # shape = (max_trues,)
+
+ # build new flat mask: True only at sampled positions
+ limited_flat_mask = np.zeros(mask.size, dtype=bool)
+ limited_flat_mask[sampled_indices] = True
+
+ # restore original shape
+ return limited_flat_mask.reshape(mask.shape)
+
+
+def create_pixel_coordinate_grid(num_frames, height, width):
+ """
+ Creates a grid of pixel coordinates and frame indices for all frames.
+ Returns:
+ tuple: A tuple containing:
+ - points_xyf (numpy.ndarray): Array of shape (num_frames, height, width, 3)
+ with x, y coordinates and frame indices
+ - y_coords (numpy.ndarray): Array of y coordinates for all frames
+ - x_coords (numpy.ndarray): Array of x coordinates for all frames
+ - f_coords (numpy.ndarray): Array of frame indices for all frames
+ """
+ # Create coordinate grids for a single frame
+ y_grid, x_grid = np.indices((height, width), dtype=np.float32)
+ x_grid = x_grid[np.newaxis, :, :]
+ y_grid = y_grid[np.newaxis, :, :]
+
+ # Broadcast to all frames
+ x_coords = np.broadcast_to(x_grid, (num_frames, height, width))
+ y_coords = np.broadcast_to(y_grid, (num_frames, height, width))
+
+ # Create frame indices and broadcast
+ f_idx = np.arange(num_frames, dtype=np.float32)[:, np.newaxis, np.newaxis]
+ f_coords = np.broadcast_to(f_idx, (num_frames, height, width))
+
+ # Stack coordinates and frame indices
+ points_xyf = np.stack((x_coords, y_coords, f_coords), axis=-1)
+
+ return points_xyf
diff --git a/libs/vggt/utils/load_fn.py b/libs/vggt/utils/load_fn.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d223aabdc43ac644c1b8ca376e8fec59decd084
--- /dev/null
+++ b/libs/vggt/utils/load_fn.py
@@ -0,0 +1,230 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+from PIL import Image
+from torchvision import transforms as TF
+import numpy as np
+
+
+def load_and_preprocess_images_square(image_path_list, target_size=1024):
+ """
+ Load and preprocess images by center padding to square and resizing to target size.
+ Also returns the position information of original pixels after transformation.
+
+ Args:
+ image_path_list (list): List of paths to image files
+ target_size (int, optional): Target size for both width and height. Defaults to 518.
+
+ Returns:
+ tuple: (
+ torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, target_size, target_size),
+ torch.Tensor: Array of shape (N, 5) containing [x1, y1, x2, y2, width, height] for each image
+ )
+
+ Raises:
+ ValueError: If the input list is empty
+ """
+ # Check for empty list
+ if len(image_path_list) == 0:
+ raise ValueError("At least 1 image is required")
+
+ images = []
+ original_coords = [] # Renamed from position_info to be more descriptive
+ to_tensor = TF.ToTensor()
+
+ for image_path in image_path_list:
+ # Open image
+ img = Image.open(image_path)
+
+ # If there's an alpha channel, blend onto white background
+ if img.mode == "RGBA":
+ background = Image.new("RGBA", img.size, (255, 255, 255, 255))
+ img = Image.alpha_composite(background, img)
+
+ # Convert to RGB
+ img = img.convert("RGB")
+
+ # Get original dimensions
+ width, height = img.size
+
+ # Make the image square by padding the shorter dimension
+ max_dim = max(width, height)
+
+ # Calculate padding
+ left = (max_dim - width) // 2
+ top = (max_dim - height) // 2
+
+ # Calculate scale factor for resizing
+ scale = target_size / max_dim
+
+ # Calculate final coordinates of original image in target space
+ x1 = left * scale
+ y1 = top * scale
+ x2 = (left + width) * scale
+ y2 = (top + height) * scale
+
+ # Store original image coordinates and scale
+ original_coords.append(np.array([x1, y1, x2, y2, width, height]))
+
+ # Create a new black square image and paste original
+ square_img = Image.new("RGB", (max_dim, max_dim), (0, 0, 0))
+ square_img.paste(img, (left, top))
+
+ # Resize to target size
+ square_img = square_img.resize((target_size, target_size), Image.Resampling.BICUBIC)
+
+ # Convert to tensor
+ img_tensor = to_tensor(square_img)
+ images.append(img_tensor)
+
+ # Stack all images
+ images = torch.stack(images)
+ original_coords = torch.from_numpy(np.array(original_coords)).float()
+
+ # Add additional dimension if single image to ensure correct shape
+ if len(image_path_list) == 1:
+ if images.dim() == 3:
+ images = images.unsqueeze(0)
+ original_coords = original_coords.unsqueeze(0)
+
+ return images, original_coords
+
+
+def load_and_preprocess_images(image_path_list, mode="crop"):
+ """
+ A quick start function to load and preprocess images for model input.
+ This assumes the images should have the same shape for easier batching, but our model can also work well with different shapes.
+
+ Args:
+ image_path_list (list): List of paths to image files
+ mode (str, optional): Preprocessing mode, either "crop" or "pad".
+ - "crop" (default): Sets width to 518px and center crops height if needed.
+ - "pad": Preserves all pixels by making the largest dimension 518px
+ and padding the smaller dimension to reach a square shape.
+
+ Returns:
+ torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, H, W)
+
+ Raises:
+ ValueError: If the input list is empty or if mode is invalid
+
+ Notes:
+ - Images with different dimensions will be padded with white (value=1.0)
+ - A warning is printed when images have different shapes
+ - When mode="crop": The function ensures width=518px while maintaining aspect ratio
+ and height is center-cropped if larger than 518px
+ - When mode="pad": The function ensures the largest dimension is 518px while maintaining aspect ratio
+ and the smaller dimension is padded to reach a square shape (518x518)
+ - Dimensions are adjusted to be divisible by 14 for compatibility with model requirements
+ """
+ # Check for empty list
+ if len(image_path_list) == 0:
+ raise ValueError("At least 1 image is required")
+
+ # Validate mode
+ if mode not in ["crop", "pad"]:
+ raise ValueError("Mode must be either 'crop' or 'pad'")
+
+ images = []
+ shapes = set()
+ to_tensor = TF.ToTensor()
+ target_size = 518
+
+ # First process all images and collect their shapes
+ for image_path in image_path_list:
+ # Open image
+ img = Image.open(image_path)
+
+ # If there's an alpha channel, blend onto white background:
+ if img.mode == "RGBA":
+ # Create white background
+ background = Image.new("RGBA", img.size, (255, 255, 255, 255))
+ # Alpha composite onto the white background
+ img = Image.alpha_composite(background, img)
+
+ # Now convert to "RGB" (this step assigns white for transparent areas)
+ img = img.convert("RGB")
+
+ width, height = img.size
+
+ if mode == "pad":
+ # Make the largest dimension 518px while maintaining aspect ratio
+ if width >= height:
+ new_width = target_size
+ new_height = round(height * (new_width / width) / 14) * 14 # Make divisible by 14
+ else:
+ new_height = target_size
+ new_width = round(width * (new_height / height) / 14) * 14 # Make divisible by 14
+ else: # mode == "crop"
+ # Original behavior: set width to 518px
+ new_width = target_size
+ # Calculate height maintaining aspect ratio, divisible by 14
+ new_height = round(height * (new_width / width) / 14) * 14
+
+ # Resize with new dimensions (width, height)
+ img = img.resize((new_width, new_height), Image.Resampling.BICUBIC)
+ img = to_tensor(img) # Convert to tensor (0, 1)
+
+ # Center crop height if it's larger than 518 (only in crop mode)
+ if mode == "crop" and new_height > target_size:
+ start_y = (new_height - target_size) // 2
+ img = img[:, start_y : start_y + target_size, :]
+
+ # For pad mode, pad to make a square of target_size x target_size
+ if mode == "pad":
+ h_padding = target_size - img.shape[1]
+ w_padding = target_size - img.shape[2]
+
+ if h_padding > 0 or w_padding > 0:
+ pad_top = h_padding // 2
+ pad_bottom = h_padding - pad_top
+ pad_left = w_padding // 2
+ pad_right = w_padding - pad_left
+
+ # Pad with white (value=1.0)
+ img = torch.nn.functional.pad(
+ img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0
+ )
+
+ shapes.add((img.shape[1], img.shape[2]))
+ images.append(img)
+
+ # Check if we have different shapes
+ # In theory our model can also work well with different shapes
+ if len(shapes) > 1:
+ print(f"Warning: Found images with different shapes: {shapes}")
+ # Find maximum dimensions
+ max_height = max(shape[0] for shape in shapes)
+ max_width = max(shape[1] for shape in shapes)
+
+ # Pad images if necessary
+ padded_images = []
+ for img in images:
+ h_padding = max_height - img.shape[1]
+ w_padding = max_width - img.shape[2]
+
+ if h_padding > 0 or w_padding > 0:
+ pad_top = h_padding // 2
+ pad_bottom = h_padding - pad_top
+ pad_left = w_padding // 2
+ pad_right = w_padding - pad_left
+
+ img = torch.nn.functional.pad(
+ img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0
+ )
+ padded_images.append(img)
+ images = padded_images
+
+ images = torch.stack(images) # concatenate images
+
+ # Ensure correct shape when single image
+ if len(image_path_list) == 1:
+ # Verify shape is (1, C, H, W)
+ if images.dim() == 3:
+ images = images.unsqueeze(0)
+
+ return images
diff --git a/libs/vggt/utils/pose_enc.py b/libs/vggt/utils/pose_enc.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d3b964330af0e62f4d36d332317ae00cb99b3a9
--- /dev/null
+++ b/libs/vggt/utils/pose_enc.py
@@ -0,0 +1,124 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+from .rotation import quat_to_mat, mat_to_quat
+
+
+def extri_intri_to_pose_encoding(
+ extrinsics, intrinsics, image_size_hw=None, pose_encoding_type="absT_quaR_FoV" # e.g., (256, 512)
+):
+ """Convert camera extrinsics and intrinsics to a compact pose encoding.
+
+ This function transforms camera parameters into a unified pose encoding format,
+ which can be used for various downstream tasks like pose prediction or representation.
+
+ Args:
+ extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4,
+ where B is batch size and S is sequence length.
+ In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world transformation.
+ The format is [R|t] where R is a 3x3 rotation matrix and t is a 3x1 translation vector.
+ intrinsics (torch.Tensor): Camera intrinsic parameters with shape BxSx3x3.
+ Defined in pixels, with format:
+ [[fx, 0, cx],
+ [0, fy, cy],
+ [0, 0, 1]]
+ where fx, fy are focal lengths and (cx, cy) is the principal point
+ image_size_hw (tuple): Tuple of (height, width) of the image in pixels.
+ Required for computing field of view values. For example: (256, 512).
+ pose_encoding_type (str): Type of pose encoding to use. Currently only
+ supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view).
+
+ Returns:
+ torch.Tensor: Encoded camera pose parameters with shape BxSx9.
+ For "absT_quaR_FoV" type, the 9 dimensions are:
+ - [:3] = absolute translation vector T (3D)
+ - [3:7] = rotation as quaternion quat (4D)
+ - [7:] = field of view (2D)
+ """
+
+ # extrinsics: BxSx3x4
+ # intrinsics: BxSx3x3
+
+ if pose_encoding_type == "absT_quaR_FoV":
+ R = extrinsics[:, :, :3, :3] # BxSx3x3
+ T = extrinsics[:, :, :3, 3] # BxSx3
+
+ quat = mat_to_quat(R)
+ # Note the order of h and w here
+ H, W = image_size_hw
+ fov_h = 2 * torch.atan((H / 2) / intrinsics[..., 1, 1])
+ fov_w = 2 * torch.atan((W / 2) / intrinsics[..., 0, 0])
+ pose_encoding = torch.cat([T, quat, fov_h[..., None], fov_w[..., None]], dim=-1).float()
+ else:
+ raise NotImplementedError
+
+ return pose_encoding
+
+
+def pose_encoding_to_extri_intri(
+ pose_encoding, image_size_hw=None, pose_encoding_type="absT_quaR_FoV", build_intrinsics=True # e.g., (256, 512)
+):
+ """Convert a pose encoding back to camera extrinsics and intrinsics.
+
+ This function performs the inverse operation of extri_intri_to_pose_encoding,
+ reconstructing the full camera parameters from the compact encoding.
+
+ Args:
+ pose_encoding (torch.Tensor): Encoded camera pose parameters with shape BxSx9,
+ where B is batch size and S is sequence length.
+ For "absT_quaR_FoV" type, the 9 dimensions are:
+ - [:3] = absolute translation vector T (3D)
+ - [3:7] = rotation as quaternion quat (4D)
+ - [7:] = field of view (2D)
+ image_size_hw (tuple): Tuple of (height, width) of the image in pixels.
+ Required for reconstructing intrinsics from field of view values.
+ For example: (256, 512).
+ pose_encoding_type (str): Type of pose encoding used. Currently only
+ supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view).
+ build_intrinsics (bool): Whether to reconstruct the intrinsics matrix.
+ If False, only extrinsics are returned and intrinsics will be None.
+
+ Returns:
+ tuple: (extrinsics, intrinsics)
+ - extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4.
+ In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world
+ transformation. The format is [R|t] where R is a 3x3 rotation matrix and t is
+ a 3x1 translation vector.
+ - intrinsics (torch.Tensor or None): Camera intrinsic parameters with shape BxSx3x3,
+ or None if build_intrinsics is False. Defined in pixels, with format:
+ [[fx, 0, cx],
+ [0, fy, cy],
+ [0, 0, 1]]
+ where fx, fy are focal lengths and (cx, cy) is the principal point,
+ assumed to be at the center of the image (W/2, H/2).
+ """
+
+ intrinsics = None
+
+ if pose_encoding_type == "absT_quaR_FoV":
+ T = pose_encoding[..., :3]
+ quat = pose_encoding[..., 3:7]
+ fov_h = pose_encoding[..., 7]
+ fov_w = pose_encoding[..., 8]
+
+ R = quat_to_mat(quat)
+ extrinsics = torch.cat([R, T[..., None]], dim=-1)
+
+ if build_intrinsics:
+ H, W = image_size_hw
+ fy = (H / 2.0) / torch.tan(fov_h / 2.0)
+ fx = (W / 2.0) / torch.tan(fov_w / 2.0)
+ intrinsics = torch.zeros(pose_encoding.shape[:2] + (3, 3), device=pose_encoding.device)
+ intrinsics[..., 0, 0] = fx
+ intrinsics[..., 1, 1] = fy
+ intrinsics[..., 0, 2] = W / 2
+ intrinsics[..., 1, 2] = H / 2
+ intrinsics[..., 2, 2] = 1.0 # Set the homogeneous coordinate to 1
+ else:
+ raise NotImplementedError
+
+ return extrinsics, intrinsics
diff --git a/libs/vggt/utils/rotation.py b/libs/vggt/utils/rotation.py
new file mode 100644
index 0000000000000000000000000000000000000000..f972afd8414c82fa1e9ed231725fd3f9f6ebde77
--- /dev/null
+++ b/libs/vggt/utils/rotation.py
@@ -0,0 +1,132 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# Modified from PyTorch3D, https://github.com/facebookresearch/pytorch3d
+
+import torch
+import numpy as np
+import torch.nn.functional as F
+
+
+def quat_to_mat(quaternions: torch.Tensor) -> torch.Tensor:
+ """
+ Quaternion Order: XYZW or say ijkr, scalar-last
+
+ Convert rotations given as quaternions to rotation matrices.
+ Args:
+ quaternions: quaternions with real part last,
+ as tensor of shape (..., 4).
+
+ Returns:
+ Rotation matrices as tensor of shape (..., 3, 3).
+ """
+ i, j, k, r = torch.unbind(quaternions, -1)
+ # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
+ two_s = 2.0 / (quaternions * quaternions).sum(-1)
+
+ o = torch.stack(
+ (
+ 1 - two_s * (j * j + k * k),
+ two_s * (i * j - k * r),
+ two_s * (i * k + j * r),
+ two_s * (i * j + k * r),
+ 1 - two_s * (i * i + k * k),
+ two_s * (j * k - i * r),
+ two_s * (i * k - j * r),
+ two_s * (j * k + i * r),
+ 1 - two_s * (i * i + j * j),
+ ),
+ -1,
+ )
+ return o.reshape(quaternions.shape[:-1] + (3, 3))
+
+
+def mat_to_quat(matrix: torch.Tensor) -> torch.Tensor:
+ """
+ Convert rotations given as rotation matrices to quaternions.
+
+ Args:
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
+
+ Returns:
+ quaternions with real part last, as tensor of shape (..., 4).
+ Quaternion Order: XYZW or say ijkr, scalar-last
+ """
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
+ raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
+
+ batch_dim = matrix.shape[:-2]
+ m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(matrix.reshape(batch_dim + (9,)), dim=-1)
+
+ q_abs = _sqrt_positive_part(
+ torch.stack(
+ [1.0 + m00 + m11 + m22, 1.0 + m00 - m11 - m22, 1.0 - m00 + m11 - m22, 1.0 - m00 - m11 + m22], dim=-1
+ )
+ )
+
+ # we produce the desired quaternion multiplied by each of r, i, j, k
+ quat_by_rijk = torch.stack(
+ [
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
+ # `int`.
+ torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
+ # `int`.
+ torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
+ # `int`.
+ torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
+ # `int`.
+ torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
+ ],
+ dim=-2,
+ )
+
+ # We floor here at 0.1 but the exact level is not important; if q_abs is small,
+ # the candidate won't be picked.
+ flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
+ quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
+
+ # if not for numerical problems, quat_candidates[i] should be same (up to a sign),
+ # forall i; we pick the best-conditioned one (with the largest denominator)
+ out = quat_candidates[F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape(batch_dim + (4,))
+
+ # Convert from rijk to ijkr
+ out = out[..., [1, 2, 3, 0]]
+
+ out = standardize_quaternion(out)
+
+ return out
+
+
+def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
+ """
+ Returns torch.sqrt(torch.max(0, x))
+ but with a zero subgradient where x is 0.
+ """
+ ret = torch.zeros_like(x)
+ positive_mask = x > 0
+ if torch.is_grad_enabled():
+ ret[positive_mask] = torch.sqrt(x[positive_mask])
+ else:
+ ret = torch.where(positive_mask, torch.sqrt(x), ret)
+ return ret
+
+
+def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
+ """
+ Convert a unit quaternion to a standard form: one in which the real
+ part is non negative.
+
+ Args:
+ quaternions: Quaternions with real part last,
+ as tensor of shape (..., 4).
+
+ Returns:
+ Standardized quaternions as tensor of shape (..., 4).
+ """
+ return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions)
diff --git a/libs/vggt/utils/visual_track.py b/libs/vggt/utils/visual_track.py
new file mode 100644
index 0000000000000000000000000000000000000000..796c114ccba00b5f7850e04b9444a6cd5c44b154
--- /dev/null
+++ b/libs/vggt/utils/visual_track.py
@@ -0,0 +1,239 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import cv2
+import torch
+import numpy as np
+import os
+
+
+def color_from_xy(x, y, W, H, cmap_name="hsv"):
+ """
+ Map (x, y) -> color in (R, G, B).
+ 1) Normalize x,y to [0,1].
+ 2) Combine them into a single scalar c in [0,1].
+ 3) Use matplotlib's colormap to convert c -> (R,G,B).
+
+ You can customize step 2, e.g., c = (x + y)/2, or some function of (x, y).
+ """
+ import matplotlib.cm
+ import matplotlib.colors
+
+ x_norm = x / max(W - 1, 1)
+ y_norm = y / max(H - 1, 1)
+ # Simple combination:
+ c = (x_norm + y_norm) / 2.0
+
+ cmap = matplotlib.cm.get_cmap(cmap_name)
+ # cmap(c) -> (r,g,b,a) in [0,1]
+ rgba = cmap(c)
+ r, g, b = rgba[0], rgba[1], rgba[2]
+ return (r, g, b) # in [0,1], RGB order
+
+
+def get_track_colors_by_position(tracks_b, vis_mask_b=None, image_width=None, image_height=None, cmap_name="hsv"):
+ """
+ Given all tracks in one sample (b), compute a (N,3) array of RGB color values
+ in [0,255]. The color is determined by the (x,y) position in the first
+ visible frame for each track.
+
+ Args:
+ tracks_b: Tensor of shape (S, N, 2). (x,y) for each track in each frame.
+ vis_mask_b: (S, N) boolean mask; if None, assume all are visible.
+ image_width, image_height: used for normalizing (x, y).
+ cmap_name: for matplotlib (e.g., 'hsv', 'rainbow', 'jet').
+
+ Returns:
+ track_colors: np.ndarray of shape (N, 3), each row is (R,G,B) in [0,255].
+ """
+ S, N, _ = tracks_b.shape
+ track_colors = np.zeros((N, 3), dtype=np.uint8)
+
+ if vis_mask_b is None:
+ # treat all as visible
+ vis_mask_b = torch.ones(S, N, dtype=torch.bool, device=tracks_b.device)
+
+ for i in range(N):
+ # Find first visible frame for track i
+ visible_frames = torch.where(vis_mask_b[:, i])[0]
+ if len(visible_frames) == 0:
+ # track is never visible; just assign black or something
+ track_colors[i] = (0, 0, 0)
+ continue
+
+ first_s = int(visible_frames[0].item())
+ # use that frame's (x,y)
+ x, y = tracks_b[first_s, i].tolist()
+
+ # map (x,y) -> (R,G,B) in [0,1]
+ r, g, b = color_from_xy(x, y, W=image_width, H=image_height, cmap_name=cmap_name)
+ # scale to [0,255]
+ r, g, b = int(r * 255), int(g * 255), int(b * 255)
+ track_colors[i] = (r, g, b)
+
+ return track_colors
+
+
+def visualize_tracks_on_images(
+ images,
+ tracks,
+ track_vis_mask=None,
+ out_dir="track_visuals_concat_by_xy",
+ image_format="CHW", # "CHW" or "HWC"
+ normalize_mode="[0,1]",
+ cmap_name="hsv", # e.g. "hsv", "rainbow", "jet"
+ frames_per_row=4, # New parameter for grid layout
+ save_grid=True, # Flag to control whether to save the grid image
+):
+ """
+ Visualizes frames in a grid layout with specified frames per row.
+ Each track's color is determined by its (x,y) position
+ in the first visible frame (or frame 0 if always visible).
+ Finally convert the BGR result to RGB before saving.
+ Also saves each individual frame as a separate PNG file.
+
+ Args:
+ images: torch.Tensor (S, 3, H, W) if CHW or (S, H, W, 3) if HWC.
+ tracks: torch.Tensor (S, N, 2), last dim = (x, y).
+ track_vis_mask: torch.Tensor (S, N) or None.
+ out_dir: folder to save visualizations.
+ image_format: "CHW" or "HWC".
+ normalize_mode: "[0,1]", "[-1,1]", or None for direct raw -> 0..255
+ cmap_name: a matplotlib colormap name for color_from_xy.
+ frames_per_row: number of frames to display in each row of the grid.
+ save_grid: whether to save all frames in one grid image.
+
+ Returns:
+ None (saves images in out_dir).
+ """
+
+ if len(tracks.shape) == 4:
+ tracks = tracks.squeeze(0)
+ images = images.squeeze(0)
+ if track_vis_mask is not None:
+ track_vis_mask = track_vis_mask.squeeze(0)
+
+ import matplotlib
+
+ matplotlib.use("Agg") # for non-interactive (optional)
+
+ os.makedirs(out_dir, exist_ok=True)
+
+ S = images.shape[0]
+ _, N, _ = tracks.shape # (S, N, 2)
+
+ # Move to CPU
+ images = images.cpu().clone()
+ tracks = tracks.cpu().clone()
+ if track_vis_mask is not None:
+ track_vis_mask = track_vis_mask.cpu().clone()
+
+ # Infer H, W from images shape
+ if image_format == "CHW":
+ # e.g. images[s].shape = (3, H, W)
+ H, W = images.shape[2], images.shape[3]
+ else:
+ # e.g. images[s].shape = (H, W, 3)
+ H, W = images.shape[1], images.shape[2]
+
+ # Pre-compute the color for each track i based on first visible position
+ track_colors_rgb = get_track_colors_by_position(
+ tracks, # shape (S, N, 2)
+ vis_mask_b=track_vis_mask if track_vis_mask is not None else None,
+ image_width=W,
+ image_height=H,
+ cmap_name=cmap_name,
+ )
+
+ # We'll accumulate each frame's drawn image in a list
+ frame_images = []
+
+ for s in range(S):
+ # shape => either (3, H, W) or (H, W, 3)
+ img = images[s]
+
+ # Convert to (H, W, 3)
+ if image_format == "CHW":
+ img = img.permute(1, 2, 0) # (H, W, 3)
+ # else "HWC", do nothing
+
+ img = img.numpy().astype(np.float32)
+
+ # Scale to [0,255] if needed
+ if normalize_mode == "[0,1]":
+ img = np.clip(img, 0, 1) * 255.0
+ elif normalize_mode == "[-1,1]":
+ img = (img + 1.0) * 0.5 * 255.0
+ img = np.clip(img, 0, 255.0)
+ # else no normalization
+
+ # Convert to uint8
+ img = img.astype(np.uint8)
+
+ # For drawing in OpenCV, convert to BGR
+ img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
+
+ # Draw each visible track
+ cur_tracks = tracks[s] # shape (N, 2)
+ if track_vis_mask is not None:
+ valid_indices = torch.where(track_vis_mask[s])[0]
+ else:
+ valid_indices = range(N)
+
+ cur_tracks_np = cur_tracks.numpy()
+ for i in valid_indices:
+ x, y = cur_tracks_np[i]
+ pt = (int(round(x)), int(round(y)))
+
+ # track_colors_rgb[i] is (R,G,B). For OpenCV circle, we need BGR
+ R, G, B = track_colors_rgb[i]
+ color_bgr = (int(B), int(G), int(R))
+ cv2.circle(img_bgr, pt, radius=3, color=color_bgr, thickness=-1)
+
+ # Convert back to RGB for consistent final saving:
+ img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
+
+ # Save individual frame
+ frame_path = os.path.join(out_dir, f"frame_{s:04d}.png")
+ # Convert to BGR for OpenCV imwrite
+ frame_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
+ cv2.imwrite(frame_path, frame_bgr)
+
+ frame_images.append(img_rgb)
+
+ # Only create and save the grid image if save_grid is True
+ if save_grid:
+ # Calculate grid dimensions
+ num_rows = (S + frames_per_row - 1) // frames_per_row # Ceiling division
+
+ # Create a grid of images
+ grid_img = None
+ for row in range(num_rows):
+ start_idx = row * frames_per_row
+ end_idx = min(start_idx + frames_per_row, S)
+
+ # Concatenate this row horizontally
+ row_img = np.concatenate(frame_images[start_idx:end_idx], axis=1)
+
+ # If this row has fewer than frames_per_row images, pad with black
+ if end_idx - start_idx < frames_per_row:
+ padding_width = (frames_per_row - (end_idx - start_idx)) * W
+ padding = np.zeros((H, padding_width, 3), dtype=np.uint8)
+ row_img = np.concatenate([row_img, padding], axis=1)
+
+ # Add this row to the grid
+ if grid_img is None:
+ grid_img = row_img
+ else:
+ grid_img = np.concatenate([grid_img, row_img], axis=0)
+
+ out_path = os.path.join(out_dir, "tracks_grid.png")
+ # Convert back to BGR for OpenCV imwrite
+ grid_img_bgr = cv2.cvtColor(grid_img, cv2.COLOR_RGB2BGR)
+ cv2.imwrite(out_path, grid_img_bgr)
+ print(f"[INFO] Saved color-by-XY track visualization grid -> {out_path}")
+
+ print(f"[INFO] Saved {S} individual frames to {out_dir}/frame_*.png")
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..993d44bc80a7104a1973ad6fef3d0201243c6e7c
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,51 @@
+--extra-index-url https://download.pytorch.org/whl/cu121
+torch==2.5.1
+torchvision==0.20.1
+xformers==0.0.29.post1
+-f https://data.pyg.org/whl/torch-2.5.1+cu121.html
+torch-cluster
+diffusers==0.33.0
+accelerate==1.1.1
+moviepy==1.0.3
+hydra-core==1.3.2
+taichi==1.7.3
+trimesh==4.4.0
+numpy==1.26.1
+opencv-python==4.11.0.86
+opencv-python-headless==4.11.0.86
+rembg==2.0.63
+huggingface_hub==0.31.1
+gradio==4.44.1
+gradio_image_prompter==0.1.0
+uvicorn==0.34.2
+fastapi==0.112.0
+pydantic==2.8.2
+plotly==5.24.1
+SentencePiece==0.2.0
+
+warp-lang
+matplotlib
+pyyaml
+h5py
+einops
+timm
+scikit-image
+open3d
+PyMCubes
+plyfile
+transformers==4.45.1
+cython
+rich
+munch
+omegaconf
+flow_vis
+kiui
+segment_anything
+rembg
+onnxruntime
+tyro
+roma
+safetensors
+# git+https://github.com/ashawkey/diff-gaussian-rasterization.git
+git+https://github.com/asomoza/image_gen_aux.git
+./wheels/diff-gaussian-rasterization-0.1.1-cp310-cp310-linux_x86_64.whl
\ No newline at end of file
diff --git a/src/configs/acc/1gpu.yaml b/src/configs/acc/1gpu.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..9dcf83fb8b0bac2aa17f9c0c0459a08d843e3e29
--- /dev/null
+++ b/src/configs/acc/1gpu.yaml
@@ -0,0 +1,14 @@
+compute_environment: LOCAL_MACHINE
+distributed_type: 'NO'
+downcast_bf16: 'no'
+machine_rank: 0
+main_training_function: main
+mixed_precision: 'fp16'
+num_machines: 1
+num_processes: 1
+rdzv_backend: static
+same_network: true
+tpu_env: []
+tpu_use_cluster: false
+tpu_use_sudo: false
+use_cpu: false
diff --git a/src/configs/acc/2gpu.yaml b/src/configs/acc/2gpu.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..cff0f583ff50d54a7c6fe0b0d9909a8bd93c5c42
--- /dev/null
+++ b/src/configs/acc/2gpu.yaml
@@ -0,0 +1,16 @@
+compute_environment: LOCAL_MACHINE
+debug: false
+distributed_type: MULTI_GPU
+downcast_bf16: 'no'
+machine_rank: 0
+main_training_function: main
+mixed_precision: 'no'
+num_machines: 1
+num_processes: 2
+rdzv_backend: static
+same_network: true
+tpu_env: []
+tpu_use_cluster: false
+tpu_use_sudo: false
+use_cpu: false
+main_process_port: 19001
\ No newline at end of file
diff --git a/src/configs/acc/4gpu.yaml b/src/configs/acc/4gpu.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..080c1469a94611d7993b8fe25c1552f5e7e4b077
--- /dev/null
+++ b/src/configs/acc/4gpu.yaml
@@ -0,0 +1,15 @@
+compute_environment: LOCAL_MACHINE
+debug: false
+distributed_type: MULTI_GPU
+downcast_bf16: 'no'
+machine_rank: 0
+main_training_function: main
+mixed_precision: fp16
+num_machines: 1
+num_processes: 4
+rdzv_backend: static
+same_network: true
+tpu_env: []
+tpu_use_cluster: false
+tpu_use_sudo: false
+use_cpu: false
\ No newline at end of file
diff --git a/src/configs/acc/8gpu.yaml b/src/configs/acc/8gpu.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..10ba6f1e6b9bf4b238e0b7245a78fd4ea22fc7a9
--- /dev/null
+++ b/src/configs/acc/8gpu.yaml
@@ -0,0 +1,15 @@
+compute_environment: LOCAL_MACHINE
+distributed_type: MULTI_GPU
+downcast_bf16: 'no'
+gpu_ids: all
+machine_rank: 0
+main_training_function: main
+mixed_precision: 'no'
+num_machines: 1
+num_processes: 8
+rdzv_backend: static
+same_network: true
+tpu_env: []
+tpu_use_cluster: false
+tpu_use_sudo: false
+use_cpu: false
diff --git a/src/configs/config_dit_base.yaml b/src/configs/config_dit_base.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..876d9a207c3f31ae0230379284a5ab9bdd62b71a
--- /dev/null
+++ b/src/configs/config_dit_base.yaml
@@ -0,0 +1,76 @@
+image_size: 128
+output_dir: ./outputs/dit_${model_config.n_layers}layers_${pc_size}p_${train_dataset.category}_${train_dataset.n_training_frames}frames_pointembed_latent${model_config.latent_dim}_deform${lambda_deform}_8gpus_base
+seed: 0
+train_batch_size: 2
+eval_batch_size: 4
+num_train_epochs: 50
+max_train_steps: 60000
+gradient_accumulation_steps: 25
+gradient_checkpointing: true
+learning_rate: 1e-4
+scale_lr: false
+lr_scheduler: "costant_with_warmup"
+lr_warmup_steps: 100
+use_8bit_adam: false
+allow_tf32: true
+dataloader_num_workers: 48
+adam_beta1: 0.9
+adam_beta2: 0.999
+adam_weight_decay: 1.e-2
+adam_epsilon: 1.e-08
+max_grad_norm: 1.0
+prediction_type: null
+vis_dir: vis
+logging_dir: logs
+mixed_precision: 'bf16'
+report_to: 'tensorboard'
+local_rank: -1
+checkpointing_steps: 2500
+checkpoints_total_limit: 10
+resume_from_checkpoint: latest
+enable_xformers_memory_efficient_attention: true
+validation_steps: 500
+validation_train_steps: 2000
+validation_sanity_check: true
+tracker_project_name: 'diffusion'
+push_to_hub: false
+set_grads_to_none: true
+lambda_vel: 1.0
+lambda_mask: 0.0
+lambda_momentum: 0.0
+lambda_deform: 0.001
+pc_size: 2048
+condition_drop_rate: 0.0
+model_type: 'dit_st'
+model_config:
+ n_layers: 8
+ latent_dim: 256
+ frame_cond: true
+ point_embed: true
+ mask_cond: false
+ pred_offset: true
+ num_neighbors: -1
+ floor_cond: false
+ max_num_forces: 1
+ force_as_token: false
+ force_as_latent: false
+ coeff_cond: false
+ class_token: false
+ transformer_block: SpatialTemporalTransformerBlock
+train_dataset:
+ category: hf-objaverse-v1
+ dataset_path: DATASET_FOLDER
+ dataset_list: DATASET_ITEM_LIST
+ has_gravity: true # no gravity
+ max_num_forces: ${model_config.max_num_forces}
+ norm_fac: 5
+ stage: 'deform'
+ mode: 'diff'
+ pc_size: ${pc_size}
+ repeat: 1
+ seed: 0
+ n_sample_pro_model: 300
+ n_frames_interval: 2
+ n_training_frames: 24
+ batch_size: 20
+ overfit: false
diff --git a/src/configs/config_dit_large.yaml b/src/configs/config_dit_large.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f82e8fa192bbcaa2aca32170e496b79b3ecbed62
--- /dev/null
+++ b/src/configs/config_dit_large.yaml
@@ -0,0 +1,78 @@
+image_size: 128
+output_dir: ./outputs/dit_${model_config.n_layers}layers_${pc_size}p_${train_dataset.category}_${train_dataset.n_training_frames}frames_pointembed_latent${model_config.latent_dim}_deform${lambda_deform}_8gpus_large
+seed: 0
+train_batch_size: 2
+eval_batch_size: 4
+num_train_epochs: 50
+max_train_steps: 100000
+gradient_accumulation_steps: 25
+gradient_checkpointing: true
+learning_rate: 1e-4
+scale_lr: false
+lr_scheduler: "costant_with_warmup"
+lr_warmup_steps: 100
+use_8bit_adam: false
+allow_tf32: true
+dataloader_num_workers: 48
+adam_beta1: 0.9
+adam_beta2: 0.999
+adam_weight_decay: 1.e-2
+adam_epsilon: 1.e-08
+max_grad_norm: 1.0
+prediction_type: null
+vis_dir: vis
+logging_dir: logs
+mixed_precision: 'bf16'
+report_to: 'tensorboard'
+local_rank: -1
+checkpointing_steps: 2500
+checkpoints_total_limit: 10
+resume_from_checkpoint: latest
+enable_xformers_memory_efficient_attention: true
+validation_steps: 500
+validation_train_steps: 2000
+validation_sanity_check: true
+tracker_project_name: 'diffusion'
+push_to_hub: false
+set_grads_to_none: true
+lambda_vel: 1.0
+lambda_mask: 0.0
+lambda_momentum: 0.0
+lambda_deform: 0.001
+pc_size: 2048
+condition_drop_rate: 0.0
+model_type: 'dit_st'
+model_config:
+ n_layers: 8
+ latent_dim: 256
+ frame_cond: true
+ point_embed: true
+ mask_cond: false
+ pred_offset: true
+ num_neighbors: -1
+ floor_cond: true
+ max_num_forces: 1
+ force_as_token: true
+ force_as_latent: false
+ gravity_emb: true
+ coeff_cond: false
+ num_mat: 4
+ class_token: true
+ transformer_block: SpatialTemporalTransformerBlock
+train_dataset:
+ category: hf-objaverse-v1
+ dataset_path: DATASET_FOLDER
+ dataset_list: DATASET_ITEM_LIST
+ has_gravity: true
+ max_num_forces: ${model_config.max_num_forces}
+ norm_fac: 5
+ stage: 'deform'
+ mode: 'diff'
+ pc_size: ${pc_size}
+ repeat: 1
+ seed: 0
+ n_sample_pro_model: 300
+ n_frames_interval: 2
+ n_training_frames: 24
+ batch_size: 20
+ overfit: false
diff --git a/src/configs/eval_base.yaml b/src/configs/eval_base.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..cee0b3c160855479037da43ca088d0e90096d3fa
--- /dev/null
+++ b/src/configs/eval_base.yaml
@@ -0,0 +1,39 @@
+pc_size: 2048
+eval_batch_size: 1
+dataloader_num_workers: 4
+seed: 0
+pred_offset: true
+model_type: 'dit_st'
+num_inference_steps: 4
+model_config:
+ n_layers: 8
+ latent_dim: 256
+ frame_cond: true
+ point_embed: true
+ mask_cond: false
+ pred_offset: true
+ num_neighbors: -1
+ floor_cond: false
+ max_num_forces: 1
+ force_as_token: false
+ force_as_latent: false
+ coeff_cond: false
+ class_token: false
+ transformer_block: SpatialTemporalTransformerBlock
+train_dataset:
+ category: hf-objaverse-v1
+ dataset_list: DATASET_ITEM_LIST_TEST
+ dataset_path: DATASET_FOLDER
+ has_graivty: false
+
+ norm_fac: 5
+ stage: 'deform'
+ mode: 'diff'
+ pc_size: 2048
+ repeat: 1
+ seed: 0
+ n_sample_pro_model: 300
+ n_frames_interval: 2
+ n_training_frames: 24
+ batch_size: 20
+ overfit: false
diff --git a/src/configs/eval_large.yaml b/src/configs/eval_large.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..c6a671095b0faa4dbcba052b037a83df4ff61644
--- /dev/null
+++ b/src/configs/eval_large.yaml
@@ -0,0 +1,42 @@
+pc_size: 2048
+eval_batch_size: 1
+dataloader_num_workers: 4
+seed: 0
+pred_offset: true
+model_type: 'dit_st'
+num_inference_steps: 25
+model_config:
+ n_layers: 12
+ latent_dim: 512
+ frame_cond: true
+ point_embed: true
+ mask_cond: false
+ pred_offset: true
+ num_neighbors: -1
+ floor_cond: true
+ max_num_forces: 1
+ force_as_token: true
+ force_as_latent: false
+ gravity_emb: true
+ coeff_cond: false
+ num_mat: 4
+ class_token: true
+ transformer_block: SpatialTemporalTransformerBlock
+train_dataset:
+ category: hf-objaverse-v1
+ dataset_path: DATASET_FOLDER
+ dataset_list: DATASET_ITEM_LIST_TEST
+ has_gravity: true
+ max_num_forces: ${model_config.max_num_forces}
+
+ norm_fac: 5
+ stage: 'deform'
+ mode: 'diff'
+ pc_size: 2048
+ repeat: 1
+ seed: 0
+ n_sample_pro_model: 300
+ n_frames_interval: 2
+ n_training_frames: 24
+ batch_size: 20
+ overfit: false
diff --git a/src/dataset/traj_dataset.py b/src/dataset/traj_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..8311479b18ec87c20dd8e502b71a77611b994b07
--- /dev/null
+++ b/src/dataset/traj_dataset.py
@@ -0,0 +1,286 @@
+import numpy as np
+import torch
+from torch.utils.data import Dataset
+import os
+import h5py
+from torch_cluster import fps
+import json
+import random
+
+class TrajDataset(Dataset):
+ def __init__(self, split, cfg):
+
+ self.cfg = cfg
+ self.dataset_path = cfg.dataset_path
+ self.split = split
+ self.stage = cfg.stage # 'shape' or 'deform'
+ self.mode = cfg.mode # 'ae' or 'diff'
+ self.repeat = cfg.repeat
+ self.seed = cfg.seed
+ self.pc_size = cfg.pc_size
+ self.n_sample_pro_model = cfg.n_sample_pro_model
+ self.n_frames_interval = cfg.n_frames_interval
+ self.n_training_frames = cfg.n_training_frames
+ self.batch_size = cfg.batch_size
+ self.has_gravity = cfg.get('has_gravity', False)
+ self.max_num_forces = cfg.get('max_num_forces', 1)
+
+ # if os.path.exists(os.path.join(self.dataset_path, cfg.dataset_list)):
+ if os.path.exists(cfg.dataset_list):
+ print(f'Loading {cfg.dataset_list}')
+ with open(cfg.dataset_list, 'r') as f:
+ self.split_lst = json.load(f)
+ else:
+ self.split_lst = [f for f in sorted(os.listdir(self.dataset_path)) if f.endswith('h5')]
+ random.seed(0)
+ random.shuffle(self.split_lst)
+ print('Number of data:', len(self.split_lst))
+
+ if cfg.overfit:
+ self.split_lst = self.split_lst[:1]
+ elif cfg.dataset_path.endswith('_test') or cfg.dataset_list.endswith('test.json') or cfg.dataset_list.endswith('test_list.json'):
+ self.split_lst = self.split_lst[:100]
+ print('Test split:', self.split_lst)
+ else:
+ if split == 'train':
+ self.split_lst = self.split_lst[:-4]
+ else:
+ self.split_lst = self.split_lst[-8:]
+ print('Test split:', self.split_lst)
+ self.split_lst_save = self.split_lst.copy()
+ self.split_lst_pcl_len = [49] * len(self.split_lst_save)
+ # if not os.path.exists(os.path.join(self.dataset_path, f'info_deform_ae_{split}.json')):
+ self.prepare_data_lst()
+ # with open(os.path.join(self.dataset_path, f'info_deform_ae_{split}.json'), "w") as f:
+ # json.dump(self.models, f)
+ # print(f'Saved info_deform_ae_{split}.json')
+ # else:
+ # self.models = json.load(open(os.path.join(self.dataset_path, f'info_deform_ae_{split}.json'), 'r'))
+ # print(f'Loaded info_deform_ae_{split}.json')
+
+ print("Current stage: [bold red]{}[/bold red]".format(self.stage))
+ print("Current mode: [bold red]{}[/bold red]".format(self.mode))
+ print("Current split: [bold red]{}[/bold red]".format(self.split))
+ print("Dataset is repeated [bold cyan]{}[/bold cyan] times".format(self.repeat))
+ print("Length of split: {}".format(len(self.split_lst) if self.stage == 'shape' else len(self.models)))
+
+ def prepare_data_lst(self):
+ self.models = []
+ if self.stage == 'deform':
+ if self.mode == 'ae':
+ if self.split == 'train':
+ models_out, indices_out = self.random_sample_indexes(self.split_lst_save * self.repeat, self.split_lst_pcl_len * self.repeat)
+ self.models += [{"model": m, "indices": indices_out[i]} for i, m in enumerate(models_out)]
+ else: # Evaluate
+ for m in self.split_lst_save:
+ for i in range(1, self.batch_size + 1):
+ self.models += [{"model": m, "indices": [i-1, i]}]
+ elif self.mode == 'diff':
+ # models_out, indices_out = self.subdivide_into_sequences(self.split_lst_save * self.repeat, self.split_lst_pcl_len * self.repeat)
+ # self.models += [{"model": m, "start_idx": indices_out[i]} for i, m in enumerate(models_out)]
+ self.models += [{"model": m, "start_idx": 0} for i, m in enumerate(self.split_lst_save)]
+ else:
+ raise NotImplementedError("mode not implemented")
+
+ def __getitem__(self, index):
+ if self.stage == 'deform':
+ if self.mode == 'ae':
+ return self.get_deform_ae(index)
+ elif self.mode == 'diff':
+ return self.get_deform_diff(index)
+
+ def __len__(self):
+ if self.stage == 'deform':
+ if self.mode == 'ae':
+ if self.split == 'train':
+ return sum(self.split_lst_pcl_len) * self.repeat
+ else:
+ return len(self.split_lst_save) * self.batch_size # number of sequences
+ elif self.mode == 'diff':
+ return len(self.models)
+ else:
+ raise NotImplementedError("mode not implemented")
+
+ def random_sample_indexes(self, models, models_len):
+ n_sample_pro_model = self.n_sample_pro_model
+ interval_between_frames = self.interval_between_frames
+ n_selected_frames = self.n_selected_frames
+
+ # Initialize output lists
+ models_out = []
+ indexes_out = []
+
+ # Loop over each model
+ for idx, model in enumerate(models):
+ # For each sample per model
+ for n in range(n_sample_pro_model):
+ # Initialize indices list for current sample
+ indexes = []
+
+ # Select n_selected_frames number of indices
+ for i in range(n_selected_frames):
+ # If first index, randomly select from range
+ if i == 0:
+ # indexes.append(np.random.randint(0, models_len[idx] - interval_between_frames))
+ indexes.append(np.random.randint(0, models_len[idx]))
+ else:
+ # For subsequent indices, select within interval_between_frames from the previous index
+ indexes.append( min(indexes[-1] + np.random.randint(0, interval_between_frames), models_len[idx]-1) )
+
+ # Append the selected indices and corresponding model to output lists
+ indexes_out.append(sorted(indexes))
+ models_out.append(model)
+
+ return models_out, indexes_out
+
+ def get_deform_ae(self, index):
+ model = self.models[index]
+ model_name = model["model"]
+ model_indices = model["indices"]
+
+ model_info = {}
+ model_info["model"] = model_name
+ model_info["indices"] = model_indices
+
+ model_metas = h5py.File(os.path.join(self.dataset_path, f'{model_name}'), 'r')
+ model_pcls = torch.from_numpy(np.array(model_metas['x']))
+
+ ind = np.random.default_rng(seed=self.seed).choice(model_pcls[0].shape[0], self.pc_size, replace=False)
+ points_src = model_pcls[model_indices[0]][ind]
+ points_tgt = model_pcls[model_indices[1]][ind]
+
+ model_data = {}
+ model_data['points_src'] = points_src.float()
+ model_data['points_tgt'] = points_tgt.float()
+ return model_data, model_info
+
+ def get_deform_diff(self, index):
+
+ model = self.models[index]
+ model_name = model["model"]
+
+ model_info = {}
+ model_info["model"] = model_name
+ model_info["indices"] = np.arange(self.n_training_frames)
+
+ model_data = {}
+ model_data['model'] = model_name
+
+ model_metas = h5py.File(os.path.join(self.dataset_path, f'{model_name}'), 'r')
+ model_pcls = torch.from_numpy(np.array(model_metas['x']))
+
+ # if model_pcls[0].shape[0] > self.pc_size:
+ # ind = np.random.default_rng(seed=self.seed).choice(model_pcls[0].shape[0], self.pc_size, replace=False)
+ # points_src = model_pcls[:1]
+ # points_tgt = model_pcls[1:(self.n_training_frames*self.n_frames_interval+1):self.n_frames_interval][:, ind]
+ # else: # No need to do fps in new dataset case (input is 2048 points)
+ points_src = model_pcls[:1]
+ points_tgt = model_pcls[1:(self.n_training_frames*self.n_frames_interval+1):self.n_frames_interval]
+
+ if not 'drag_point' in model_metas: # Assume drag direction cross the sphere center
+ drag_dir = np.array(model_metas['drag_force'])
+ drag_dir = drag_dir / np.linalg.norm(drag_dir)
+ drag_point = np.array([self.cfg.norm_fac, self.cfg.norm_fac, self.cfg.norm_fac]) + drag_dir
+ else:
+ drag_point = np.array(model_metas['drag_point'])
+
+ if not 'floor_height' in model_metas:
+ model_data['floor_height'] = torch.from_numpy(np.array(-2.4)).unsqueeze(-1).float()
+ else:
+ model_data['floor_height'] = (torch.from_numpy(np.array(model_metas['floor_height'])).unsqueeze(-1).float() - self.cfg.norm_fac) / 2
+ model_data['drag_point'] = (torch.from_numpy(drag_point).float() - self.cfg.norm_fac) / 2
+ model_data['points_src'] = (points_src.float() - self.cfg.norm_fac) / 2
+ model_data['points_tgt'] = (points_tgt.float() - self.cfg.norm_fac) / 2
+
+ model_data['vol'] = torch.from_numpy(np.array(model_metas['vol']))
+ model_data['F'] = torch.from_numpy(np.array(model_metas['F']))
+ model_data['F'] = model_data['F'][1:(self.n_training_frames*self.n_frames_interval+1):self.n_frames_interval]
+ model_data['C'] = torch.from_numpy(np.array(model_metas['C']))
+ model_data['C'] = model_data['C'][1:(self.n_training_frames*self.n_frames_interval+1):self.n_frames_interval]
+
+ mask = torch.from_numpy(np.array(model_metas['drag_mask'])).bool()
+
+ if 'gravity' in model_metas:
+ model_data['gravity'] = torch.from_numpy(np.array(model_metas['gravity'])).long().unsqueeze(0)
+ else:
+ # print('no gravity in model_metas')
+ model_data['gravity'] = torch.from_numpy(np.array(0)).long().unsqueeze(0)
+
+ model_data['drag_point'] = (torch.from_numpy(drag_point).float() - self.cfg.norm_fac) / 2
+ if model_data['drag_point'].ndim == 1: # For compatibility: only have one force
+ model_data['drag_point'] = torch.cat([model_data['drag_point'], torch.tensor([mask.sum()]).float()], dim=0).unsqueeze(0)
+ else:
+ model_data['drag_point'] = torch.cat([model_data['drag_point'], mask.sum(dim=-1, keepdim=True).float()], dim=1)
+
+ force_order = torch.randperm(self.max_num_forces) if self.split == 'train' else torch.arange(self.max_num_forces)
+ mask = mask.unsqueeze(0) if mask.ndim == 1 else mask
+ # force_mask = torch.ones(self.max_num_forces, 1)
+ # force_mask[:mask.shape[0]] *= 0
+ # force_mask = force_mask[force_order].bool()
+
+ if mask.shape[1] == 0:
+ mask = torch.zeros(0, self.pc_size).bool()
+ model_data['force'] = torch.zeros(0, 3)
+ model_data['drag_point'] = torch.zeros(0, 4)
+ model_data['base_drag_coeff'] = torch.zeros(self.max_num_forces, 1)
+ elif not 'base_drag_coeff' in model_metas:
+ vol = model_data['vol'].unsqueeze(0)
+ total_volume = torch.sum(vol)
+ masked_volume = torch.sum(vol * mask, dim=1)
+ mean_masked_volume = masked_volume / mask.sum(dim=1)
+ mask_ratio = masked_volume / total_volume
+ base_drag_coeff = 9.8 * 1000 * mean_masked_volume / mask_ratio
+ weighted_force = torch.from_numpy(np.array(model_metas['drag_force'])).float()
+ weighted_force = weighted_force.unsqueeze(0) if weighted_force.ndim == 1 else weighted_force
+ model_data['force'] = weighted_force / base_drag_coeff.unsqueeze(1)
+ coeff = torch.zeros(self.max_num_forces, 1)
+ coeff = coeff[force_order]
+ coeff[:base_drag_coeff.shape[0]] = base_drag_coeff.unsqueeze(1)
+ model_data['base_drag_coeff'] = coeff
+ # model_data['weighted_force'] = weighted_force
+ else:
+ model_data['force'] = torch.from_numpy(np.array(model_metas['drag_force'])).float()
+ model_data['base_drag_coeff'] = torch.from_numpy(np.array(model_metas['base_drag_coeff'])).float()
+
+ model_data['is_mpm'] = torch.tensor(1).bool()
+ if 'mat_type' in model_metas:
+ model_data['mat_type'] = torch.from_numpy(np.array(model_metas['mat_type'])).long()
+ if np.array(model_data['mat_type']).item() == 3: # Rigid dataset
+ model_data['is_mpm'] = torch.tensor(0).bool()
+ else: # temporary fix for elastic data
+ model_data['mat_type'] = torch.tensor(0).long()
+
+ if self.has_gravity and model_data['gravity'][0] == 1: # add gravity to force
+ gravity = torch.tensor([[0, -1.0, 0]]).float()
+ drag_point = (model_data['points_src'][0] * (model_data['vol'] / model_data['vol'].sum()).unsqueeze(1)).sum(axis=0) if model_data['is_mpm'] else model_data['points_src'][0].mean(axis=0)
+ drag_point = torch.cat([drag_point, torch.tensor([self.pc_size]).float()]).unsqueeze(0)
+ assert model_data['force'].sum() == 0, f'we are not supporting both drag and gravity now: {model_name}'
+ model_data['force'] = torch.cat([model_data['force'], gravity], dim=0) if not model_data['force'].sum() == 0 else gravity
+ model_data['drag_point'] = torch.cat([model_data['drag_point'], drag_point], dim=0) if not drag_point.sum() == 0 else drag_point
+ mask = torch.cat([mask, torch.ones_like(mask).bool()], dim=0) if not mask.sum() == 0 else torch.ones(1, self.pc_size).bool()
+
+ all_forces = torch.zeros(self.max_num_forces, 3)
+ all_forces[:model_data['force'].shape[0]] = model_data['force']
+ all_forces = all_forces[force_order]
+ model_data['force'] = all_forces
+
+ all_drag_points = torch.zeros(self.max_num_forces, 4)
+ all_drag_points[:model_data['drag_point'].shape[0]] = model_data['drag_point']
+ all_drag_points = all_drag_points[force_order]
+ model_data['drag_point'] = all_drag_points
+
+ if model_pcls[0].shape[0] > self.pc_size:
+ ind = np.random.default_rng(seed=self.seed).choice(model_pcls[0].shape[0], self.pc_size, replace=False)
+ model_data['points_src'] = model_data['points_src'][:, ind]
+ model_data['points_tgt'] = model_data['points_tgt'][:, ind]
+ mask = mask[:, ind] if mask.shape[-1] > self.pc_size else mask
+
+ all_mask = torch.zeros(self.max_num_forces, self.pc_size).bool()
+ all_mask[:mask.shape[0]] = mask
+ all_mask = all_mask[force_order]
+
+ model_data['mask'] = all_mask[..., None] # (n_forces, pc_size, 1) for compatibility
+ model_data['E'] = torch.log10(torch.from_numpy(np.array(model_metas['E'])).unsqueeze(-1).float()) if np.array(model_metas['E']) > 0 else torch.zeros(1).float()
+ model_data['nu'] = torch.from_numpy(np.array(model_metas['nu'])).unsqueeze(-1).float()
+
+ return model_data, model_info
\ No newline at end of file
diff --git a/src/eval.py b/src/eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..888ed4c10f16b5110ffd9841e4955fb7833331cb
--- /dev/null
+++ b/src/eval.py
@@ -0,0 +1,72 @@
+from diffusers import DDPMScheduler, DDIMScheduler
+from dataset.traj_dataset import TrajDataset
+from model.mdm_dit import MDM_DiT
+from model.spacetime import MDM_ST
+import sys
+from options import TrainingConfig, TestingConfig
+from omegaconf import OmegaConf
+from pipeline_traj import TrajPipeline
+import torch
+from safetensors.torch import load_file
+import argparse
+import os
+import numpy as np
+from utils.physics import loss_momentum, DeformLoss
+import torch.nn.functional as F
+from tqdm import tqdm
+from utils.visualization import save_pointcloud_video, save_pointcloud_json, save_threejs_html, generate_html_from_exts
+
+def create_model(args):
+ model = MDM_ST(args.pc_size, args.train_dataset.n_training_frames, n_feats=3, model_config=args.model_config)
+ return model
+
+loss_deform = DeformLoss().to('cuda')
+def main(args):
+ val_dataset = TrajDataset('val', args.train_dataset)
+ val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=args.eval_batch_size, shuffle=False, num_workers=args.dataloader_num_workers)
+
+ device = 'cuda'
+ model = create_model(args).to(device)
+ ckpt = load_file(args.resume, device='cpu')
+ model.load_state_dict(ckpt, strict=False)
+ model.eval().requires_grad_(False)
+ model = torch.compile(model)
+ noise_scheduler = DDIMScheduler(num_train_timesteps=1000, prediction_type='sample', clip_sample=False)
+ pipeline = TrajPipeline(model=model, scheduler=noise_scheduler)
+
+ total_loss_p = 0.0
+ total_loss_xyz = 0.0
+ total_loss_F = 0.0
+ total_loss_F_gt = 0.0
+ for i, (batch, _) in enumerate(tqdm(val_dataloader)):
+ with torch.autocast("cuda", dtype=torch.bfloat16):
+ output = pipeline(batch['points_src'], batch['force'], batch['E'], batch['nu'], batch['mask'][..., :1], batch['drag_point'], batch['floor_height'], batch['gravity'], batch['base_drag_coeff'], y=None if args.model_config.get('num_mat', 0) == 0 else batch['mat_type'], device=device, batch_size=args.eval_batch_size, generator=torch.Generator().manual_seed(args.seed), n_frames=args.train_dataset.n_training_frames, num_inference_steps=args.num_inference_steps)
+ if 'vol' in batch:
+ loss_F = loss_deform(x=output.clamp(min=-2.2, max=2.2), vol=batch['vol'].to(device), F=batch['F'].to(device),
+ C=batch['C'].to(device), frame_interval=2, norm_fac=args.train_dataset.norm_fac)
+ loss_F_gt = loss_deform(x=batch['points_tgt'].to(device), vol=batch['vol'].to(device), F=batch['F'].to(device),
+ C=batch['C'].to(device), frame_interval=2, norm_fac=args.train_dataset.norm_fac)
+ total_loss_F += loss_F
+ total_loss_F_gt += loss_F_gt
+ total_loss_xyz += F.mse_loss(output, batch['points_tgt'].to(device))
+ output = output.cpu().numpy()
+ tgt = batch['points_tgt'].cpu().numpy()
+ vis_dir = args.vis_dir
+ save_dir = os.path.join(vis_dir, f'test_100_{args.num_inference_steps}steps_nips_debug')
+ os.makedirs(save_dir, exist_ok=True)
+ for j in range(output.shape[0]):
+ save_pointcloud_video(output[j:j+1].squeeze(), tgt[j:j+1].squeeze(), os.path.join(save_dir, f'{i*batch["points_src"].shape[0] + j:03d}_{batch["E"][j].item():03f}_{batch["nu"][j].item():03f}.gif'), drag_mask=batch['mask'][j:j+1, 0, :, 0].cpu().numpy().squeeze(), vis_flag='objaverse')
+ np.save(os.path.join(save_dir, f'{i*batch["points_src"].shape[0] + j}_{batch["E"][j].item():03f}_{batch["nu"][j].item():03f}.npy'), output[j:j+1].squeeze())
+ np.save(os.path.join(save_dir, f'{batch["model"][j]}.npy'), output[j:j+1].squeeze())
+ torch.cuda.empty_cache()
+ generate_html_from_exts(save_dir, os.path.join(save_dir, f'visualize.html'), 'gif')
+ print(total_loss_p, total_loss_xyz, total_loss_F, total_loss_F_gt)
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--config', type=str, required=True)
+ args = parser.parse_args()
+ schema = OmegaConf.structured(TestingConfig)
+ cfg = OmegaConf.load(args.config)
+ cfg = OmegaConf.merge(schema, cfg)
+ main(cfg)
\ No newline at end of file
diff --git a/src/inference.py b/src/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..b5ac1b8bedf6ffca0e1395aec170b181198f25f4
--- /dev/null
+++ b/src/inference.py
@@ -0,0 +1,733 @@
+import os
+import argparse
+import json
+import sys
+import gc
+import random
+import warp as wp
+
+sys.path.append("../libs")
+sys.path.append("../libs/LGM")
+sys.path.append("../libs/vggt")
+sys.path.append("../libs/das")
+
+import numpy as np
+import trimesh
+import torch
+import torch.nn.functional as F
+import torchvision.transforms.functional as TF
+import cv2
+
+import h5py
+import tyro
+import imageio
+import open3d as o3d
+
+from tqdm import tqdm
+from PIL import Image
+from sklearn.decomposition import PCA
+from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
+from diffusers import AutoencoderKL, EulerDiscreteScheduler, DDPMScheduler
+from diffusers.utils import export_to_gif, export_to_video
+from kiui.cam import orbit_camera
+from safetensors.torch import load_file
+from torch_cluster import fps
+from omegaconf import OmegaConf
+
+from sv3d.diffusers_sv3d import SV3DUNetSpatioTemporalConditionModel, StableVideo3DDiffusionPipeline
+from LGM.core.models import LGM
+from LGM.core.options import AllConfigs
+from LGM.core.gs import GaussianRenderer
+from LGM.mvdream.pipeline_mvdream import MVDreamPipeline
+
+from vggt.models.vggt import VGGT
+from vggt.utils.load_fn import load_and_preprocess_images
+
+from utils.seeding import seed_everything
+from utils.track_utils.preprocessing import track_first, find_and_remove_nearest_point
+from utils.track_utils.visualize_tracks import visualize_tracks
+from utils.interpolate import *
+from utils.loading import paste_image
+from utils.image_process import image_preprocess, pred_bbox, sam_init, sam_out_nosave, resize_image
+from utils.transform import transform2origin, shift2center_th
+from utils.sim_utils import get_particle_volume
+
+# Diffusion
+from model.spacetime import MDM_ST
+from pipeline_traj import TrajPipeline
+from options import TestingConfig
+
+device = torch.device("cuda")
+
+def run_track(args, output_dir):
+
+ N = 2048
+ frame_num = 49
+
+ animated_points = np.load(f'{output_dir}/gen_data.npy')
+ animated_points = animated_points * 2
+ new_animate_points = np.zeros((frame_num, N, 3))
+ for i in range(frame_num - 2): # Interpolate since we only generate 24 frames
+ if i % 2 == 0:
+ new_animate_points[i + 1] = animated_points[i // 2]
+ else:
+ new_animate_points[i + 1] = (animated_points[i // 2] + animated_points[i // 2 + 1]) / 2
+ new_animate_points[0] = new_animate_points[1]
+ new_animate_points[frame_num - 1] = new_animate_points[frame_num - 2]
+ animated_points = new_animate_points
+
+ projection_matrix = np.load('templates/projection.npy')
+ crop_info = np.load(f'{output_dir}/crop_info.npy')
+ center = np.load(f'{output_dir}/center.npy')
+ scale = np.load(f'{output_dir}/scale.npy')
+ animated_points = (animated_points / scale) + center
+
+ # Aligned to Gaussian points at this moment
+ sys.argv = ['pipeline_track_gen.py', 'big']
+ opt = tyro.cli(AllConfigs)
+
+ scale_factor = 1
+ focal = 0.5 * opt.output_size / np.tan(np.deg2rad(opt.fovy) / 2)
+ new_fovy_rad = scale_factor * np.arctan(opt.output_size / focal)
+ new_fovy_deg = np.rad2deg(new_fovy_rad)
+ opt.fovy = new_fovy_deg
+ opt.output_size *= scale_factor # Expand canvas size by 2
+
+ gs = GaussianRenderer(opt)
+ gaussians = gs.load_ply(f'{output_dir}/point_cloud.ply', compatible=True).to(device).float()
+ idx = torch.from_numpy(np.load(f'{output_dir}/fps_idx.npy')).to(device)
+ gaussian_pos = gaussians[:, :3].contiguous()
+ drive_x = gaussian_pos[idx]
+ cdist = -1.0 * torch.cdist(gaussian_pos, drive_x) # [N, 2048]
+ _, topk_index = torch.topk(cdist, 8, -1)
+
+ cam_poses = torch.from_numpy(orbit_camera(0, 0, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)
+ cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
+ cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
+ cam_view_proj = cam_view @ gs.proj_matrix.to(device) # [V, 4, 4]
+ cam_pos = - cam_poses[:, :3, 3] # [V, 3]
+
+ pos = []
+ frames = []
+ input_raw = np.array(Image.open(f'{args.base_dir}/{args.data_name}/input.png'))
+ input_mask = np.array(Image.open(f'{output_dir}/input_mask.png').convert('L'))
+ input_raw[input_mask != 0] = 0 # Set masked pixels (where mask is 0) to black
+ input_raw = Image.fromarray(input_raw)
+
+ for i in tqdm(range(0, frame_num, 1)):
+ drive_current = torch.from_numpy(animated_points[i]).to(device).float()
+ ret_points, new_rotation = interpolate_points(gaussian_pos, gaussians[:, 7:11], drive_x, drive_current, topk_index)
+ gaussians_new = gaussians.clone()
+ gaussians_new[:, :3] = ret_points
+ gaussians_new[:, 7:11] = new_rotation
+ pos.append(ret_points.cpu().numpy())
+
+ track_template = np.load(f'templates/tracks_template.npy', allow_pickle=True)
+ tracks = track_template.item()['tracks']
+ tracks_output = tracks.copy()
+ tracks_init = tracks[0, 0]
+ track_idx = []
+ mask = np.zeros(tracks_init.shape[0], dtype=bool)
+
+ h_begin, w_begin, res = crop_info[0], crop_info[1], crop_info[2]
+ image_shape = (res, res) # Example image shape (H, W)
+
+ drag_points = []
+
+ for i in tqdm(range(frame_num)):
+
+ points = pos[i]
+ projected_points = (projection_matrix.T @ np.hstack((points, np.ones((points.shape[0], 1)))).T).T
+ projected_points_weights = 1. / (projected_points[:, -1:] + 1e-8)
+ projected_points = (projected_points * projected_points_weights)[:, :-1]
+
+ projected_points[:, :2] = ((projected_points[:, :2] + 1) * image_shape[1] - 1) / 2
+ projected_points[:, 0] += w_begin
+ projected_points[:, 1] += h_begin
+ drag_points.append(projected_points.mean(axis=0))
+
+ if i == 0:
+ track_point_candidates = track_first(projected_points, (480, 720))
+ for j in range(tracks_init.shape[0]):
+ x, y = tracks_init[j, 0], tracks_init[j, 1]
+ target = np.array([x, y])
+ candidate, track_point_candidates = find_and_remove_nearest_point(target, track_point_candidates)
+ if candidate is not None:
+ track_idx.append(candidate[3].astype(np.int32))
+ mask[j] = True
+
+ tracks_output[0, i, mask] = projected_points[track_idx]
+ tracks_output[0, i, ~mask, :2] = tracks_output[0, 0, ~mask, :2]
+ tracks_output[0, i, ~mask, 2] = 2
+
+ track_template.item()['tracks'] = tracks_output
+ track_template.item()['drag_points'] = np.stack(drag_points, axis=0)
+ sub_dir = f'{output_dir}/tracks_gen'
+ os.makedirs(sub_dir, exist_ok=True)
+
+ np.save(f'{sub_dir}/tracks.npy', track_template)
+ visualize_tracks(tracks_dir=sub_dir, output_dir=sub_dir, args=args)
+
+def run_diffusion(args, output_dir):
+
+ schema = OmegaConf.structured(TestingConfig)
+ cfg = OmegaConf.load(args.model_cfg_path)
+ cfg = OmegaConf.merge(schema, cfg)
+ n_training_frames = cfg.train_dataset.n_training_frames
+ n_frames_interval = cfg.train_dataset.n_frames_interval
+ norm_fac = cfg.train_dataset.norm_fac
+ model = MDM_ST(cfg.pc_size, n_training_frames, n_feats=3, model_config=cfg.model_config).to(device)
+
+ ckpt = load_file(args.model_path, device='cpu')
+ model.load_state_dict(ckpt, strict=True)
+ model.eval().requires_grad_(False)
+ noise_scheduler = DDPMScheduler(num_train_timesteps=1000, prediction_type='sample', clip_sample=False)
+ pipeline = TrajPipeline(model=model, scheduler=noise_scheduler)
+
+ pc_path = f'{output_dir}/point_cloud.ply'
+ pc = trimesh.load_mesh(pc_path)
+ points = pc.vertices
+ points = np.array(points)
+ points, center, scale = transform2origin(points, size=1)
+ np.save(f'{output_dir}/center.npy', center)
+ np.save(f'{output_dir}/scale.npy', scale)
+
+ N = 2048
+ max_num_forces = 1
+ has_gravity = args.mat_label > 0
+
+ points = torch.tensor(points, dtype=torch.float32, device=device).contiguous()
+ ratio_N = N / points.shape[0]
+ idx = fps(points, ratio=ratio_N, random_start=True)
+ np.save(f'{output_dir}/fps_idx.npy', idx.cpu().numpy())
+ points_tensor = points[idx].contiguous()
+ points_center = shift2center_th(points_tensor) # MPM coordinate
+ points = points_tensor.cpu().numpy()
+
+ # User input
+ if "drag_mode" in cfg_json:
+ if cfg_json["drag_mode"] == "point":
+ drag_point = np.array(cfg_json["drag_point"])
+ elif cfg_json["drag_mode"] == "max":
+ drag_point_idx = np.argmax(points[:, cfg_json["drag_axis"]]) if cfg_json["drag_mode"] == "max" \
+ else np.argmin(points[:, cfg_json["drag_axis"]])
+ drag_point = points[drag_point_idx]
+ else:
+ raise ValueError(f"Invalid drag mode: {cfg_json['drag_mode']}")
+ drag_offset = np.abs(points - drag_point)
+ drag_mask = (drag_offset < 0.4).all(axis=-1)
+ drag_dir = np.array(cfg_json["drag_dir"], dtype=np.float32)
+ drag_dir /= np.linalg.norm(drag_dir)
+ drag_force = drag_dir * np.array(cfg_json["force_coeff"])
+ else:
+ drag_mask = np.ones(N, dtype=bool)
+ drag_point = np.zeros(4)
+ drag_dir = np.zeros(3)
+ drag_force = np.zeros(3)
+
+ if cfg_json["material"] == "elastic":
+ log_E, nu = np.array(cfg_json["log_E"]), np.array(cfg_json["nu"])
+ else:
+ log_E, nu = np.array(6), np.array(0.4) # Default values for non-elastic materials
+
+ print(f'[Diffusion Simulation] Number of drag points: {drag_mask.sum()}/{N}')
+ print(f'[Diffusion Simulation] Drag point: {drag_point}')
+ print(f'[Diffusion Simulation] log_E: {log_E}, ν: {nu}')
+ print(f'[Diffusion Simulation] Drag force: {drag_force}')
+ print(f'[Diffusion Simulation] Material type: {cfg_json["material"]}({args.mat_label})')
+ print(f'[Diffusion Simulation] Has gravity: {has_gravity}')
+
+ force_order = torch.arange(max_num_forces)
+ mask = torch.from_numpy(drag_mask).bool()
+ mask = mask.unsqueeze(0) if mask.ndim == 1 else mask
+
+ batch = {}
+ batch['gravity'] = torch.from_numpy(np.array(has_gravity)).long().unsqueeze(0)
+ batch['drag_point'] = torch.from_numpy(drag_point).float() / 2
+ batch['drag_point'] = batch['drag_point'].unsqueeze(0) # (1, 4)
+ batch['points_src'] = points_tensor.float().unsqueeze(0) / 2
+
+ if has_gravity:
+ floor_normal = np.load(f'{output_dir}/floor_normal.npy')
+ floor_height = np.load(f'{output_dir}/floor_height.npy') * scale / 2.
+ batch['floor_height'] = torch.from_numpy(np.array(floor_height)).float().unsqueeze(0)
+
+ # Create rotation matrix to align floor normal with [0, 1, 0] (upward direction)
+ target_normal = np.array([0, 1, 0])
+
+ # Use Rodrigues' rotation formula to find rotation matrix
+ # Rotate from floor_normal to target_normal
+ v = np.cross(floor_normal, target_normal)
+ s = np.linalg.norm(v)
+ c = np.dot(floor_normal, target_normal)
+
+ if s < 1e-6: # If vectors are parallel
+ if c > 0: # Same direction
+ R_floor = np.eye(3)
+ else: # Opposite direction
+ R_floor = -np.eye(3)
+ else:
+ v = v / s
+ K = np.array([[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]])
+ R_floor = np.eye(3) + s * K + (1 - c) * (K @ K)
+
+ R_floor_tensor = torch.from_numpy(R_floor).float().to(device)
+ for i in range(batch['points_src'].shape[0]):
+ batch['points_src'][i] = (R_floor_tensor @ batch['points_src'][i].T).T
+ else:
+ batch['floor_height'] = torch.ones(1).float() * -2.4
+
+ print(f'[Diffusion Simulation] Floor height: {batch["floor_height"]}')
+
+ if mask.shape[1] == 0:
+ mask = torch.zeros(0, N).bool()
+ batch['force'] = torch.zeros(0, 3)
+ batch['drag_point'] = torch.zeros(0, 4)
+ else:
+ batch['force'] = torch.from_numpy(drag_force).float().unsqueeze(0)
+
+ batch['mat_type'] = torch.from_numpy(np.array(args.mat_label)).long()
+ if np.array(batch['mat_type']).item() == 3: # Rigid dataset
+ batch['is_mpm'] = torch.tensor(0).bool()
+ else:
+ batch['is_mpm'] = torch.tensor(1).bool()
+
+ if has_gravity: # Currently we only have either drag force or gravity
+ batch['force'] = torch.tensor([[0, -1.0, 0]]).to(device)
+
+ all_forces = torch.zeros(max_num_forces, 3)
+ all_forces[:batch['force'].shape[0]] = batch['force']
+ all_forces = all_forces[force_order]
+ batch['force'] = all_forces
+
+ all_drag_points = torch.zeros(max_num_forces, 4)
+ all_drag_points[:batch['drag_point'].shape[0], :batch['drag_point'].shape[1]] = batch['drag_point'] # The last dim of drag_point is not used now
+ all_drag_points = all_drag_points[force_order]
+ batch['drag_point'] = all_drag_points
+
+ if batch['gravity'][0] == 1: # add gravity to force
+ batch['force'] = torch.tensor([[0, -1.0, 0]]).float().to(device)
+
+ all_mask = torch.zeros(max_num_forces, N).bool()
+ all_mask[:mask.shape[0]] = mask
+ all_mask = all_mask[force_order]
+
+ batch['mask'] = all_mask[..., None] # (n_forces, N, 1) for compatibility
+ batch['E'] = torch.from_numpy(log_E).unsqueeze(-1).float() if log_E > 0 else torch.zeros(1).float()
+ batch['nu'] = torch.from_numpy(nu).unsqueeze(-1).float()
+
+ for k in batch:
+ batch[k] = batch[k].unsqueeze(0).to(device)
+
+
+ with torch.autocast("cuda", dtype=torch.bfloat16):
+ output = pipeline(batch['points_src'], batch['force'], batch['E'], batch['nu'], batch['mask'][..., :1],
+ batch['drag_point'], batch['floor_height'], batch['gravity'], coeff=batch['E'], generator=torch.Generator().manual_seed(args.seed),
+ device=device, batch_size=1, y=batch['mat_type'], n_frames=n_training_frames, num_inference_steps=25)
+ output = output.cpu().numpy()
+ for j in range(output.shape[0]):
+ if batch['gravity'][0] == 1:
+ for k in range(output.shape[1]):
+ output[j, k] = (np.linalg.inv(R_floor) @ output[j, k].T).T
+ np.save(f'{output_dir}/gen_data.npy', output[j:j+1].squeeze())
+
+def run_vggt(args, output_dir):
+
+ if not os.path.exists(f'{output_dir}/est_pcd.npy'):
+
+ model = VGGT.from_pretrained("facebook/VGGT-1B").to(device)
+ if os.path.exists(f'{args.base_dir}/{args.data_name}/input_ori.png'):
+ image_names = [f'{args.base_dir}/{args.data_name}/input_ori.png']
+ else:
+ image_names = [f'{args.base_dir}/{args.data_name}/input.png']
+
+ images = []
+ for image_name in image_names:
+ image = Image.open(image_name)
+ image = np.array(image)[2:-2, 3:-3]
+ image = image.astype(np.float32) / 255.0
+ images.append(image)
+ images = np.stack(images, axis=0)
+ images = torch.from_numpy(images).permute(0, 3, 1, 2).float().to(device)
+ images = images[:, :3]
+
+ with torch.no_grad():
+ with torch.cuda.amp.autocast(dtype=torch.float16):
+ # Predict attributes including cameras, depth maps, and point maps.
+ predictions = model(images)
+
+ est_pcd = predictions['world_points'].cpu().numpy()
+ depth = predictions['depth'].cpu().numpy()
+ Image.fromarray((depth[0, 0, :, :, 0] * 255).astype(np.uint8)).save(f'{output_dir}/est_depth.png')
+ np.save(f'{output_dir}/est_pcd.npy', est_pcd)
+ est_pcd_export = trimesh.PointCloud(est_pcd.reshape(-1, 3))
+ est_pcd_export.export(f'{output_dir}/est_pcd.ply')
+
+ cfg_json_path = f'{args.base_dir}/{args.data_name}/config.json'
+ with open(cfg_json_path, 'r') as f:
+ cfg_json = json.load(f)
+ floor_loc_begin = np.array(cfg_json["floor_loc_begin"])
+ floor_loc_end = np.array(cfg_json["floor_loc_end"])
+
+ input_mask = np.array(Image.open(f'{output_dir}/input_mask.png').convert('L'))
+ input_mask_eroded = input_mask.copy()
+ kernel = np.ones((5, 5), np.uint8)
+ input_mask_eroded = cv2.erode(input_mask_eroded, kernel, iterations=1)
+ Image.fromarray(input_mask_eroded).save(f'{output_dir}/input_mask_eroded.png')
+
+ est_pcd = np.load(f'{output_dir}/est_pcd.npy')[0, 0]
+ est_pcd = np.pad(est_pcd, ((2, 2), (3, 3), (0, 0)), mode='constant', constant_values=0)
+ est_pcd_masked = est_pcd[input_mask_eroded > 0].reshape(-1, 3)
+ est_pcd_floor = est_pcd[floor_loc_begin[0]:floor_loc_end[0],
+ floor_loc_begin[1]:floor_loc_end[1]].reshape(-1, 3)
+
+ bmax = est_pcd_masked.max(axis=0)
+ bmin = est_pcd_masked.min(axis=0)
+ aabb = bmax - bmin
+ center = (bmax + bmin) / 2
+ scale = aabb.max()
+ est_pcd = (est_pcd - center) / scale
+ est_pcd_masked = (est_pcd_masked - center) / scale
+ est_pcd_floor = (est_pcd_floor - center) / scale
+
+ projection_matrix = np.load('templates/projection.npy')
+ crop_info = np.load(f'{output_dir}/crop_info.npy')
+ h_begin, w_begin, res = crop_info[0], crop_info[1], crop_info[2]
+ image_shape = (res, res) # Example image shape (H, W)
+
+ pc_path = f'{output_dir}/point_cloud.ply'
+ pc = trimesh.load_mesh(pc_path)
+ points = pc.vertices
+ points = np.array(points)
+
+ projected_points = (projection_matrix.T @ np.hstack((points, np.ones((points.shape[0], 1)))).T).T
+ projected_points_weights = 1. / (projected_points[:, -1:] + 1e-8)
+ projected_points = (projected_points * projected_points_weights)[:, :-1]
+
+ projected_points[:, :2] = ((projected_points[:, :2] + 1) * image_shape[1] - 1) / 2
+ projected_points[:, 0] += w_begin
+ projected_points[:, 1] += h_begin
+
+ gt_pcd = np.zeros((480, 720, 3))
+ min_z = np.ones((480, 720)) * 233
+ for i, project_point in enumerate(projected_points):
+ y, x = int(project_point[1]), int(project_point[0])
+ if project_point[2] < min_z[y, x]:
+ gt_pcd[y, x] = points[i]
+ min_z[y, x] = project_point[2]
+
+ gt_pcd_masked = gt_pcd[input_mask_eroded > 0]
+ min_z_masked = min_z[input_mask_eroded > 0]
+ min_z_num = min_z_masked.shape[0]
+ z_values_threshold = np.sort(min_z_masked)[min_z_num // 3]
+
+ est_pcd_masked_ori = est_pcd_masked.copy()
+ est_pcd_masked = est_pcd_masked[min_z_masked < z_values_threshold]
+ gt_pcd_masked = gt_pcd_masked[min_z_masked < z_values_threshold]
+
+ est_pcd_masked_export = trimesh.PointCloud(est_pcd_masked)
+ est_pcd_masked_export.export(f'{output_dir}/est_pcd_masked.ply')
+ gt_pcd_masked_export = trimesh.PointCloud(gt_pcd_masked)
+ gt_pcd_masked_export.export(f'{output_dir}/gt_pcd_masked.ply')
+
+ # Use least squares to find the best-fit similarity transformation (rotation + translation + scale)
+ # between est_pcd_masked and gt_pcd_masked (correspondences are known and ordered)
+ # This is an extension of the Kabsch algorithm to include scaling
+
+ # Compute centroids
+ est_centroid = np.mean(est_pcd_masked, axis=0)
+ gt_centroid = np.mean(gt_pcd_masked, axis=0)
+
+ # Center the point clouds
+ est_centered = est_pcd_masked - est_centroid
+ gt_centered = gt_pcd_masked - gt_centroid
+
+ # Compute covariance matrix
+ H = est_centered.T @ gt_centered
+
+ # SVD
+ U, S, Vt = np.linalg.svd(H)
+ R = Vt.T @ U.T
+
+ # Ensure a proper rotation (determinant = 1)
+ if np.linalg.det(R) < 0:
+ Vt[-1, :] *= -1
+ R = Vt.T @ U.T
+
+ # Compute scale factor
+ scale = np.trace(R.T @ H) / np.trace(est_centered.T @ est_centered)
+
+ # Compute translation
+ t = gt_centroid - scale * R @ est_centroid
+
+ # Compose transformation matrix
+ transform = np.eye(4)
+ transform[:3, :3] = scale * R
+ transform[:3, 3] = t
+
+ # Apply transformation
+ est_pcd_masked_ori_transformed = scale * (R @ est_pcd_masked_ori.T).T + t
+ est_pcd_transformed = scale * (R @ est_pcd_masked.T).T + t
+ est_pcd_transformed_export = trimesh.PointCloud(est_pcd_transformed)
+ est_pcd_transformed_export.export(f'{output_dir}/est_pcd_masked_transformed.ply')
+ est_pcd_floor_transformed = scale * (R @ est_pcd_floor.T).T + t
+ est_pcd_floor_transformed_export = trimesh.PointCloud(est_pcd_floor_transformed)
+ est_pcd_floor_transformed_export.export(f'{output_dir}/est_pcd_floor_transformed.ply')
+
+ # Compute RMSE for the alignment
+ alignment_rmse = np.sqrt(np.mean(np.sum((est_pcd_transformed - gt_pcd_masked) ** 2, axis=1)))
+
+ # Fit a plane using PCA to get normal vector and center point
+ center = np.mean(est_pcd_floor_transformed, axis=0)
+ pca = PCA(n_components=3)
+ pca.fit(est_pcd_floor_transformed)
+ normal = pca.components_[2] # Last component is normal to plane
+
+ # Calculate floor height as distance between the center of est_pcd_masked and the fitted floor plane
+ d = -np.dot(normal, center) # d parameter for plane equation
+ est_centroid = np.mean(est_pcd_masked_ori_transformed, axis=0) # center of est_pcd_masked
+ est_centroid[1] = 0 # set y to 0
+ floor_height = np.abs(np.dot(est_centroid, normal) + d) / np.linalg.norm(normal)
+
+ print(f"[Floor Alignment] Floor Height: {-floor_height}")
+ print(f"[Floor Alignment] Floor Normal: {normal}")
+ np.save(f'{output_dir}/floor_normal.npy', normal)
+ np.save(f'{output_dir}/floor_height.npy', -floor_height)
+
+def run_LGM(args, output_dir):
+
+ device = torch.device("cuda")
+ IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
+ IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
+
+ sys.argv = ['pipeline_track_gen.py', 'big']
+ opt = tyro.cli(AllConfigs)
+
+ model = LGM(opt)
+ ckpt = load_file(args.lgm_ckpt_path, device='cpu')
+ model.load_state_dict(ckpt, strict=False)
+ model = model.half().to(device)
+ model.eval()
+
+ rays_embeddings = model.prepare_default_rays(device)
+ tan_half_fov = np.tan(0.5 * np.deg2rad(opt.fovy))
+ proj_matrix = torch.zeros(4, 4, dtype=torch.float32, device=device)
+ proj_matrix[0, 0] = 1 / tan_half_fov
+ proj_matrix[1, 1] = 1 / tan_half_fov
+ proj_matrix[2, 2] = (opt.zfar + opt.znear) / (opt.zfar - opt.znear)
+ proj_matrix[3, 2] = - (opt.zfar * opt.znear) / (opt.zfar - opt.znear)
+ proj_matrix[2, 3] = 1
+
+ images = []
+ for i in range(4):
+ image = Image.open(f"{output_dir}/view_{i}.png")
+ image = image.resize((256, 256))
+ image = np.array(image)
+ image = image.astype(np.float32) / 255.0
+ if image.shape[-1] == 4:
+ image = image[..., :3] * image[..., 3:4] + (1 - image[..., 3:4])
+ images.append(image)
+ mv_image = np.stack(images, axis=0)
+
+ # generate gaussians
+ input_image = torch.from_numpy(mv_image).permute(0, 3, 1, 2).float().to(device) # [4, 3, 256, 256]
+ input_image = F.interpolate(input_image, size=(opt.input_size, opt.input_size), mode='bilinear', align_corners=False)
+ input_image = TF.normalize(input_image, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
+ input_image = torch.cat([input_image, rays_embeddings], dim=1).unsqueeze(0) # [1, 4, 9, H, W]
+
+ with torch.no_grad():
+ with torch.autocast(device_type='cuda', dtype=torch.float16):
+ # generate gaussians
+ gaussians = model.forward_gaussians(input_image)
+
+ # save gaussians
+ model.gs.save_ply(gaussians, f'{output_dir}/point_cloud.ply')
+
+ # render front view
+ cam_poses = torch.from_numpy(orbit_camera(0, 0, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)
+ # cam_poses = torch.from_numpy(orbit_camera(45, 225, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)
+ cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
+ cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
+ cam_view_proj = cam_view @ proj_matrix # [V, 4, 4]
+
+ cam_pos = - cam_poses[:, :3, 3] # [V, 3]
+ image = model.gs.render(gaussians, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), scale_modifier=1)['image']
+ image_save = (image[0, 0].permute(1, 2, 0).contiguous().float().cpu().numpy() * 255).astype(np.uint8)
+ Image.fromarray(image_save).save(f'{output_dir}/front_view.png')
+
+ images = []
+ azimuth = np.arange(0, 360, 2, dtype=np.int32)
+ elevation = 0
+
+ for azi in tqdm(azimuth):
+
+ cam_poses = torch.from_numpy(orbit_camera(elevation, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)
+ cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
+
+ # cameras needed by gaussian rasterizer
+ cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
+ cam_view_proj = cam_view @ proj_matrix # [V, 4, 4]
+ cam_pos = - cam_poses[:, :3, 3] # [V, 3]
+
+ image = model.gs.render(gaussians, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), scale_modifier=1)['image']
+ images.append((image.squeeze(1).permute(0,2,3,1).contiguous().float().cpu().numpy() * 255).astype(np.uint8))
+
+ images = np.concatenate(images, axis=0)
+ imageio.mimwrite(f'{output_dir}/gs_animation.mp4', images, fps=30)
+
+def run_sv3d(args, output_dir):
+
+ model_path = "chenguolin/sv3d-diffusers"
+ data_dir = f'{output_dir}/data'
+ os.makedirs(data_dir, exist_ok=True)
+
+ num_frames, sv3d_res = 20, 576
+ elevations_deg = [args.elevation] * num_frames
+ polars_rad = [np.deg2rad(90 - e) for e in elevations_deg]
+ azimuths_deg = np.linspace(0, 360, num_frames + 1)[1:] % 360
+ azimuths_rad = [np.deg2rad((a - azimuths_deg[-1]) % 360) for a in azimuths_deg]
+ azimuths_rad[:-1].sort()
+
+ unet = SV3DUNetSpatioTemporalConditionModel.from_pretrained(model_path, subfolder="unet")
+ vae = AutoencoderKL.from_pretrained(model_path, subfolder="vae")
+ scheduler = EulerDiscreteScheduler.from_pretrained(model_path, subfolder="scheduler")
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(model_path, subfolder="image_encoder")
+ feature_extractor = CLIPImageProcessor.from_pretrained(model_path, subfolder="feature_extractor")
+
+ pipeline = StableVideo3DDiffusionPipeline(
+ image_encoder=image_encoder, feature_extractor=feature_extractor,
+ unet=unet, vae=vae,
+ scheduler=scheduler,
+ )
+
+ pipeline = pipeline.to("cuda")
+ with torch.no_grad():
+ with torch.autocast("cuda", dtype=torch.float16, enabled=True):
+
+ image = Image.open(f'{output_dir}/input_processed.png')
+ if len(image.split()) == 4: # RGBA
+ input_image = Image.new("RGB", image.size, (255, 255, 255)) # pure white bg
+ input_image.paste(image, mask=image.split()[3]) # 3rd is the alpha channel
+ else:
+ input_image = image
+
+ video_frames = pipeline(
+ input_image.resize((sv3d_res, sv3d_res)),
+ height=sv3d_res,
+ width=sv3d_res,
+ num_frames=num_frames,
+ decode_chunk_size=8, # smaller to save memory
+ polars_rad=polars_rad,
+ azimuths_rad=azimuths_rad,
+ generator=torch.manual_seed(args.seed),
+ ).frames[0]
+
+ torch.cuda.empty_cache()
+ gc.collect()
+
+ export_to_gif(video_frames, f"{output_dir}/view_animation.gif", fps=7)
+ for i, frame in enumerate(video_frames):
+ frame.save(f"{data_dir}/{i:03d}.png")
+
+ save_idx = [19, 4, 9, 14]
+ for i in range(4):
+ video_frames[save_idx[i]].save(f"{output_dir}/view_{i}.png")
+
+def run_sam(args, output_dir):
+
+ # Load SAM checkpoint
+ sv3d_res = 576
+ sam_predictor = sam_init(args.sam_ckpt_path)
+ print("[SAM] Loaded SAM model")
+
+ input_raw = Image.open(f'{args.base_dir}/{args.data_name}/input.png') if not os.path.exists(f'{args.base_dir}/{args.data_name}/input_masked.png') else Image.open(f'{args.base_dir}/{args.data_name}/input_masked.png')
+ input_sam = sam_out_nosave(sam_predictor, input_raw.convert("RGB"), pred_bbox(input_raw))
+ mask = np.array(input_sam)[:, :, 3]
+ Image.fromarray(mask).save(f"{output_dir}/input_mask.png")
+ y, x, res = image_preprocess(input_sam, f"{output_dir}/input_processed.png", target_res=sv3d_res,
+ lower_contrast=False, rescale=True)
+ np.save(f"{output_dir}/crop_info.npy", np.array([y, x, res]))
+
+from das.models.pipelines import DiffusionAsShaderPipeline
+from das.infer import load_media
+def run_das(args, output_dir, prompt, seed):
+ output_dir = os.path.join(args.output_dir, args.data_name)
+ das = DiffusionAsShaderPipeline(gpu_id=args.gpu, output_dir=os.path.join(args.output_dir, args.data_name))
+ video_tensor, fps, is_video = load_media(f'{args.base_dir}/{args.data_name}/input.png')
+ tracking_tensor, _, _ = load_media(os.path.join(args.output_dir, args.data_name, 'tracks_gen', 'tracking', 'tracks_tracking.mp4'))
+ das.apply_tracking(
+ video_tensor=video_tensor,
+ fps=24,
+ tracking_tensor=tracking_tensor,
+ img_cond_tensor=None,
+ prompt=prompt,
+ checkpoint_path=args.das_ckpt_path,
+ seed=seed
+ )
+
+if __name__ == "__main__":
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--base_dir", default="../examples", type=str, help="Base dir")
+ parser.add_argument("--output_dir", default="../outputs", type=str, help="Output filepath")
+ parser.add_argument("--sam_ckpt_path", default="../checkpoints/sam_vit_h_4b8939.pth")
+ parser.add_argument("--lgm_ckpt_path", default="../checkpoints/lgm_fp16.safetensors")
+ parser.add_argument("--das_ckpt_path", default="../checkpoints/cogshader5B")
+ parser.add_argument("--base_ckpt_path", default="../checkpoints/physctrl_base.safetensors")
+ parser.add_argument("--large_ckpt_path", default="../checkpoints/physctrl_large.safetensors")
+ parser.add_argument("--gpu", type=int, default=0)
+ parser.add_argument("--data_name", default="chair", type=str, help="Data Name")
+ parser.add_argument("--base_cfg_path", default="configs/eval_base.yaml", type=str, help="Model config")
+ parser.add_argument("--large_cfg_path", default="configs/eval_large.yaml", type=str, help="Model config")
+ parser.add_argument("--elevation", default=0, type=float, help="Camera elevation of the input image")
+ parser.add_argument("--seed", default=0, type=int, help="Random seed")
+ parser.add_argument('--tracks_dir', type=str, default='', help='DAS Tracking data directory')
+ parser.add_argument('--output_fps', type=int, default=24, help='DAS Output video FPS')
+ parser.add_argument('--point_size', type=int, default=10, help='DAS Tracking point size')
+ parser.add_argument('--len_track', type=int, default=0, help='DAS Tracking trajectory length')
+ parser.add_argument('--num_frames', type=int, default=49, help='DAS Number of frames to generate black video')
+
+ args = parser.parse_args()
+ seed_everything(args.seed)
+ mat_labels = {'elastic': 0, 'plasticine': 1, 'sand': 2, 'rigid': 3}
+
+ output_dir = f'{args.output_dir}/{args.data_name}'
+ cfg_json_path = f'{args.base_dir}/{args.data_name}/config.json'
+ with open(cfg_json_path, 'r') as f:
+ cfg_json = json.load(f)
+ args.model_path = args.base_ckpt_path
+ args.model_cfg_path = args.base_cfg_path
+
+ mat_type = cfg_json['material']
+ if mat_type in mat_labels:
+ args.mat_label = mat_labels[mat_type]
+ else:
+ raise ValueError(f"Invalid material type: {mat_type}")
+
+ os.makedirs(output_dir, exist_ok=True)
+
+ ## Run SAM to preprocess the input image
+ run_sam(args, output_dir)
+
+ ## Run SV3D to generate 21 frames
+ run_sv3d(args, output_dir)
+
+ ## Run LGM to reconstruct the 3D model
+ run_LGM(args, output_dir)
+
+ ## Run VGGT to infer floor height and floor normal
+ if args.mat_label > 0:
+ args.model_path = args.large_ckpt_path
+ args.model_cfg_path = args.large_cfg_path
+ run_vggt(args, output_dir)
+
+ ## Run Generation to get results and tracks
+ run_diffusion(args, output_dir)
+ run_track(args, output_dir)
+
+ ## Run Video Generation
+ prompt = cfg_json['prompt']
+ run_das(args, output_dir, prompt, seed=cfg_json['seed'] if 'seed' in cfg_json else 42)
+
+
+
\ No newline at end of file
diff --git a/src/model/dit.py b/src/model/dit.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc3b4ed25a28595d953b3ec1c93b2274799d0f9e
--- /dev/null
+++ b/src/model/dit.py
@@ -0,0 +1,598 @@
+# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Dict, Optional, Tuple, Union
+
+import torch
+from torch import nn
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.loaders import PeftAdapterMixin
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
+from diffusers.utils import is_accelerate_version, is_accelerate_available
+from diffusers.utils.torch_utils import maybe_allow_in_graph
+from diffusers.models.attention import Attention, FeedForward
+from diffusers.models.attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
+from diffusers.models.embeddings import TimestepEmbedding, Timesteps
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
+ """
+ This function generates 1D positional embeddings from a grid.
+
+ Args:
+ embed_dim (`int`): The embedding dimension `D`
+ pos (`torch.Tensor`): 1D tensor of positions with shape `(M,)`
+
+ Returns:
+ `torch.Tensor`: Sinusoidal positional embeddings of shape `(M, D)`.
+ """
+ if embed_dim % 2 != 0:
+ raise ValueError("embed_dim must be divisible by 2")
+
+ omega = torch.arange(embed_dim // 2, device=pos.device, dtype=torch.float64)
+ omega /= embed_dim / 2.0
+ omega = 1.0 / 10000**omega # (D/2,)
+
+ pos = pos.reshape(-1) # (M,)
+ out = torch.outer(pos, omega) # (M, D/2), outer product
+
+ emb_sin = torch.sin(out) # (M, D/2)
+ emb_cos = torch.cos(out) # (M, D/2)
+
+ emb = torch.concat([emb_sin, emb_cos], dim=1) # (M, D)
+ return emb
+
+def get_3d_sincos_pos_embed(
+ embed_dim: int,
+ spatial_size: Union[int, Tuple[int, int]],
+ temporal_size: int,
+ spatial_interpolation_scale: float = 1.0,
+ temporal_interpolation_scale: float = 1.0,
+ device: Optional[torch.device] = None,
+ output_type: str = "np",
+) -> torch.Tensor:
+ r"""
+ Creates 3D sinusoidal positional embeddings.
+
+ Args:
+ embed_dim (`int`):
+ The embedding dimension of inputs. It must be divisible by 16.
+ spatial_size (`int` or `Tuple[int, int]`):
+ The spatial dimension of positional embeddings. If an integer is provided, the same size is applied to both
+ spatial dimensions (height and width).
+ temporal_size (`int`):
+ The temporal dimension of postional embeddings (number of frames).
+ spatial_interpolation_scale (`float`, defaults to 1.0):
+ Scale factor for spatial grid interpolation.
+ temporal_interpolation_scale (`float`, defaults to 1.0):
+ Scale factor for temporal grid interpolation.
+
+ Returns:
+ `torch.Tensor`:
+ The 3D sinusoidal positional embeddings of shape `[temporal_size, spatial_size[0] * spatial_size[1],
+ embed_dim]`.
+ """
+
+ if embed_dim % 4 != 0:
+ raise ValueError("`embed_dim` must be divisible by 4")
+
+ embed_dim_spatial = 3 * embed_dim // 4
+ embed_dim_temporal = embed_dim // 4
+
+ # 1. Spatial
+ grid_pc = torch.arange(spatial_size, device=device, dtype=torch.float32) / spatial_interpolation_scale
+ pos_embed_spatial = get_1d_sincos_pos_embed_from_grid(embed_dim_spatial, grid_pc)
+
+ # 2. Temporal
+ grid_t = torch.arange(temporal_size, device=device, dtype=torch.float32) / temporal_interpolation_scale
+ pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(embed_dim_temporal, grid_t)
+
+ # 3. Concat
+ pos_embed_spatial = pos_embed_spatial[None, :, :]
+ pos_embed_spatial = pos_embed_spatial.repeat_interleave(temporal_size, dim=0) # [T, H*W, D // 4 * 3]
+ pos_embed_temporal = pos_embed_temporal[:, None, :]
+ pos_embed_temporal = pos_embed_temporal.repeat_interleave(spatial_size, dim=1) # [T, H*W, D // 4]
+ pos_embed = torch.concat([pos_embed_temporal, pos_embed_spatial], dim=-1) # [T, H*W, D]
+ return pos_embed
+
+@maybe_allow_in_graph
+class CogVideoXBlock(nn.Module):
+ r"""
+ Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model.
+
+ Parameters:
+ dim (`int`):
+ The number of channels in the input and output.
+ num_attention_heads (`int`):
+ The number of heads to use for multi-head attention.
+ attention_head_dim (`int`):
+ The number of channels in each head.
+ time_embed_dim (`int`):
+ The number of channels in timestep embedding.
+ dropout (`float`, defaults to `0.0`):
+ The dropout probability to use.
+ activation_fn (`str`, defaults to `"gelu-approximate"`):
+ Activation function to be used in feed-forward.
+ attention_bias (`bool`, defaults to `False`):
+ Whether or not to use bias in attention projection layers.
+ qk_norm (`bool`, defaults to `True`):
+ Whether or not to use normalization after query and key projections in Attention.
+ norm_elementwise_affine (`bool`, defaults to `True`):
+ Whether to use learnable elementwise affine parameters for normalization.
+ norm_eps (`float`, defaults to `1e-5`):
+ Epsilon value for normalization layers.
+ final_dropout (`bool` defaults to `False`):
+ Whether to apply a final dropout after the last feed-forward layer.
+ ff_inner_dim (`int`, *optional*, defaults to `None`):
+ Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used.
+ ff_bias (`bool`, defaults to `True`):
+ Whether or not to use bias in Feed-forward layer.
+ attention_out_bias (`bool`, defaults to `True`):
+ Whether or not to use bias in Attention output projection layer.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ time_embed_dim: int,
+ dropout: float = 0.0,
+ activation_fn: str = "gelu-approximate",
+ attention_bias: bool = False,
+ qk_norm: bool = True,
+ norm_elementwise_affine: bool = True,
+ norm_eps: float = 1e-5,
+ final_dropout: bool = True,
+ ff_inner_dim: Optional[int] = None,
+ ff_bias: bool = True,
+ attention_out_bias: bool = True,
+ ):
+ super().__init__()
+
+ # 1. Self Attention
+ self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
+
+ self.attn1 = Attention(
+ query_dim=dim,
+ dim_head=attention_head_dim,
+ heads=num_attention_heads,
+ qk_norm="layer_norm" if qk_norm else None,
+ eps=1e-6,
+ bias=attention_bias,
+ out_bias=attention_out_bias,
+ processor=CogVideoXAttnProcessor2_0(),
+ )
+
+ # 2. Feed Forward
+ self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
+
+ self.ff = FeedForward(
+ dim,
+ dropout=dropout,
+ activation_fn=activation_fn,
+ final_dropout=final_dropout,
+ inner_dim=ff_inner_dim,
+ bias=ff_bias,
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ ) -> torch.Tensor:
+ text_seq_length = encoder_hidden_states.size(1)
+
+ # norm & modulate
+ norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
+ hidden_states, encoder_hidden_states, temb
+ )
+
+ # attention
+ attn_hidden_states, attn_encoder_hidden_states = self.attn1(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=norm_encoder_hidden_states,
+ image_rotary_emb=image_rotary_emb,
+ )
+
+ hidden_states = hidden_states + gate_msa * attn_hidden_states
+ encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
+
+ # norm & modulate
+ norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
+ hidden_states, encoder_hidden_states, temb
+ )
+
+ # feed-forward
+ norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
+ ff_output = self.ff(norm_hidden_states)
+
+ hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
+ encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
+
+ return hidden_states, encoder_hidden_states
+
+class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
+ """
+ A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
+
+ Parameters:
+ num_attention_heads (`int`, defaults to `30`):
+ The number of heads to use for multi-head attention.
+ attention_head_dim (`int`, defaults to `64`):
+ The number of channels in each head.
+ in_channels (`int`, defaults to `16`):
+ The number of channels in the input.
+ out_channels (`int`, *optional*, defaults to `16`):
+ The number of channels in the output.
+ flip_sin_to_cos (`bool`, defaults to `True`):
+ Whether to flip the sin to cos in the time embedding.
+ time_embed_dim (`int`, defaults to `512`):
+ Output dimension of timestep embeddings.
+ ofs_embed_dim (`int`, defaults to `512`):
+ Output dimension of "ofs" embeddings used in CogVideoX-5b-I2B in version 1.5
+ text_embed_dim (`int`, defaults to `4096`):
+ Input dimension of text embeddings from the text encoder.
+ num_layers (`int`, defaults to `30`):
+ The number of layers of Transformer blocks to use.
+ dropout (`float`, defaults to `0.0`):
+ The dropout probability to use.
+ attention_bias (`bool`, defaults to `True`):
+ Whether to use bias in the attention projection layers.
+ sample_width (`int`, defaults to `90`):
+ The width of the input latents.
+ sample_height (`int`, defaults to `60`):
+ The height of the input latents.
+ sample_frames (`int`, defaults to `49`):
+ The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49
+ instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings,
+ but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with
+ K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1).
+ patch_size (`int`, defaults to `2`):
+ The size of the patches to use in the patch embedding layer.
+ temporal_compression_ratio (`int`, defaults to `4`):
+ The compression ratio across the temporal dimension. See documentation for `sample_frames`.
+ max_text_seq_length (`int`, defaults to `226`):
+ The maximum sequence length of the input text embeddings.
+ activation_fn (`str`, defaults to `"gelu-approximate"`):
+ Activation function to use in feed-forward.
+ timestep_activation_fn (`str`, defaults to `"silu"`):
+ Activation function to use when generating the timestep embeddings.
+ norm_elementwise_affine (`bool`, defaults to `True`):
+ Whether to use elementwise affine in normalization layers.
+ norm_eps (`float`, defaults to `1e-5`):
+ The epsilon value to use in normalization layers.
+ spatial_interpolation_scale (`float`, defaults to `1.875`):
+ Scaling factor to apply in 3D positional embeddings across spatial dimensions.
+ temporal_interpolation_scale (`float`, defaults to `1.0`):
+ Scaling factor to apply in 3D positional embeddings across temporal dimensions.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ num_attention_heads: int = 8,
+ attention_head_dim: int = 64,
+ in_channels: int = 3,
+ out_channels: Optional[int] = 3,
+ flip_sin_to_cos: bool = True,
+ freq_shift: int = 0,
+ time_embed_dim: int = 512,
+ ofs_embed_dim: Optional[int] = None,
+ text_embed_dim: int = 4096,
+ num_layers: int = 8,
+ dropout: float = 0.0,
+ attention_bias: bool = True,
+ sample_points: int = 2048,
+ sample_frames: int = 48,
+ patch_size: int = 1,
+ patch_size_t: Optional[int] = None,
+ temporal_compression_ratio: int = 4,
+ max_text_seq_length: int = 226,
+ activation_fn: str = "gelu-approximate",
+ timestep_activation_fn: str = "silu",
+ norm_elementwise_affine: bool = True,
+ norm_eps: float = 1e-5,
+ spatial_interpolation_scale: float = 1.875,
+ temporal_interpolation_scale: float = 1.0,
+ use_positional_embeddings: bool = True,
+ use_learned_positional_embeddings: bool = False,
+ patch_bias: bool = True,
+ cond_seq_length: int = 4
+ ):
+ super().__init__()
+ inner_dim = num_attention_heads * attention_head_dim
+
+ if use_positional_embeddings and use_learned_positional_embeddings:
+ raise ValueError(
+ "There are no CogVideoX checkpoints available with disable rotary embeddings and learned positional "
+ "embeddings. If you're using a custom model and/or believe this should be supported, please open an "
+ "issue at https://github.com/huggingface/diffusers/issues."
+ )
+
+ self.embedding_dropout = nn.Dropout(dropout)
+
+ # 2. Time embeddings and ofs embedding(Only CogVideoX1.5-5B I2V have)
+
+ self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
+ self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
+
+ self.ofs_proj = None
+ self.ofs_embedding = None
+ if ofs_embed_dim:
+ self.ofs_proj = Timesteps(ofs_embed_dim, flip_sin_to_cos, freq_shift)
+ self.ofs_embedding = TimestepEmbedding(
+ ofs_embed_dim, ofs_embed_dim, timestep_activation_fn
+ ) # same as time embeddings, for ofs
+
+ # 3. Define spatio-temporal transformers blocks
+ self.transformer_blocks = nn.ModuleList(
+ [
+ CogVideoXBlock(
+ dim=inner_dim,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ time_embed_dim=time_embed_dim,
+ dropout=dropout,
+ activation_fn=activation_fn,
+ attention_bias=attention_bias,
+ norm_elementwise_affine=norm_elementwise_affine,
+ norm_eps=norm_eps,
+ )
+ for _ in range(num_layers)
+ ]
+ )
+ self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine)
+
+ # 4. Output blocks
+ self.norm_out = AdaLayerNorm(
+ embedding_dim=time_embed_dim,
+ output_dim=2 * inner_dim,
+ norm_elementwise_affine=norm_elementwise_affine,
+ norm_eps=norm_eps,
+ chunk_dim=1,
+ )
+
+ if patch_size_t is None:
+ # For CogVideox 1.0
+ output_dim = patch_size * patch_size * out_channels
+ else:
+ # For CogVideoX 1.5
+ output_dim = patch_size * patch_size * patch_size_t * out_channels
+
+ self.proj_out = nn.Linear(inner_dim, output_dim)
+
+ self.gradient_checkpointing = False
+
+ if use_positional_embeddings or use_learned_positional_embeddings:
+ self.embed_dim = num_attention_heads * attention_head_dim
+ self.cond_seq_length = cond_seq_length
+ persistent = use_learned_positional_embeddings
+ pos_embedding = self._get_positional_embeddings(sample_points, sample_frames)
+ self.register_buffer("pos_embedding", pos_embedding, persistent=persistent)
+
+ def _get_positional_embeddings(self, points: int, frames: int, device: Optional[torch.device] = None) -> torch.Tensor:
+ pos_embedding = get_3d_sincos_pos_embed(
+ self.embed_dim,
+ points,
+ frames,
+ device=device,
+ output_type="pt",
+ )
+ pos_embedding = pos_embedding.flatten(0, 1)
+ joint_pos_embedding = pos_embedding.new_zeros(
+ 1, self.cond_seq_length + points * frames, self.embed_dim, requires_grad=False
+ )
+ joint_pos_embedding.data[:, self.cond_seq_length:].copy_(pos_embedding)
+ return joint_pos_embedding
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ self.gradient_checkpointing = value
+
+ @property
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor()
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0
+ def fuse_qkv_projections(self):
+ """
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
+ are fused. For cross-attention modules, key and value projection matrices are fused.
+
+
+
+ This API is 🧪 experimental.
+
+
+ """
+ self.original_attn_processors = None
+
+ for _, attn_processor in self.attn_processors.items():
+ if "Added" in str(attn_processor.__class__.__name__):
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
+
+ self.original_attn_processors = self.attn_processors
+
+ for module in self.modules():
+ if isinstance(module, Attention):
+ module.fuse_projections(fuse=True)
+
+ self.set_attn_processor(FusedCogVideoXAttnProcessor2_0())
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
+ def unfuse_qkv_projections(self):
+ """Disables the fused QKV projection if enabled.
+
+
+
+ This API is 🧪 experimental.
+
+
+
+ """
+ if self.original_attn_processors is not None:
+ self.set_attn_processor(self.original_attn_processors)
+
+ def forward(
+ self,
+ full_seq: torch.Tensor, # [batch_size]
+ timestep: Union[int, float, torch.LongTensor],
+ timestep_cond: Optional[torch.Tensor] = None,
+ ofs: Optional[Union[int, float, torch.LongTensor]] = None,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ return_dict: bool = True,
+ ):
+ if attention_kwargs is not None:
+ attention_kwargs = attention_kwargs.copy()
+ lora_scale = attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+ else:
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
+ logger.warning(
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
+ )
+
+ # 1. Time embedding
+ timesteps = timestep
+ t_emb = self.time_proj(timesteps)
+
+ # timesteps does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=full_seq.dtype)
+ emb = self.time_embedding(t_emb, timestep_cond)
+
+ if self.ofs_embedding is not None:
+ ofs_emb = self.ofs_proj(ofs)
+ ofs_emb = ofs_emb.to(dtype=full_seq.dtype)
+ ofs_emb = self.ofs_embedding(ofs_emb)
+ emb = emb + ofs_emb
+
+ # 2. Patch embedding
+ pos_embedding = self.pos_embedding
+ pos_embedding = pos_embedding.to(dtype=full_seq.dtype)
+ hidden_states = full_seq + pos_embedding
+
+ hidden_states = self.embedding_dropout(hidden_states)
+ encoder_hidden_states = hidden_states[:, :self.cond_seq_length]
+ hidden_states = hidden_states[:, self.cond_seq_length:]
+
+ # 3. Transformer blocks
+ for i, block in enumerate(self.transformer_blocks):
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ hidden_states,
+ encoder_hidden_states,
+ emb,
+ image_rotary_emb,
+ **ckpt_kwargs,
+ )
+ else:
+ hidden_states, encoder_hidden_states = block(
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ temb=emb,
+ image_rotary_emb=image_rotary_emb,
+ )
+
+ # 4. Final block
+ hidden_states = self.norm_final(hidden_states)
+ hidden_states = self.norm_out(hidden_states, temb=emb)
+ output = self.proj_out(hidden_states)
+
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+
+ return output
\ No newline at end of file
diff --git a/src/model/mdm_dit.py b/src/model/mdm_dit.py
new file mode 100644
index 0000000000000000000000000000000000000000..92ec21075cda19fc622342bd77f06e9e6560a425
--- /dev/null
+++ b/src/model/mdm_dit.py
@@ -0,0 +1,127 @@
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+import sys
+sys.path.append('./')
+
+from model.dit import CogVideoXTransformer3DModel
+
+class PointEmbed(nn.Module):
+ def __init__(self, hidden_dim=96, dim=512):
+ super().__init__()
+
+ assert hidden_dim % 6 == 0
+
+ self.embedding_dim = hidden_dim
+ e = torch.pow(2, torch.arange(self.embedding_dim // 6)).float() * np.pi
+ e = torch.stack([
+ torch.cat([e, torch.zeros(self.embedding_dim // 6),
+ torch.zeros(self.embedding_dim // 6)]),
+ torch.cat([torch.zeros(self.embedding_dim // 6), e,
+ torch.zeros(self.embedding_dim // 6)]),
+ torch.cat([torch.zeros(self.embedding_dim // 6),
+ torch.zeros(self.embedding_dim // 6), e]),
+ ])
+ self.register_buffer('basis', e) # 3 x 16
+
+ self.mlp = nn.Linear(self.embedding_dim+3, dim)
+
+ @staticmethod
+ def embed(input, basis):
+ projections = torch.einsum(
+ 'bnd,de->bne', input, basis)
+ embeddings = torch.cat([projections.sin(), projections.cos()], dim=2)
+ return embeddings
+
+ def forward(self, input):
+ # input: B x N x 3
+ embed = self.mlp(torch.cat([self.embed(input, self.basis), input], dim=2)) # B x N x C
+ return embed
+
+class PositionalEncoding(nn.Module):
+ def __init__(self, d_model, dropout=0.1, max_len=5000):
+ super(PositionalEncoding, self).__init__()
+ self.dropout = nn.Dropout(p=dropout)
+
+ pe = torch.zeros(max_len, d_model)
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
+ pe[:, 0::2] = torch.sin(position * div_term)
+ pe[:, 1::2] = torch.cos(position * div_term)
+ pe = pe.unsqueeze(0).transpose(0, 1)
+
+ self.register_parameter('pe', nn.Parameter(pe, requires_grad=False))
+
+ def forward(self, x):
+ # not used in the final model
+ x = x + self.pe[:x.shape[0], :]
+ return self.dropout(x)
+
+class MDM_DiT(nn.Module):
+
+ def __init__(self, n_points, n_frame, n_feats, model_config):
+ super().__init__()
+
+ self.n_points = n_points
+ self.n_feats = n_feats
+ self.latent_dim = model_config.latent_dim
+ self.cond_seq_length = 4
+ self.cond_frame = 1 if model_config.frame_cond else 0
+
+ self.dit = CogVideoXTransformer3DModel(sample_points=n_points, sample_frames=n_frame+self.cond_frame, in_channels=n_feats,
+ num_layers=model_config.n_layers, num_attention_heads=self.latent_dim // 64, cond_seq_length=self.cond_seq_length)
+
+ self.input_encoder = PointEmbed(dim=self.latent_dim)
+ # self.init_cond_encoder = PointEmbed(dim=self.latent_dim)
+ self.E_cond_encoder = nn.Linear(1, self.latent_dim)
+ self.nu_cond_encoder = nn.Linear(1, self.latent_dim)
+ self.force_cond_encoder = nn.Linear(3, self.latent_dim)
+ self.drag_point_encoder = nn.Linear(3, self.latent_dim)
+
+ def enable_gradient_checkpointing(self):
+ self.dit._set_gradient_checkpointing(True)
+
+ def forward(self, x, timesteps, init_pc, force, E, nu, drag_mask, drag_point, floor_height=None, coeff=None, y=None, null_emb=0):
+
+ """
+ x: [batch_size, frame, n_points, n_feats], denoted x_t in the paper
+ timesteps: [batch_size] (int)
+ """
+
+ bs, n_frame, n_points, n_feats = x.shape
+
+ init_pc = init_pc.reshape(bs, n_points, n_feats)
+ force = force.unsqueeze(1)
+ E = E.unsqueeze(1)
+ nu = nu.unsqueeze(1)
+ drag_point = drag_point.unsqueeze(1)
+ x = torch.cat([init_pc.unsqueeze(1), x], axis=1)
+ n_frame += 1
+ encoder_hidden_states = torch.cat([self.force_cond_encoder(force), self.E_cond_encoder(E),
+ self.nu_cond_encoder(nu), self.drag_point_encoder(drag_point)], axis=1)
+ hidden_states = self.input_encoder(x.reshape(bs * n_frame, n_points,
+ n_feats)).reshape(bs, n_frame, n_points, self.latent_dim)
+ full_seq = torch.cat([encoder_hidden_states, hidden_states.reshape(bs, n_frame * n_points, self.latent_dim)], axis=1)
+ output = self.dit(full_seq, timesteps).reshape(bs, n_frame, n_points, 3)[:, self.cond_frame:]
+ output = output + init_pc.unsqueeze(1)
+
+ return output
+
+if __name__ == "__main__":
+
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ point_num = 512
+ frame_num = 6
+
+ x = torch.randn(2, frame_num, point_num, 3).to(device).to(torch.float16)
+ timesteps = torch.tensor([999, 999]).int().to(device).to(torch.float16)
+ init_pc = torch.randn(2, 1, point_num, 3).to(device).to(torch.float16)
+ force = torch.randn(2, 3).to(device).to(torch.float16)
+ E = torch.randn(2, 1).to(device).to(torch.float16)
+ nu = torch.randn(2, 1).to(device).to(torch.float16)
+
+ model = MDM_DiT([point_num], 3).to(device).to(torch.float16)
+ output = model(x, timesteps, init_pc, force, E, nu)
+ print(output.shape)
diff --git a/src/model/spacetime.py b/src/model/spacetime.py
new file mode 100644
index 0000000000000000000000000000000000000000..66815f5d38f267a66d4658080e8c718f9dc48f30
--- /dev/null
+++ b/src/model/spacetime.py
@@ -0,0 +1,1113 @@
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+import sys
+sys.path.append('./')
+
+from einops import rearrange, repeat
+from model.dit import *
+from diffusers.models.embeddings import LabelEmbedding
+
+class PointEmbed(nn.Module):
+ def __init__(self, hidden_dim=96, dim=512):
+ super().__init__()
+
+ assert hidden_dim % 6 == 0
+
+ self.embedding_dim = hidden_dim
+ e = torch.pow(2, torch.arange(self.embedding_dim // 6)).float() * np.pi
+ e = torch.stack([
+ torch.cat([e, torch.zeros(self.embedding_dim // 6),
+ torch.zeros(self.embedding_dim // 6)]),
+ torch.cat([torch.zeros(self.embedding_dim // 6), e,
+ torch.zeros(self.embedding_dim // 6)]),
+ torch.cat([torch.zeros(self.embedding_dim // 6),
+ torch.zeros(self.embedding_dim // 6), e]),
+ ])
+ self.register_buffer('basis', e) # 3 x 16
+
+ self.mlp = nn.Linear(self.embedding_dim+3, dim)
+
+ @staticmethod
+ def embed(input, basis):
+ projections = torch.einsum(
+ 'bnd,de->bne', input, basis)
+ embeddings = torch.cat([projections.sin(), projections.cos()], dim=2)
+ return embeddings
+
+ def forward(self, input):
+ # input: B x N x 3
+ embed = self.mlp(torch.cat([self.embed(input, self.basis), input], dim=2)) # B x N x C
+ return embed
+
+class AdaLayerNorm(nn.Module):
+ r"""
+ Norm layer modified to incorporate timestep embeddings.
+
+ Parameters:
+ embedding_dim (`int`): The size of each embedding vector.
+ num_embeddings (`int`, *optional*): The size of the embeddings dictionary.
+ output_dim (`int`, *optional*):
+ norm_elementwise_affine (`bool`, defaults to `False):
+ norm_eps (`bool`, defaults to `False`):
+ chunk_dim (`int`, defaults to `0`):
+ """
+
+ def __init__(
+ self,
+ embedding_dim: int,
+ num_embeddings: Optional[int] = None,
+ output_dim: Optional[int] = None,
+ norm_elementwise_affine: bool = False,
+ norm_eps: float = 1e-5,
+ chunk_dim: int = 0,
+ ):
+ super().__init__()
+
+ self.chunk_dim = chunk_dim
+ output_dim = output_dim or embedding_dim * 2
+
+ if num_embeddings is not None:
+ self.emb = nn.Embedding(num_embeddings, embedding_dim)
+ else:
+ self.emb = None
+
+ self.silu = nn.SiLU()
+ self.linear = nn.Linear(embedding_dim, output_dim)
+ self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine)
+
+ def forward(
+ self, x: torch.Tensor, timestep: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None
+ ) -> torch.Tensor:
+ if self.emb is not None:
+ temb = self.emb(timestep)
+
+ temb = self.linear(self.silu(temb))
+
+ if self.chunk_dim == 1:
+ # This is a bit weird why we have the order of "shift, scale" here and "scale, shift" in the
+ # other if-branch. This branch is specific to CogVideoX for now.
+ shift, scale = temb.chunk(2, dim=1)
+ shift = shift[:, None, :]
+ scale = scale[:, None, :]
+ else:
+ scale, shift = temb.chunk(2, dim=0)
+
+ x = self.norm(x) * (1 + scale) + shift
+ return x
+
+
+@maybe_allow_in_graph
+class SpatialTemporalTransformerBlock(nn.Module):
+ r"""
+ Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model.
+
+ Parameters:
+ dim (`int`):
+ The number of channels in the input and output.
+ num_attention_heads (`int`):
+ The number of heads to use for multi-head attention.
+ attention_head_dim (`int`):
+ The number of channels in each head.
+ time_embed_dim (`int`):
+ The number of channels in timestep embedding.
+ dropout (`float`, defaults to `0.0`):
+ The dropout probability to use.
+ activation_fn (`str`, defaults to `"gelu-approximate"`):
+ Activation function to be used in feed-forward.
+ attention_bias (`bool`, defaults to `False`):
+ Whether or not to use bias in attention projection layers.
+ qk_norm (`bool`, defaults to `True`):
+ Whether or not to use normalization after query and key projections in Attention.
+ norm_elementwise_affine (`bool`, defaults to `True`):
+ Whether to use learnable elementwise affine parameters for normalization.
+ norm_eps (`float`, defaults to `1e-5`):
+ Epsilon value for normalization layers.
+ final_dropout (`bool` defaults to `False`):
+ Whether to apply a final dropout after the last feed-forward layer.
+ ff_inner_dim (`int`, *optional*, defaults to `None`):
+ Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used.
+ ff_bias (`bool`, defaults to `True`):
+ Whether or not to use bias in Feed-forward layer.
+ attention_out_bias (`bool`, defaults to `True`):
+ Whether or not to use bias in Attention output projection layer.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ time_embed_dim: int,
+ dropout: float = 0.0,
+ activation_fn: str = "gelu-approximate",
+ attention_bias: bool = False,
+ qk_norm: bool = True,
+ norm_elementwise_affine: bool = True,
+ norm_eps: float = 1e-5,
+ final_dropout: bool = True,
+ ff_inner_dim: Optional[int] = None,
+ ff_bias: bool = True,
+ attention_out_bias: bool = True,
+ ):
+ super().__init__()
+
+ # 1. Self Attention
+ self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
+
+ self.attn1 = Attention(
+ query_dim=dim,
+ dim_head=attention_head_dim,
+ heads=num_attention_heads,
+ qk_norm="layer_norm" if qk_norm else None,
+ eps=1e-6,
+ bias=attention_bias,
+ out_bias=attention_out_bias,
+ processor=CogVideoXAttnProcessor2_0(),
+ )
+
+ self.norm_temp = AdaLayerNorm(dim, chunk_dim=1)
+ self.attn_temp = Attention(
+ query_dim=dim,
+ dim_head=attention_head_dim,
+ heads=num_attention_heads,
+ qk_norm="layer_norm" if qk_norm else None,
+ eps=1e-6,
+ bias=attention_bias,
+ out_bias=attention_out_bias,
+ )
+
+ # 2. Feed Forward
+ self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
+
+ self.ff = FeedForward(
+ dim,
+ dropout=dropout,
+ activation_fn=activation_fn,
+ final_dropout=final_dropout,
+ inner_dim=ff_inner_dim,
+ bias=ff_bias,
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ ) -> torch.Tensor:
+ temb_in = temb
+ text_seq_length = encoder_hidden_states.size(1)
+
+ B, F, N, C = hidden_states.shape
+ hidden_states = hidden_states.reshape(-1, N, C)
+ if encoder_hidden_states.shape[0] != B * F:
+ encoder_hidden_states = encoder_hidden_states.repeat_interleave(F, 0)
+ temb = temb_in.repeat_interleave(F, 0)
+
+ # norm & modulate
+ norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
+ hidden_states, encoder_hidden_states, temb
+ )
+
+ # Spatial Attention
+ attn_hidden_states, attn_encoder_hidden_states = self.attn1(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=norm_encoder_hidden_states,
+ image_rotary_emb=image_rotary_emb,
+ )
+
+ hidden_states = hidden_states + gate_msa * attn_hidden_states
+ encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
+
+ # norm & modulate
+ norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
+ hidden_states, encoder_hidden_states, temb
+ )
+
+ # feed-forward
+ norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
+ ff_output = self.ff(norm_hidden_states)
+
+ hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
+ encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
+
+ ## Time Attention
+ hidden_states = rearrange(hidden_states, '(b f) n c -> (b n) f c', f=F)
+ temb = temb_in.repeat_interleave(N, 0)
+ norm_hidden_states = self.norm_temp(hidden_states, temb=temb)
+ hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
+ hidden_states = rearrange(hidden_states, '(b n) f c -> b f n c', n=N)
+
+ # hidden_states = rearrange(hidden_states, '(b f) n c -> b f n c', f=F)
+
+ return hidden_states, encoder_hidden_states
+
+@maybe_allow_in_graph
+class SpatialOnlyTransformerBlock(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ time_embed_dim: int,
+ dropout: float = 0.0,
+ activation_fn: str = "gelu-approximate",
+ attention_bias: bool = False,
+ qk_norm: bool = True,
+ norm_elementwise_affine: bool = True,
+ norm_eps: float = 1e-5,
+ final_dropout: bool = True,
+ ff_inner_dim: Optional[int] = None,
+ ff_bias: bool = True,
+ attention_out_bias: bool = True,
+ ):
+ super().__init__()
+
+ # 1. Self Attention
+ self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
+
+ self.attn1 = Attention(
+ query_dim=dim,
+ dim_head=attention_head_dim,
+ heads=num_attention_heads,
+ qk_norm="layer_norm" if qk_norm else None,
+ eps=1e-6,
+ bias=attention_bias,
+ out_bias=attention_out_bias,
+ processor=CogVideoXAttnProcessor2_0(),
+ )
+
+ # 2. Feed Forward
+ self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
+
+ self.ff = FeedForward(
+ dim,
+ dropout=dropout,
+ activation_fn=activation_fn,
+ final_dropout=final_dropout,
+ inner_dim=ff_inner_dim,
+ bias=ff_bias,
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ ) -> torch.Tensor:
+ temb_in = temb
+ text_seq_length = encoder_hidden_states.size(1)
+
+ B, F, N, C = hidden_states.shape
+ hidden_states = hidden_states.reshape(-1, N, C)
+ if encoder_hidden_states.shape[0] != B * F:
+ encoder_hidden_states = encoder_hidden_states.repeat_interleave(F, 0)
+ temb = temb_in.repeat_interleave(F, 0)
+
+ # norm & modulate
+ norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
+ hidden_states, encoder_hidden_states, temb
+ )
+
+ # Spatial Attention
+ attn_hidden_states, attn_encoder_hidden_states = self.attn1(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=norm_encoder_hidden_states,
+ image_rotary_emb=image_rotary_emb,
+ )
+
+ hidden_states = hidden_states + gate_msa * attn_hidden_states
+ encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
+
+ # norm & modulate
+ norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
+ hidden_states, encoder_hidden_states, temb
+ )
+
+ # feed-forward
+ norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
+ ff_output = self.ff(norm_hidden_states)
+
+ hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
+ encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
+
+ hidden_states = rearrange(hidden_states, '(b f) n c -> b f n c', f=F)
+
+ return hidden_states, encoder_hidden_states
+
+@maybe_allow_in_graph
+class TemporalOnlyTransformerBlock(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ time_embed_dim: int,
+ dropout: float = 0.0,
+ activation_fn: str = "gelu-approximate",
+ attention_bias: bool = False,
+ qk_norm: bool = True,
+ norm_elementwise_affine: bool = True,
+ norm_eps: float = 1e-5,
+ final_dropout: bool = True,
+ ff_inner_dim: Optional[int] = None,
+ ff_bias: bool = True,
+ attention_out_bias: bool = True,
+ ):
+ super().__init__()
+
+ # 1. Self Attention
+ # self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
+
+ # self.attn1 = Attention(
+ # query_dim=dim,
+ # dim_head=attention_head_dim,
+ # heads=num_attention_heads,
+ # qk_norm="layer_norm" if qk_norm else None,
+ # eps=1e-6,
+ # bias=attention_bias,
+ # out_bias=attention_out_bias,
+ # processor=CogVideoXAttnProcessor2_0(),
+ # )
+
+ self.norm_temp = AdaLayerNorm(dim, chunk_dim=1)
+ self.attn_temp = Attention(
+ query_dim=dim,
+ dim_head=attention_head_dim,
+ heads=num_attention_heads,
+ qk_norm="layer_norm" if qk_norm else None,
+ eps=1e-6,
+ bias=attention_bias,
+ out_bias=attention_out_bias,
+ )
+
+ # 2. Feed Forward
+ self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
+
+ self.ff = FeedForward(
+ dim,
+ dropout=dropout,
+ activation_fn=activation_fn,
+ final_dropout=final_dropout,
+ inner_dim=ff_inner_dim,
+ bias=ff_bias,
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ ) -> torch.Tensor:
+ temb_in = temb
+ text_seq_length = encoder_hidden_states.size(1)
+
+ B, F, N, C = hidden_states.shape
+ hidden_states = hidden_states.reshape(-1, N, C)
+ if encoder_hidden_states.shape[0] != B * F:
+ encoder_hidden_states = encoder_hidden_states.repeat_interleave(F, 0)
+ temb = temb_in.repeat_interleave(F, 0)
+
+ # # norm & modulate
+ # norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
+ # hidden_states, encoder_hidden_states, temb
+ # )
+
+ # # Spatial Attention
+ # attn_hidden_states, attn_encoder_hidden_states = self.attn1(
+ # hidden_states=norm_hidden_states,
+ # encoder_hidden_states=norm_encoder_hidden_states,
+ # image_rotary_emb=image_rotary_emb,
+ # )
+
+ # hidden_states = hidden_states + gate_msa * attn_hidden_states
+ # encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
+
+ # norm & modulate
+ norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
+ hidden_states, encoder_hidden_states, temb
+ )
+
+ # feed-forward
+ norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
+ ff_output = self.ff(norm_hidden_states)
+
+ hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
+ encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
+
+ ## Time Attention
+ hidden_states = rearrange(hidden_states, '(b f) n c -> (b n) f c', f=F)
+ temb = temb_in.repeat_interleave(N, 0)
+ norm_hidden_states = self.norm_temp(hidden_states, temb=temb)
+ hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
+ hidden_states = rearrange(hidden_states, '(b n) f c -> b f n c', n=N)
+
+ # hidden_states = rearrange(hidden_states, '(b f) n c -> b f n c', f=F)
+
+ return hidden_states, encoder_hidden_states
+
+
+@maybe_allow_in_graph
+class SpatialTemporalTransformerBlockv2(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ time_embed_dim: int,
+ dropout: float = 0.0,
+ activation_fn: str = "gelu-approximate",
+ attention_bias: bool = False,
+ qk_norm: bool = True,
+ norm_elementwise_affine: bool = True,
+ norm_eps: float = 1e-5,
+ final_dropout: bool = True,
+ ff_inner_dim: Optional[int] = None,
+ ff_bias: bool = True,
+ attention_out_bias: bool = True,
+ ):
+ super().__init__()
+
+ # 1. Self Attention
+ self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
+
+ self.attn1 = Attention(
+ query_dim=dim,
+ dim_head=attention_head_dim,
+ heads=num_attention_heads,
+ qk_norm="layer_norm" if qk_norm else None,
+ eps=1e-6,
+ bias=attention_bias,
+ out_bias=attention_out_bias,
+ processor=CogVideoXAttnProcessor2_0(),
+ )
+
+ self.norm_temp = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
+ self.attn_temp = Attention(
+ query_dim=dim,
+ dim_head=attention_head_dim,
+ heads=num_attention_heads,
+ qk_norm="layer_norm" if qk_norm else None,
+ eps=1e-6,
+ bias=attention_bias,
+ out_bias=attention_out_bias,
+ processor=CogVideoXAttnProcessor2_0(),
+ )
+
+ # 2. Feed Forward
+ self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
+
+ self.ff = FeedForward(
+ dim,
+ dropout=dropout,
+ activation_fn=activation_fn,
+ final_dropout=final_dropout,
+ inner_dim=ff_inner_dim,
+ bias=ff_bias,
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ encoder_hidden_states_time: torch.Tensor,
+ temb: torch.Tensor,
+ indices: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ ) -> torch.Tensor:
+ temb_in = temb
+ text_seq_length = encoder_hidden_states.size(1)
+
+ B, F, N, C = hidden_states.shape
+ hidden_states = hidden_states.reshape(-1, N, C)
+ if encoder_hidden_states.shape[0] != B * F:
+ encoder_hidden_states = encoder_hidden_states.repeat_interleave(F, 0)
+ temb = temb_in.repeat_interleave(F, 0)
+
+ # norm & modulate
+ norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
+ hidden_states, encoder_hidden_states, temb
+ )
+
+ # Spatial Attention
+ attn_hidden_states, attn_encoder_hidden_states = self.attn1(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=norm_encoder_hidden_states,
+ indices=indices,
+ image_rotary_emb=image_rotary_emb,
+ )
+
+ hidden_states = hidden_states + gate_msa * attn_hidden_states
+ encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
+
+ # norm & modulate
+ norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
+ hidden_states, encoder_hidden_states, temb
+ )
+
+ # feed-forward
+ norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
+ ff_output = self.ff(norm_hidden_states)
+
+ hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
+ encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
+
+ ## Time Attention
+ hidden_states = rearrange(hidden_states, '(b f) n c -> (b n) f c', f=F)
+ temb = temb_in.repeat_interleave(N, 0)
+ norm_hidden_states, norm_encoder_hidden_states_time, gate_msa, enc_gate_msa = self.norm_temp(
+ hidden_states, encoder_hidden_states_time, temb
+ )
+ attn_hidden_states, attn_encoder_hidden_states_time = self.attn_temp(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=norm_encoder_hidden_states_time
+ )
+ hidden_states = hidden_states + gate_msa * attn_hidden_states
+ encoder_hidden_states_time = encoder_hidden_states_time + enc_gate_msa * attn_encoder_hidden_states_time
+ hidden_states = rearrange(hidden_states, '(b n) f c -> b f n c', n=N)
+
+ return hidden_states, encoder_hidden_states, encoder_hidden_states_time
+
+class SpaitalTemporalTransformer(ModelMixin, ConfigMixin, PeftAdapterMixin):
+ """
+ A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
+
+ Parameters:
+ num_attention_heads (`int`, defaults to `30`):
+ The number of heads to use for multi-head attention.
+ attention_head_dim (`int`, defaults to `64`):
+ The number of channels in each head.
+ in_channels (`int`, defaults to `16`):
+ The number of channels in the input.
+ out_channels (`int`, *optional*, defaults to `16`):
+ The number of channels in the output.
+ flip_sin_to_cos (`bool`, defaults to `True`):
+ Whether to flip the sin to cos in the time embedding.
+ time_embed_dim (`int`, defaults to `512`):
+ Output dimension of timestep embeddings.
+ ofs_embed_dim (`int`, defaults to `512`):
+ Output dimension of "ofs" embeddings used in CogVideoX-5b-I2B in version 1.5
+ text_embed_dim (`int`, defaults to `4096`):
+ Input dimension of text embeddings from the text encoder.
+ num_layers (`int`, defaults to `30`):
+ The number of layers of Transformer blocks to use.
+ dropout (`float`, defaults to `0.0`):
+ The dropout probability to use.
+ attention_bias (`bool`, defaults to `True`):
+ Whether to use bias in the attention projection layers.
+ sample_width (`int`, defaults to `90`):
+ The width of the input latents.
+ sample_height (`int`, defaults to `60`):
+ The height of the input latents.
+ sample_frames (`int`, defaults to `49`):
+ The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49
+ instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings,
+ but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with
+ K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1).
+ patch_size (`int`, defaults to `2`):
+ The size of the patches to use in the patch embedding layer.
+ temporal_compression_ratio (`int`, defaults to `4`):
+ The compression ratio across the temporal dimension. See documentation for `sample_frames`.
+ max_text_seq_length (`int`, defaults to `226`):
+ The maximum sequence length of the input text embeddings.
+ activation_fn (`str`, defaults to `"gelu-approximate"`):
+ Activation function to use in feed-forward.
+ timestep_activation_fn (`str`, defaults to `"silu"`):
+ Activation function to use when generating the timestep embeddings.
+ norm_elementwise_affine (`bool`, defaults to `True`):
+ Whether to use elementwise affine in normalization layers.
+ norm_eps (`float`, defaults to `1e-5`):
+ The epsilon value to use in normalization layers.
+ spatial_interpolation_scale (`float`, defaults to `1.875`):
+ Scaling factor to apply in 3D positional embeddings across spatial dimensions.
+ temporal_interpolation_scale (`float`, defaults to `1.0`):
+ Scaling factor to apply in 3D positional embeddings across temporal dimensions.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ num_attention_heads: int = 8,
+ attention_head_dim: int = 64,
+ in_channels: int = 3,
+ out_channels: Optional[int] = 3,
+ flip_sin_to_cos: bool = True,
+ freq_shift: int = 0,
+ time_embed_dim: int = 512,
+ ofs_embed_dim: Optional[int] = None,
+ text_embed_dim: int = 4096,
+ num_layers: int = 8,
+ dropout: float = 0.0,
+ attention_bias: bool = True,
+ sample_points: int = 2048,
+ sample_frames: int = 48,
+ patch_size: int = 1,
+ patch_size_t: Optional[int] = None,
+ temporal_compression_ratio: int = 4,
+ max_text_seq_length: int = 226,
+ activation_fn: str = "gelu-approximate",
+ timestep_activation_fn: str = "silu",
+ norm_elementwise_affine: bool = True,
+ norm_eps: float = 1e-5,
+ spatial_interpolation_scale: float = 1.875,
+ temporal_interpolation_scale: float = 1.0,
+ use_positional_embeddings: bool = True,
+ use_learned_positional_embeddings: bool = False,
+ patch_bias: bool = True,
+ cond_seq_length: int = 4,
+ cond_seq_length_t: int = 2,
+ transformer_block: str = "SpatialTemporalTransformerBlock",
+ num_classes: int = 0,
+ class_dropout_prob: float = 0.0,
+ ):
+ super().__init__()
+ inner_dim = num_attention_heads * attention_head_dim
+
+ if use_positional_embeddings and use_learned_positional_embeddings:
+ raise ValueError(
+ "There are no CogVideoX checkpoints available with disable rotary embeddings and learned positional "
+ "embeddings. If you're using a custom model and/or believe this should be supported, please open an "
+ "issue at https://github.com/huggingface/diffusers/issues."
+ )
+
+ self.embedding_dropout = nn.Dropout(dropout)
+
+ # 2. Time embeddings and ofs embedding(Only CogVideoX1.5-5B I2V have)
+
+ self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
+ self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
+
+ self.ofs_proj = None
+ self.ofs_embedding = None
+ if ofs_embed_dim:
+ self.ofs_proj = Timesteps(ofs_embed_dim, flip_sin_to_cos, freq_shift)
+ self.ofs_embedding = TimestepEmbedding(
+ ofs_embed_dim, ofs_embed_dim, timestep_activation_fn
+ ) # same as time embeddings, for ofs
+
+ self.class_embedder = None
+ if num_classes > 0:
+ self.class_embedder = LabelEmbedding(num_classes, time_embed_dim, class_dropout_prob)
+
+ self.transformer_block = transformer_block
+ if transformer_block == "SpatialTemporalTransformerBlock":
+ TransformerBlock = SpatialTemporalTransformerBlock
+ elif transformer_block == "SpatialTemporalTransformerBlockv2":
+ TransformerBlock = SpatialTemporalTransformerBlockv2
+ elif transformer_block == "SpatialTemporalTransformerBlockv3":
+ TransformerBlock = SpatialTemporalTransformerBlockv3
+ elif transformer_block == "SpatialOnlyTransformerBlock":
+ TransformerBlock = SpatialOnlyTransformerBlock
+ elif transformer_block == "TemporalOnlyTransformerBlock":
+ TransformerBlock = TemporalOnlyTransformerBlock
+
+ # 3. Define spatio-temporal transformers blocks
+ self.transformer_blocks = nn.ModuleList(
+ [
+ TransformerBlock(
+ dim=inner_dim,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ time_embed_dim=time_embed_dim,
+ dropout=dropout,
+ activation_fn=activation_fn,
+ attention_bias=attention_bias,
+ norm_elementwise_affine=norm_elementwise_affine,
+ norm_eps=norm_eps,
+ )
+ for _ in range(num_layers)
+ ]
+ )
+ self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine)
+
+ # 4. Output blocks
+ self.norm_out = AdaLayerNorm(
+ embedding_dim=time_embed_dim,
+ output_dim=2 * inner_dim,
+ norm_elementwise_affine=norm_elementwise_affine,
+ norm_eps=norm_eps,
+ chunk_dim=1,
+ )
+
+ if patch_size_t is None:
+ # For CogVideox 1.0
+ output_dim = patch_size * patch_size * out_channels
+ else:
+ # For CogVideoX 1.5
+ output_dim = patch_size * patch_size * patch_size_t * out_channels
+
+ self.proj_out = nn.Linear(inner_dim, output_dim)
+
+ self.gradient_checkpointing = False
+
+ if use_positional_embeddings or use_learned_positional_embeddings:
+ self.embed_dim = num_attention_heads * attention_head_dim
+ self.cond_seq_length = cond_seq_length
+ self.cond_seq_length_t = cond_seq_length_t
+ persistent = use_learned_positional_embeddings
+ pos_embedding = self._get_positional_embeddings(sample_points, sample_frames)
+ self.register_buffer("pos_embedding", pos_embedding, persistent=persistent)
+
+ def _get_positional_embeddings(self, points: int, frames: int, device: Optional[torch.device] = None) -> torch.Tensor:
+ pos_embedding = get_3d_sincos_pos_embed(
+ self.embed_dim,
+ points,
+ frames,
+ device=device,
+ output_type="pt",
+ )
+ pos_embedding = pos_embedding.flatten(0, 1)
+ joint_pos_embedding = pos_embedding.new_zeros(
+ 1, self.cond_seq_length + points * frames, self.embed_dim, requires_grad=False
+ )
+ joint_pos_embedding.data[:, self.cond_seq_length:].copy_(pos_embedding)
+ return joint_pos_embedding
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ self.gradient_checkpointing = value
+
+ @property
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor()
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor, # [batch_size]
+ encoder_hidden_states: torch.Tensor, # [batch_size]
+ timestep: Union[int, float, torch.LongTensor],
+ timestep_cond: Optional[torch.Tensor] = None,
+ class_labels: Optional[torch.Tensor] = None,
+ force_drop_ids: Optional[torch.Tensor] = None,
+ ofs: Optional[Union[int, float, torch.LongTensor]] = None,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ indices: Optional[torch.LongTensor] = None,
+ return_dict: bool = True,
+ ):
+ if attention_kwargs is not None:
+ attention_kwargs = attention_kwargs.copy()
+ lora_scale = attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+ else:
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
+ logger.warning(
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
+ )
+
+ # 1. Time embedding
+ timesteps = timestep
+ t_emb = self.time_proj(timesteps)
+
+ # timesteps does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=hidden_states.dtype)
+ emb = self.time_embedding(t_emb, timestep_cond)
+
+ # TODO: check force drop id shape
+ if self.class_embedder is not None:
+ assert class_labels is not None
+ class_labels = self.class_embedder(class_labels, force_drop_ids=force_drop_ids) # (N, D)
+ emb = emb + class_labels
+
+ if self.ofs_embedding is not None:
+ ofs_emb = self.ofs_proj(ofs)
+ ofs_emb = ofs_emb.to(dtype=hidden_states.dtype)
+ ofs_emb = self.ofs_embedding(ofs_emb)
+ emb = emb + ofs_emb
+
+ B, F, N, C = hidden_states.shape
+ full_seq = torch.cat([encoder_hidden_states, hidden_states.reshape(B, F*N, -1)], axis=1)
+
+ # 2. Patch embedding
+ pos_embedding = self.pos_embedding
+ pos_embedding = pos_embedding.to(dtype=full_seq.dtype)
+ hidden_states = full_seq + pos_embedding
+
+ hidden_states = self.embedding_dropout(hidden_states)
+ encoder_hidden_states = hidden_states[:, :self.cond_seq_length]
+ hidden_states = hidden_states[:, self.cond_seq_length:].reshape(B, F, N, C)
+
+ if self.transformer_block not in ["SpatialTemporalTransformerBlock", 'TemporalOnlyTransformerBlock', 'SpatialOnlyTransformerBlock']:
+ encoder_hidden_states_time = hidden_states[:, :self.cond_seq_length_t]
+ encoder_hidden_states_time = rearrange(encoder_hidden_states_time, 'b f n c -> (b n) f c')
+ hidden_states = hidden_states[:, self.cond_seq_length_t:]
+
+ # 3. Transformer blocks
+ for i, block in enumerate(self.transformer_blocks):
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ if self.transformer_block in ["SpatialTemporalTransformerBlock", 'TemporalOnlyTransformerBlock', 'SpatialOnlyTransformerBlock']:
+ hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ hidden_states,
+ encoder_hidden_states,
+ emb,
+ image_rotary_emb,
+ **ckpt_kwargs,
+ )
+ else:
+ hidden_states, encoder_hidden_states, encoder_hidden_states_time = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ hidden_states,
+ encoder_hidden_states,
+ encoder_hidden_states_time,
+ emb,
+ image_rotary_emb,
+ indices=indices,
+ **ckpt_kwargs,
+ )
+ else:
+ if self.transformer_block in ["SpatialTemporalTransformerBlock", 'TemporalOnlyTransformerBlock', 'SpatialOnlyTransformerBlock']:
+ hidden_states, encoder_hidden_states = block(
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ temb=emb,
+ image_rotary_emb=image_rotary_emb,
+ )
+ else:
+ hidden_states, encoder_hidden_states, encoder_hidden_states_time = block(
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_hidden_states_time=encoder_hidden_states_time,
+ temb=emb,
+ indices=indices,
+ image_rotary_emb=image_rotary_emb,
+ )
+ hidden_states = rearrange(hidden_states, 'b f n c -> b (f n) c')
+ # 4. Final block
+ hidden_states = self.norm_final(hidden_states)
+ hidden_states = self.norm_out(hidden_states, temb=emb)
+ output = self.proj_out(hidden_states)
+
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+
+ return output
+
+class MDM_ST(nn.Module):
+ def __init__(self, n_points, n_frame, n_feats, model_config):
+ super().__init__()
+ print('use new model')
+
+ self.n_points = n_points
+ self.n_feats = n_feats
+ self.latent_dim = model_config.latent_dim
+ self.cond_frame = 1 if model_config.frame_cond else 0
+ self.frame_cond = model_config.frame_cond
+
+ if model_config.get('point_embed', True):
+ self.input_encoder = PointEmbed(dim=self.latent_dim)
+ else:
+ print('not using point embedding')
+ self.input_encoder = nn.Linear(n_feats, self.latent_dim)
+ self.mask_cond = model_config.get('mask_cond', False)
+ if self.mask_cond:
+ print('Use mask condition')
+ self.mask_encoder = nn.Linear(1, self.latent_dim)
+ self.cond_frame += 1
+ self.pred_offset = model_config.get('pred_offset', True)
+ self.num_neighbors = model_config.get('num_neighbors', 0)
+ self.max_num_forces = model_config.get('max_num_forces', 1)
+ self.model_config = model_config
+
+ self.cond_seq_length = 2
+
+ self.E_cond_encoder = nn.Linear(1, self.latent_dim)
+ self.nu_cond_encoder = nn.Linear(1, self.latent_dim)
+ self.force_as_token = model_config.get('force_as_token', True)
+ self.force_as_latent = model_config.get('force_as_latent', False)
+
+ if self.force_as_latent:
+ self.input_encoder = nn.Linear(n_feats + 4 * self.max_num_forces, self.latent_dim)
+ elif self.force_as_token:
+ self.cond_seq_length += self.max_num_forces * 2
+ self.force_cond_encoder = nn.Linear(3, self.latent_dim)
+ self.drag_point_encoder = nn.Linear(3, self.latent_dim)
+ else:
+ self.cond_seq_length += 2
+ self.force_cond_encoder = nn.Linear(3, self.latent_dim)
+ self.drag_point_encoder = nn.Linear(3, self.latent_dim)
+
+ self.gravity_emb = model_config.get('gravity_emb', False)
+ if self.gravity_emb:
+ self.gravity_embedding = nn.Embedding(2, self.latent_dim)
+ self.cond_seq_length += 1
+
+ if self.model_config.floor_cond:
+ self.floor_encoder = nn.Linear(1, self.latent_dim)
+ self.cond_seq_length += 1
+
+ if self.model_config.coeff_cond:
+ self.coeff_encoder = nn.Linear(1, self.latent_dim)
+ self.cond_seq_length += 1
+
+ self.num_mat = model_config.get('num_mat', 0)
+ if model_config.class_token:
+ self.class_embedding = nn.Embedding(model_config.num_mat, self.latent_dim)
+ self.cond_seq_length += 1
+
+ self.class_dropout_prob = model_config.get('class_dropout_prob', 0.0)
+ self.dit = SpaitalTemporalTransformer(sample_points=n_points, sample_frames=n_frame+self.cond_frame, in_channels=n_feats,
+ num_layers=model_config.n_layers, num_attention_heads=self.latent_dim // 64, time_embed_dim=self.latent_dim, cond_seq_length=self.cond_seq_length, cond_seq_length_t=self.cond_frame, transformer_block=model_config.transformer_block, num_classes=self.num_mat, class_dropout_prob=self.class_dropout_prob)
+
+ self._init_weights()
+
+ def _init_weights(self):
+ if self.gravity_emb:
+ nn.init.normal_(self.gravity_embedding.weight, mean=0.0, std=0.1)
+
+ def enable_gradient_checkpointing(self):
+ self.dit._set_gradient_checkpointing(True)
+
+ def forward(self, x, timesteps, init_pc, force, E, nu, drag_mask, drag_point, floor_height, gravity_label=None, coeff=None, y=None, null_emb=None):
+
+ """
+ x: [batch_size, frame, n_points, n_feats], denoted x_t in the paper
+ timesteps: [batch_size] (int)
+ """
+
+ bs, n_frame, n_points, n_feats = x.shape
+
+ init_pc = init_pc.reshape(bs, n_points, n_feats)
+ force = force.unsqueeze(1) if force.ndim == 2 else force
+ drag_point = drag_point.unsqueeze(1) if drag_point.ndim == 2 else drag_point
+ E = E.unsqueeze(1)
+ nu = nu.unsqueeze(1)
+
+ if self.num_neighbors > 0:
+ rel_dist = torch.cdist(init_pc, init_pc)
+ dist, indices = rel_dist.topk(self.num_neighbors, largest = False)
+ indices = indices.repeat_interleave(n_frame, 0)
+ # indices = torch.cat([indices, torch.tensor([2048, 2049, 2050, 2051])[None, None].repeat(bs*n_frame, n_points, 1).to(indices.device)], axis=2)
+ else:
+ indices = None
+
+ if self.force_as_token:
+ force_emb = self.force_cond_encoder(force) + self.gravity_embedding(gravity_label) if self.gravity_emb else self.force_cond_encoder(force)
+ encoder_hidden_states = torch.cat([self.E_cond_encoder(E), self.nu_cond_encoder(nu)], axis=1)
+ # force_info = torch.cat([force, drag_point], dim=-1) # (B, n_forces, 7)
+ # force_tokens = self.force_cond_encoder(force_info)
+ encoder_hidden_states = torch.cat([encoder_hidden_states, force_emb, self.drag_point_encoder(drag_point[..., :3])], axis=1)
+ elif self.force_as_latent:
+ encoder_hidden_states = torch.cat([self.E_cond_encoder(E), self.nu_cond_encoder(nu)], axis=1)
+ force = force.unsqueeze(1).repeat(1, n_points, 1, 1) # (B, n_points, n_forces, 3)
+ all_force = torch.cat([force, drag_mask.permute(0, 2, 1, 3)], dim=-1).reshape(bs, n_points, -1) # (B, n_points, n_forces, 4)
+ else:
+ encoder_hidden_states = torch.cat([self.force_cond_encoder(force), self.E_cond_encoder(E),
+ self.nu_cond_encoder(nu), self.drag_point_encoder(drag_point[..., :3])], axis=1)
+ if self.gravity_emb:
+ encoder_hidden_states = torch.cat([encoder_hidden_states, self.gravity_embedding(gravity_label)], axis=1)
+ if self.model_config.class_token:
+ class_labels = y.unsqueeze(1)
+ class_labels = self.class_embedding(class_labels)
+ encoder_hidden_states = torch.cat([encoder_hidden_states, class_labels], axis=1)
+ if self.model_config.floor_cond:
+ floor_height = floor_height.unsqueeze(1) if floor_height is not None else None
+ encoder_hidden_states = torch.cat([encoder_hidden_states, self.floor_encoder(floor_height)], axis=1)
+ if self.model_config.coeff_cond:
+ coeff = coeff.unsqueeze(1) if coeff is not None else None
+ encoder_hidden_states = torch.cat([encoder_hidden_states, self.coeff_encoder(coeff)], axis=1)
+ if null_emb is not None:
+ encoder_hidden_states = encoder_hidden_states * null_emb
+ if self.frame_cond:
+ x = torch.cat([init_pc.unsqueeze(1), x], axis=1) # Condition on first frame
+ if self.force_as_latent:
+ all_force = all_force.unsqueeze(1).repeat(1, x.shape[1], 1, 1) # (B, n_frame, n_points, n_forces*4)
+ x = torch.cat([x, all_force], dim=-1) # (B, n_frame, n_points, n_feats+n_forces * 4)
+ n_feats = x.shape[-1]
+ hidden_states = self.input_encoder(x.reshape(-1, n_points,
+ n_feats)).reshape(bs, -1, n_points, self.latent_dim)
+ if self.mask_cond:
+ mask = self.mask_encoder(drag_mask[:, :1])
+ hidden_states = torch.cat([mask, hidden_states], axis=1)
+ if self.model_config.transformer_block in ["SpatialTemporalTransformerBlock", "TemporalOnlyTransformerBlock", "SpatialOnlyTransformerBlock"]:
+ output = self.dit(hidden_states, encoder_hidden_states, timesteps, class_labels=y).reshape(bs, -1, n_points, 3)[:, self.cond_frame:]
+ else:
+ output = self.dit(hidden_states, encoder_hidden_states, timesteps, indices=indices).reshape(bs, -1, n_points, 3)
+ output = output + init_pc.unsqueeze(1) if self.pred_offset else output
+
+ return output
+
+if __name__ == "__main__":
+
+ # Diffusion
+ from omegaconf import OmegaConf
+ from options import TestingConfig
+ cfg_path = '../traj-diff/configs/eval.yaml'
+ config_path = 'model_config.yaml'
+ device = 'cuda'
+ schema = OmegaConf.structured(TestingConfig)
+ cfg = OmegaConf.load(cfg_path)
+ cfg = OmegaConf.merge(schema, cfg)
+ n_training_frames = cfg.train_dataset.n_training_frames
+ n_frames_interval = cfg.train_dataset.n_frames_interval
+
+ point_num = 2048
+ frame_num = 24
+ x = torch.randn(1, frame_num, point_num, 3).to(device).to(torch.float16)
+ timesteps = torch.tensor([999]).int().to(device).to(torch.float16)
+ init_pc = torch.randn(1, 1, point_num, 3).to(device).to(torch.float16)
+ force = torch.randn(1, 3).to(device).to(torch.float16)
+ E = torch.randn(1, 1).to(device).to(torch.float16)
+ nu = torch.randn(1, 1).to(device).to(torch.float16)
+
+ x = nn.Parameter(x)
+
+ with torch.enable_grad():
+ # with torch.no_grad():
+ t_total = 0
+ for i in range(100):
+ model = MDM_ST(point_num, frame_num, 3, cfg.model_config).to(device).to(torch.float16)
+ model.train()
+ import time
+ t0 = time.time()
+ output = model(x, timesteps, init_pc, force, E, nu, None, force, torch.zeros_like(E), torch.ones_like(E), None)
+ loss = output.sum()
+ loss.backward()
+ t1 = time.time()
+ if i > 10:
+ t_total += t1 - t0
+ print(t1 - t0)
+
+ print("Average time: ", t_total / 90)
\ No newline at end of file
diff --git a/src/options.py b/src/options.py
new file mode 100644
index 0000000000000000000000000000000000000000..4455d848605570a4f65705ee1d28823e8fd47c71
--- /dev/null
+++ b/src/options.py
@@ -0,0 +1,87 @@
+from dataclasses import dataclass
+from typing import Dict, Optional, Tuple, List
+
+@dataclass
+class TrainingConfig:
+ image_size: int
+ # train_batch_size = 16
+ # eval_batch_size = 16 # how many images to sample during evaluation
+ # num_epochs = 50
+ # gradient_accumulation_steps = 1
+ # learning_rate = 1e-4
+ # lr_warmup_steps = 500
+ # save_image_epochs = 10
+ # save_model_epochs = 30
+ # mixed_precision = "fp16" # `no` for float32, `fp16` for automatic mixed precision
+ # output_dir = "ddpm-butterflies-128" # the model name locally and on the HF Hub
+ # logging
+ output_dir: str
+ logging_dir: str
+ vis_dir: str
+ report_to: Optional[str]
+ local_rank: int
+ tracker_project_name: str
+
+ # Training
+ seed: Optional[int]
+ train_batch_size: int
+ eval_batch_size: int
+ num_train_epochs: int
+ max_train_steps: int
+ gradient_accumulation_steps: int
+ gradient_checkpointing: bool
+ learning_rate: float
+ scale_lr: bool
+ lr_scheduler: str
+ lr_warmup_steps: int
+ use_8bit_adam: bool
+ allow_tf32: bool
+ dataloader_num_workers: int
+ adam_beta1: float
+ adam_beta2: float
+ adam_weight_decay: float
+ adam_epsilon: float
+ max_grad_norm: Optional[float]
+ prediction_type: Optional[str]
+ mixed_precision: Optional[str]
+ checkpointing_steps: int
+ checkpoints_total_limit: Optional[int]
+ resume_from_checkpoint: Optional[str]
+ enable_xformers_memory_efficient_attention: bool
+ validation_steps: int
+ validation_train_steps: int
+ validation_sanity_check: bool
+ resume_step: Optional[int]
+ push_to_hub: bool
+ set_grads_to_none: bool
+ lambda_vel: float
+ lambda_mask : float
+ lambda_momentum: float
+ lambda_deform: float
+ overfit: bool
+
+ # Diffusion Specific
+ condition_drop_rate: float
+
+ # Dataset
+ train_dataset: Dict
+
+ # Model
+ model_type: str
+ pred_offset: bool
+ model_config: Dict
+ pc_size: int
+
+@dataclass
+class TestingConfig:
+ dataloader_num_workers: int
+ pc_size: int
+ model_type: str
+ pred_offset: bool
+ model_config: Dict
+ train_dataset: Dict
+ resume: str
+ vis_dir: str
+ eval_batch_size: int
+ seed: int
+ num_inference_steps: int
\ No newline at end of file
diff --git a/src/pipeline_traj.py b/src/pipeline_traj.py
new file mode 100644
index 0000000000000000000000000000000000000000..c8f6bf5f37ce1e35c35eeb80e6f7493347f82754
--- /dev/null
+++ b/src/pipeline_traj.py
@@ -0,0 +1,60 @@
+import torch
+from diffusers import DiffusionPipeline
+
+
+class TrajPipeline(DiffusionPipeline):
+ def __init__(self, model, scheduler):
+ super().__init__()
+ self.register_modules(model=model, scheduler=scheduler)
+
+ @torch.no_grad()
+ def __call__(self, init_pc, force, E, nu, mask, drag_point, floor_height, gravity, coeff,
+ generator,
+ device,
+ y = None,
+ batch_size: int = 1,
+ num_inference_steps: int = 50,
+ guidance_scale=1.0,
+ n_frames=20
+ ):
+ # Sample gaussian noise to begin loop
+ sample = torch.randn((batch_size, n_frames, init_pc.shape[2], 3), generator=generator).to(device)
+ self.model.to(device)
+ init_pc = init_pc.to(device)
+ force = force.to(device)
+ E = E.to(device)
+ nu = nu.to(device)
+ mask = mask.to(device).to(dtype=sample.dtype)
+ drag_point = drag_point.to(device)
+ floor_height = floor_height.to(device)
+ coeff = coeff.to(device)
+ gravity = gravity.to(device) if gravity is not None else None
+ y = y.to(device) if y is not None else None
+ # set step values
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ do_classifier_free_guidance = (guidance_scale > 1.0)
+ null_emb = torch.tensor([1] * batch_size).to(sample.dtype)
+ if do_classifier_free_guidance:
+ init_pc = torch.cat([init_pc] * 2)
+ force = torch.cat([force] * 2)
+ E = torch.cat([E] * 2)
+ nu = torch.cat([nu] * 2)
+ mask = torch.cat([mask] * 2)
+ drag_point = torch.cat([drag_point] * 2)
+ floor_height = torch.cat([floor_height] * 2)
+ null_emb = torch.cat([torch.tensor([0] * batch_size).to(sample.dtype), null_emb])
+ null_emb = null_emb[:, None, None].to(device)
+ for t in self.progress_bar(self.scheduler.timesteps):
+ t = torch.tensor([t] * batch_size, device=device)
+ sample_input = torch.cat([sample] * 2) if do_classifier_free_guidance else sample
+ t = torch.cat([t] * 2) if do_classifier_free_guidance else t
+ # 1. predict noise model_output
+ model_output = self.model(sample_input, t, init_pc, force, E, nu, mask, drag_point, floor_height=floor_height, gravity_label=gravity, coeff=coeff, y=y, null_emb=null_emb)
+ if do_classifier_free_guidance:
+ model_pred_uncond, model_pred_cond = model_output.chunk(2)
+ model_output = model_pred_uncond + guidance_scale * (model_pred_cond - model_pred_uncond)
+ # 2. predict previous mean of image x_t-1 and add variance depending on eta
+ # eta corresponds to η in paper and should be between [0, 1]
+ # do x_t -> x_t-1
+ sample = self.scheduler.step(model_output, t[0], sample).prev_sample
+ return sample
\ No newline at end of file
diff --git a/src/templates/projection.npy b/src/templates/projection.npy
new file mode 100644
index 0000000000000000000000000000000000000000..325eb3c48000c39f0d8cb4ca5b64cda7352d5549
--- /dev/null
+++ b/src/templates/projection.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e236699b504df5652f551f41e5363a6eeb3778a15a433d5cfdfd2b6352d1d593
+size 192
diff --git a/src/templates/tracks_template.npy b/src/templates/tracks_template.npy
new file mode 100644
index 0000000000000000000000000000000000000000..813359159d2d075301a23983260de6f85afe2cd4
--- /dev/null
+++ b/src/templates/tracks_template.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c9efc35800a3b22026393371891e5b893dead64a0facfa5f768c408fb97d86ac
+size 6002991
diff --git a/src/train.py b/src/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..21b1c97c2a1a07cee5eb7b7e626194423794c274
--- /dev/null
+++ b/src/train.py
@@ -0,0 +1,447 @@
+import argparse
+import itertools
+import json
+import logging
+import math
+import os
+import random
+import shutil
+import warnings
+from pathlib import Path
+from omegaconf import OmegaConf
+from options import TrainingConfig
+
+import numpy as np
+import safetensors
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from huggingface_hub import HfApi, create_repo
+from huggingface_hub.utils import insecure_hashlib
+from packaging import version
+from PIL import Image
+from torch.utils.data import Dataset
+from torchvision import transforms
+from tqdm.auto import tqdm
+from torchvision import transforms
+
+import diffusers
+from diffusers import (
+ AutoencoderKL, DDPMScheduler, DDPMPipeline, DDIMScheduler, DiffusionPipeline, DPMSolverMultistepScheduler, UNet2DConditionModel, UNet2DModel
+)
+from diffusers.loaders import AttnProcsLayers
+from diffusers.optimization import get_scheduler
+from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
+from diffusers.utils.import_utils import is_xformers_available
+from diffusers.optimization import get_cosine_schedule_with_warmup
+from pipeline_traj import TrajPipeline
+from accelerate.utils import DistributedDataParallelKwargs
+
+from model.spacetime import MDM_ST
+from dataset.traj_dataset import TrajDataset
+
+from utils.visualization import save_pointcloud_video, save_pointcloud_json, save_threejs_html
+from utils.physics import loss_momentum
+from utils.physics import DeformLoss
+
+logger = get_logger(__name__)
+
+def seed_everything(seed):
+ random.seed(seed)
+ os.environ['PYTHONHASHSEED'] = str(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+
+def main(args):
+ vis_dir = os.path.join(args.output_dir, args.vis_dir)
+ logging_dir = Path(args.output_dir, args.logging_dir)
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+
+ kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ # kwargs_handlers=[kwargs]
+ )
+
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+ import wandb
+
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_config = {}
+ accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+ seed_everything(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+ os.makedirs(vis_dir, exist_ok=True)
+ OmegaConf.save(cfg, os.path.join(cfg.output_dir, 'config.yaml'))
+
+ src_snapshot_folder = os.path.join(cfg.output_dir, 'src')
+ ignore_func = lambda d, files: [f for f in files if f.endswith('__pycache__')]
+ for folder in ['model', 'dataset']:
+ dst_dir = os.path.join(src_snapshot_folder, folder)
+ shutil.copytree(folder, dst_dir, ignore=ignore_func, dirs_exist_ok=True)
+ shutil.copy(os.path.abspath(__file__), os.path.join(cfg.output_dir, 'src', 'train.py'))
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
+ # as these models are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ model = MDM_ST(args.pc_size, args.train_dataset.n_training_frames, n_feats=3, model_config=args.model_config)
+
+ # if args.gradient_checkpointing:
+ # model.enable_gradient_checkpointing()
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ params = model.parameters()
+ # Optimizer creation
+ optimizer = optimizer_class(
+ [
+ {"params": params, "lr": args.learning_rate},
+ ],
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # if args.model_type == 'dit_st_water':
+ # from dataset.water_dataset import TrajDataset
+ # Dataset and DataLoaders creation:
+ train_dataset = TrajDataset('train', args.train_dataset)
+ train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers, pin_memory=True)
+
+ val_dataset = TrajDataset('val', args.train_dataset)
+ val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=args.eval_batch_size, shuffle=False, num_workers=args.dataloader_num_workers)
+
+ # noise = torch.randn(sample_image.shape)
+ # timesteps = torch.LongTensor([50])
+ # noisy_image = noise_scheduler.add_noise(sample_image, noise, timesteps)
+ # Image.fromarray(((noisy_image.permute(0, 2, 3, 1) + 1.0) * 127.5).type(torch.uint8).numpy()[0])
+
+ # Scheduler and math around the number of training steps.
+ # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
+ num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
+ if args.max_train_steps is None:
+ len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
+ num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
+ num_training_steps_for_scheduler = (
+ args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
+ )
+ else:
+ num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
+
+ lr_scheduler = get_cosine_schedule_with_warmup(
+ optimizer=optimizer,
+ num_warmup_steps=num_warmup_steps_for_scheduler,
+ num_training_steps=num_training_steps_for_scheduler,
+ )
+
+ model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ model, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
+ logger.warning(
+ f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
+ f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
+ f"This inconsistency may result in the learning rate scheduler not functioning properly."
+ )
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num of Trainable Parameters (M) = {sum(p.numel() for p in model.parameters()) / 1000000}")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ logger.info(f" Log to = {args.output_dir}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ noise_scheduler = DDPMScheduler(num_train_timesteps=1000, prediction_type='sample', clip_sample=False)
+
+ if args.seed is None:
+ generator = None
+ else:
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+
+ loss_deform = DeformLoss()
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ model.train()
+ train_loss = 0.0
+ for step, (batch, _) in enumerate(train_dataloader):
+ with accelerator.accumulate(model):
+ latents = batch['points_tgt'] # (bsz, n_frames, n_points, 3)
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+ bsz = latents.shape[0]
+ # Sample a random timestep for each image
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
+ timesteps = timesteps.long()
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+
+ if args.condition_drop_rate > 0:
+ # Randomly drop some of the latents
+ random_p = torch.rand(bsz, device=latents.device, generator=generator)
+ null_emb = (random_p > args.condition_drop_rate).float()[..., None, None]
+ else:
+ null_emb = None
+
+ # Predict the noise residual
+ pred_sample = model(noisy_latents, timesteps, batch['points_src'], batch['force'], batch['E'], batch['nu'], batch['mask'][..., :1], batch['drag_point'], batch['floor_height'], batch['gravity'], batch['base_drag_coeff'], y=None if 'mat_type' not in batch else batch['mat_type'], null_emb=null_emb)
+ losses = {}
+
+ loss = F.mse_loss(pred_sample.float(), latents.float())
+ losses['xyz'] = loss.detach().item()
+
+ if args.lambda_mask > 0:
+ loss_mask = F.mse_loss(pred_sample[batch['mask']], latents[batch['mask']])
+ loss += loss_mask
+ losses['mask'] = loss_mask.detach().item()
+
+ if args.lambda_vel > 0.:
+ target_vel = latents[:, 1:] - latents[:, :-1]
+ pred_vel = (pred_sample[:, 1:] - pred_sample[:, :-1])
+ loss_vel = F.mse_loss(target_vel.float(), pred_vel.float())
+ losses['loss_vel'] = loss_vel.detach().item()
+ loss = loss + loss_vel
+
+ if 'vol' in batch and args.lambda_momentum > 0.:
+ loss_p = loss_momentum(x=pred_sample, vol=batch['vol'], force=batch['weighted_force'],
+ drag_pt_num=batch['mask'][:, 0, :].sum(dim=1), norm_fac=args.train_dataset.norm_fac)
+ losses['loss_p'] = loss_p.detach().item()
+ loss = loss + args.lambda_momentum * loss_p
+
+ if 'vol' in batch and args.lambda_deform > 0.:
+ pred_sample_mpm = pred_sample
+ if 'is_mpm' in batch:
+ mask = batch['is_mpm']
+ pred_sample_mpm = pred_sample[mask]
+ batch['vol'] = batch['vol'][mask]
+ batch['F'] = batch['F'][mask]
+ batch['C'] = batch['C'][mask]
+ loss_F = loss_deform(x=pred_sample_mpm.clamp(min=-2.2, max=2.2), vol=batch['vol'], F=batch['F'],
+ C=batch['C'], frame_interval=2, norm_fac=args.train_dataset.norm_fac) if batch['vol'].shape[0] > 0 else torch.tensor(0.0, device=pred_sample.device)
+ losses['loss_deform'] = loss_F.detach().item()
+ loss = loss + args.lambda_deform * loss_F
+
+ if args.model_config.floor_cond:
+ floor_height = batch['floor_height'].reshape(bsz, 1, 1) # (B, 1, 1)
+ sample_min_height = torch.amin(latents[..., 1], dim=(1, 2)).reshape(bsz, 1, 1)
+ floor_height = torch.minimum(floor_height, sample_min_height)
+ loss_floor = (torch.relu(floor_height - pred_sample[..., 1]) ** 2).mean()
+ losses['loss_floor'] = loss_floor.detach().item()
+ loss += loss_floor
+
+ # Gather the losses across all processes for logging (if we use distributed training).
+ avg_loss = accelerator.gather(loss.repeat(cfg.train_batch_size)).mean()
+ train_loss += avg_loss.item() / cfg.gradient_accumulation_steps
+
+ accelerator.backward(loss)
+
+ if accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(model.parameters(), args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+ train_loss = 0.0
+
+ if global_step % args.checkpointing_steps == 0:
+ if accelerator.is_main_process:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ if global_step % cfg.validation_steps == 0 or global_step == 1:
+ if accelerator.is_main_process:
+ model.eval()
+ pipeline = TrajPipeline(model=accelerator.unwrap_model(model), scheduler=DDIMScheduler.from_config(noise_scheduler.config))
+ logger.info(
+ f"Running validation... \n."
+ )
+ for i, (batch, _) in enumerate(val_dataloader):
+ with torch.autocast("cuda"):
+ gs = [1.0] if args.condition_drop_rate == 0 else [1.0, 2.0, 3.0]
+ for guidance_scale in gs:
+ output = pipeline(batch['points_src'], batch['force'], batch['E'], batch['nu'], batch['mask'][..., :1], batch['drag_point'], batch['floor_height'], batch['gravity'], batch['base_drag_coeff'], y=None if 'mat_type' not in batch else batch['mat_type'], device=accelerator.device, batch_size=args.eval_batch_size, generator=torch.Generator().manual_seed(args.seed), n_frames=args.train_dataset.n_training_frames, guidance_scale=guidance_scale)
+ output = output.cpu().numpy()
+ tgt = batch['points_tgt'].cpu().numpy()
+ save_dir = os.path.join(vis_dir, f'{global_step:06d}')
+ os.makedirs(save_dir, exist_ok=True)
+ for j in range(output.shape[0]):
+ save_pointcloud_video(output[j:j+1].squeeze(), tgt[j:j+1].squeeze(), os.path.join(save_dir, f'{i*batch["points_src"].shape[0] + j}_{guidance_scale}.gif'),
+ drag_mask=batch['mask'][j:j+1, 0, :, 0].cpu().numpy().squeeze(), vis_flag=args.train_dataset.dataset_path)
+ # pred_name = f'{i*batch["points_src"].shape[0]+j}_pred.json'
+ # gt_name = f'{i*batch["points_src"].shape[0]+j}_gt.json'
+ # save_pointcloud_json(output[j:j+1].squeeze(), os.path.join(save_dir, pred_name))
+ # save_pointcloud_json(tgt[j:j+1].squeeze(), os.path.join(save_dir, gt_name))
+ # save_threejs_html(pred_name, gt_name, os.path.join(save_dir, f'{j}.html'))
+ torch.cuda.empty_cache()
+ model.train()
+
+ logs = losses
+ logs.update({"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]})
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ # Save the custom diffusion layers
+ accelerator.wait_for_everyone()
+ # if accelerator.is_main_process:
+ # unet = unet.to(torch.float32)
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--config', type=str, required=True)
+ args = parser.parse_args()
+ schema = OmegaConf.structured(TrainingConfig)
+ cfg = OmegaConf.load(args.config)
+ cfg = OmegaConf.merge(schema, cfg)
+ main(cfg)
\ No newline at end of file
diff --git a/src/utils/image_process.py b/src/utils/image_process.py
new file mode 100644
index 0000000000000000000000000000000000000000..131965d387b394fc98a780e95dff3a978b91799f
--- /dev/null
+++ b/src/utils/image_process.py
@@ -0,0 +1,83 @@
+import cv2
+import numpy as np
+import torch
+from PIL import Image
+from rembg import remove
+from segment_anything import SamPredictor, sam_model_registry
+
+def sam_init(sam_checkpoint, device_id=0):
+ model_type = "vit_h"
+
+ device = "cuda:{}".format(device_id) if torch.cuda.is_available() else "cpu"
+
+ sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device=device)
+ predictor = SamPredictor(sam)
+ return predictor
+
+def sam_out_nosave(predictor, input_image, *bbox_sliders):
+ bbox = np.array(bbox_sliders)
+ image = np.asarray(input_image)
+
+ predictor.set_image(image)
+
+ masks_bbox, scores_bbox, logits_bbox = predictor.predict(
+ box=bbox, multimask_output=True
+ )
+
+ out_image = np.zeros((image.shape[0], image.shape[1], 4), dtype=np.uint8)
+ out_image[:, :, :3] = image
+ out_image_bbox = out_image.copy()
+ out_image_bbox[:, :, 3] = (
+ masks_bbox[-1].astype(np.uint8) * 255
+ ) # np.argmax(scores_bbox)
+ torch.cuda.empty_cache()
+ return Image.fromarray(out_image_bbox, mode="RGBA")
+
+# contrast correction, rescale and recenter
+def image_preprocess(input_image, save_path, target_res, lower_contrast=True, rescale=True):
+ image_arr = np.array(input_image)
+ in_w, in_h = image_arr.shape[:2]
+
+ if lower_contrast:
+ alpha = 0.8 # Contrast control (1.0-3.0)
+ beta = 0 # Brightness control (0-100)
+ # Apply the contrast adjustment
+ image_arr = cv2.convertScaleAbs(image_arr, alpha=alpha, beta=beta)
+ image_arr[image_arr[..., -1] > 200, -1] = 255
+
+ ret, mask = cv2.threshold(
+ np.array(input_image.split()[-1]), 0, 255, cv2.THRESH_BINARY
+ )
+ x, y, w, h = cv2.boundingRect(mask)
+ max_size = max(w, h)
+ ratio = 0.75
+ if rescale:
+ side_len = int(max_size / ratio)
+ else:
+ side_len = in_w
+ padded_image = np.zeros((side_len, side_len, 4), dtype=np.uint8)
+ center = side_len // 2
+ padded_image[
+ center - h // 2 : center - h // 2 + h, center - w // 2 : center - w // 2 + w
+ ] = image_arr[y : y + h, x : x + w]
+ rgba = Image.fromarray(padded_image).resize((target_res, target_res), Image.LANCZOS)
+ rgba.save(save_path)
+ return y + h // 2 - center, x + w // 2 - center, side_len
+
+def pred_bbox(image):
+ image_nobg = remove(image.convert("RGBA"), alpha_matting=True)
+ alpha = np.asarray(image_nobg)[:, :, -1]
+ x_nonzero = np.nonzero(alpha.sum(axis=0))
+ y_nonzero = np.nonzero(alpha.sum(axis=1))
+ x_min = int(x_nonzero[0].min())
+ y_min = int(y_nonzero[0].min())
+ x_max = int(x_nonzero[0].max())
+ y_max = int(y_nonzero[0].max())
+ return x_min, y_min, x_max, y_max
+
+def resize_image(input_raw, size):
+ w, h = input_raw.size
+ ratio = size / max(w, h)
+ resized_w = int(w * ratio)
+ resized_h = int(h * ratio)
+ return input_raw.resize((resized_w, resized_h), Image.Resampling.LANCZOS)
\ No newline at end of file
diff --git a/src/utils/interpolate.py b/src/utils/interpolate.py
new file mode 100644
index 0000000000000000000000000000000000000000..3981a2faa31c8070d2096ef20e02ccdf7ceecf01
--- /dev/null
+++ b/src/utils/interpolate.py
@@ -0,0 +1,297 @@
+import torch
+import torch
+import torch.nn.functional as F
+
+def get_rigid_transform(A, B):
+ """
+ Estimate the rigid body transformation between two sets of 3D points.
+ A and B are Nx3 matrices where each row is a 3D point.
+ Returns a rotation matrix R and translation vector t.
+ Args:
+ A, B: [batch, N, 3] matrix of 3D points
+ Outputs:
+ R, t: [batch, 3, 3/1]
+ target = R @ source (source shape [3, 1]) + t
+ """
+ assert A.shape == B.shape, "Input matrices must have the same shape"
+ assert A.shape[-1] == 3, "Input matrices must have 3 columns (x, y, z coordinates)"
+
+ # Compute centroids. [..., 1, 3]
+ centroid_A = torch.mean(A, dim=-2, keepdim=True)
+ centroid_B = torch.mean(B, dim=-2, keepdim=True)
+
+ # Center the point sets
+ A_centered = A - centroid_A
+ B_centered = B - centroid_B
+
+ # Compute the cross-covariance matrix. [..., 3, 3]
+ H = A_centered.transpose(-2, -1) @ B_centered
+
+ # Compute the Singular Value Decomposition. Along last two dimensions
+ U, S, Vt = torch.linalg.svd(H)
+
+ # Compute the rotation matrix
+ R = Vt.transpose(-2, -1) @ U.transpose(-2, -1)
+
+ # Ensure a right-handed coordinate system
+ flip_mask = (torch.det(R) < 0) * -2.0 + 1.0
+ # Vt[:, 2, :] *= flip_mask[..., None]
+
+ # [N] => [N, 3]
+ pad_flip_mask = torch.stack(
+ [torch.ones_like(flip_mask), torch.ones_like(flip_mask), flip_mask], dim=-1
+ )
+ Vt = Vt * pad_flip_mask[..., None]
+
+ # Compute the rotation matrix
+ R = Vt.transpose(-2, -1) @ U.transpose(-2, -1)
+
+ # print(R.shape, centroid_A.shape, centroid_B.shape, flip_mask.shape)
+ # Compute the translation
+ t = centroid_B - (R @ centroid_A.transpose(-2, -1)).transpose(-2, -1)
+ t = t.transpose(-2, -1)
+ return R, t
+
+
+def _test_rigid_transform():
+ # Example usage:
+ A = torch.tensor([[1, 2, 3], [4, 5, 6], [9, 8, 10], [10, -5, 1]]) * 1.0
+
+ R_synthesized = torch.tensor([[1, 0, 0], [0, -1, 0], [0, 0, -1]]) * 1.0
+ # init a random rotation matrix:
+
+ B = (R_synthesized @ A.T).T + 2.0 # Just an example offset
+
+ R, t = get_rigid_transform(A[None, ...], B[None, ...])
+ print("Rotation matrix R:")
+ print(R)
+ print("\nTranslation vector t:")
+ print(t)
+
+
+def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
+ """
+ Returns torch.sqrt(torch.max(0, x))
+ but with a zero subgradient where x is 0.
+ """
+ ret = torch.zeros_like(x)
+ positive_mask = x > 0
+ ret[positive_mask] = torch.sqrt(x[positive_mask])
+ return ret
+
+
+def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
+ """
+ from pytorch3d. Based on trace_method like: https://github.com/KieranWynn/pyquaternion/blob/master/pyquaternion/quaternion.py#L205
+ Convert rotations given as rotation matrices to quaternions.
+
+ Args:
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
+
+ Returns:
+ quaternions with real part first, as tensor of shape (..., 4).
+ """
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
+ raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
+
+ batch_dim = matrix.shape[:-2]
+ m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
+ matrix.reshape(batch_dim + (9,)), dim=-1
+ )
+
+ q_abs = _sqrt_positive_part(
+ torch.stack(
+ [
+ 1.0 + m00 + m11 + m22,
+ 1.0 + m00 - m11 - m22,
+ 1.0 - m00 + m11 - m22,
+ 1.0 - m00 - m11 + m22,
+ ],
+ dim=-1,
+ )
+ )
+
+ # we produce the desired quaternion multiplied by each of r, i, j, k
+ quat_by_rijk = torch.stack(
+ [
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
+ # `int`.
+ torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
+ # `int`.
+ torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
+ # `int`.
+ torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
+ # `int`.
+ torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
+ ],
+ dim=-2,
+ )
+
+ # We floor here at 0.1 but the exact level is not important; if q_abs is small,
+ # the candidate won't be picked.
+ flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
+ quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
+
+ # if not for numerical problems, quat_candidates[i] should be same (up to a sign),
+ # forall i; we pick the best-conditioned one (with the largest denominator)
+
+ return quat_candidates[
+ F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :
+ ].reshape(batch_dim + (4,))
+
+
+def quternion_to_matrix(r):
+ norm = torch.sqrt(
+ r[:, 0] * r[:, 0] + r[:, 1] * r[:, 1] + r[:, 2] * r[:, 2] + r[:, 3] * r[:, 3]
+ )
+
+ q = r / norm[:, None]
+
+ R = torch.zeros((q.size(0), 3, 3), device="cuda")
+
+ r = q[:, 0]
+ x = q[:, 1]
+ y = q[:, 2]
+ z = q[:, 3]
+
+ R[:, 0, 0] = 1 - 2 * (y * y + z * z)
+ R[:, 0, 1] = 2 * (x * y - r * z)
+ R[:, 0, 2] = 2 * (x * z + r * y)
+ R[:, 1, 0] = 2 * (x * y + r * z)
+ R[:, 1, 1] = 1 - 2 * (x * x + z * z)
+ R[:, 1, 2] = 2 * (y * z - r * x)
+ R[:, 2, 0] = 2 * (x * z - r * y)
+ R[:, 2, 1] = 2 * (y * z + r * x)
+ R[:, 2, 2] = 1 - 2 * (x * x + y * y)
+ return R
+
+
+def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
+ """
+ from Pytorch3d
+ Convert a unit quaternion to a standard form: one in which the real
+ part is non negative.
+
+ Args:
+ quaternions: Quaternions with real part first,
+ as tensor of shape (..., 4).
+
+ Returns:
+ Standardized quaternions as tensor of shape (..., 4).
+ """
+ return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)
+
+
+def quaternion_multiply(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
+ """
+ From pytorch3d
+ Multiply two quaternions.
+ Usual torch rules for broadcasting apply.
+
+ Args:
+ a: Quaternions as tensor of shape (..., 4), real part first.
+ b: Quaternions as tensor of shape (..., 4), real part first.
+
+ Returns:
+ The product of a and b, a tensor of quaternions shape (..., 4).
+ """
+ aw, ax, ay, az = torch.unbind(a, -1)
+ bw, bx, by, bz = torch.unbind(b, -1)
+ ow = aw * bw - ax * bx - ay * by - az * bz
+ ox = aw * bx + ax * bw + ay * bz - az * by
+ oy = aw * by - ax * bz + ay * bw + az * bx
+ oz = aw * bz + ax * by - ay * bx + az * bw
+ ret = torch.stack((ow, ox, oy, oz), -1)
+ ret = standardize_quaternion(ret)
+ return ret
+
+
+def _test_matrix_to_quaternion():
+ # init a random batch of quaternion
+ r = torch.randn((10, 4)).cuda()
+
+ norm = torch.sqrt(
+ r[:, 0] * r[:, 0] + r[:, 1] * r[:, 1] + r[:, 2] * r[:, 2] + r[:, 3] * r[:, 3]
+ )
+
+ q = r / norm[:, None]
+
+ q = standardize_quaternion(q)
+
+ R = quternion_to_matrix(q)
+
+ I_rec = R @ R.transpose(-2, -1)
+ I_rec_error = torch.abs(I_rec - torch.eye(3, device="cuda")[None, ...]).max()
+
+ q_recovered = matrix_to_quaternion(R)
+ norm_ = torch.linalg.norm(q_recovered, dim=-1)
+ q_recovered = q_recovered / norm_[..., None]
+ q_recovered = standardize_quaternion(q_recovered)
+
+ print(q_recovered.shape, q.shape, R.shape)
+
+ rec = (q - q_recovered).abs().max()
+
+ print("rotation to I error:", I_rec_error, "quant rec error: ", rec)
+
+
+def _test_matrix_to_quaternion_2():
+ R = (
+ torch.tensor(
+ [[[1, 0, 0], [0, -1, 0], [0, 0, -1]], [[1, 0, 0], [0, 0, 1], [0, -1, 0]]]
+ )
+ * 1.0
+ )
+
+ q_rec = matrix_to_quaternion(R.transpose(-2, -1))
+
+ R_rec = quternion_to_matrix(q_rec)
+
+ print(R_rec)
+
+def interpolate_points_w_R(
+ query_points, query_rotation, drive_origin_pts, drive_displacement, top_k_index
+):
+ """
+ Args:
+ query_points: [n, 3]
+ drive_origin_pts: [m, 3]
+ drive_displacement: [m, 3]
+ top_k_index: [n, top_k] < m
+
+ Or directly call: apply_discrete_offset_filds_with_R(self, origin_points, offsets, topk=6):
+ Args:
+ origin_points: (N_r, 3)
+ offsets: (N_r, 3)
+ in rendering
+ """
+
+ # [n, topk, 3]
+ top_k_disp = drive_displacement[top_k_index]
+ source_points = drive_origin_pts[top_k_index]
+
+ R, t = get_rigid_transform(source_points, source_points + top_k_disp)
+
+ avg_offsets = top_k_disp.mean(dim=1)
+
+ ret_points = query_points + avg_offsets
+
+ new_rotation = quaternion_multiply(matrix_to_quaternion(R), query_rotation)
+
+ return ret_points, new_rotation
+
+def interpolate_points(
+ query_points, query_rotation, drive_origin_pts, drive_current_points, top_k_index
+):
+ source_points = drive_origin_pts[top_k_index] # [n, topk, 3]
+ target_points = drive_current_points[top_k_index] # [n, topk, 3]
+ disp = target_points - source_points
+ avg_offsets = disp.mean(dim=1) # [n, 3]
+ ret_points = query_points + avg_offsets # [n, 3]
+ # ret_points = target_points.mean(dim=1) # [n, 3]
+ R, t = get_rigid_transform(source_points, target_points)
+ new_rotation = quaternion_multiply(matrix_to_quaternion(R), query_rotation)
+ return ret_points, new_rotation
\ No newline at end of file
diff --git a/src/utils/load_utils.py b/src/utils/load_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e64517d31d465c6d24bedb8911264485ccf2485
--- /dev/null
+++ b/src/utils/load_utils.py
@@ -0,0 +1,221 @@
+import numpy as np
+import torch
+import gc
+from PIL import Image
+import sys
+import os
+
+# Add the project root directory to Python path (use absolute paths for robustness)
+project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+sys.path.append(project_root)
+sys.path.append(os.path.join(project_root, "libs"))
+sys.path.append(os.path.join(project_root, "libs", "LGM"))
+sys.path.append(os.path.join(project_root, "libs", "das"))
+sys.path.append(os.path.join(project_root, "src"))
+
+from sv3d.diffusers_sv3d import SV3DUNetSpatioTemporalConditionModel, StableVideo3DDiffusionPipeline
+from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
+from diffusers import AutoencoderKL, EulerDiscreteScheduler, DDPMScheduler, DDIMScheduler
+from diffusers.utils import export_to_gif, export_to_video
+from kiui.cam import orbit_camera
+from safetensors.torch import load_file
+from omegaconf import OmegaConf
+
+from LGM.core.models import LGM
+from LGM.core.options import AllConfigs
+from LGM.core.gs import GaussianRenderer
+from .track_utils.visualize_tracks import visualize_tracks
+from .track_utils.preprocessing import track_first, find_and_remove_nearest_point
+from .interpolate import interpolate_points
+from das.models.pipelines import DiffusionAsShaderPipeline
+
+import h5py
+import tyro
+from tqdm import tqdm
+from options import TestingConfig
+from pipeline_traj import TrajPipeline
+from model.spacetime import MDM_ST
+from argparse import Namespace
+
+def load_sv3d_pipeline(device, model_path="chenguolin/sv3d-diffusers"):
+ unet = SV3DUNetSpatioTemporalConditionModel.from_pretrained(model_path, subfolder="unet")
+ vae = AutoencoderKL.from_pretrained(model_path, subfolder="vae")
+ scheduler = EulerDiscreteScheduler.from_pretrained(model_path, subfolder="scheduler")
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(model_path, subfolder="image_encoder")
+ feature_extractor = CLIPImageProcessor.from_pretrained(model_path, subfolder="feature_extractor")
+ pipeline = StableVideo3DDiffusionPipeline(
+ image_encoder=image_encoder, feature_extractor=feature_extractor,
+ unet=unet, vae=vae,
+ scheduler=scheduler,
+ ).to(device)
+ return pipeline
+
+def load_LGM(opt, device, lgm_ckpt_path="./checkpoints/lgm_fp16.safetensors"):
+ model = LGM(opt)
+ ckpt = load_file(lgm_ckpt_path, device='cpu')
+ model.load_state_dict(ckpt, strict=False)
+ model = model.half().to(device)
+ model.eval()
+ return model
+
+def load_diffusion(device, model_cfg_path, diffusion_ckpt_path, seed=0):
+ schema = OmegaConf.structured(TestingConfig)
+ cfg = OmegaConf.load(model_cfg_path)
+ cfg = OmegaConf.merge(schema, cfg)
+ n_training_frames = cfg.train_dataset.n_training_frames
+ n_frames_interval = cfg.train_dataset.n_frames_interval
+ norm_fac = cfg.train_dataset.norm_fac
+
+ model = MDM_ST(cfg.pc_size, n_training_frames, n_feats=3, model_config=cfg.model_config).to(device)
+
+ ckpt = load_file(diffusion_ckpt_path, device='cpu')
+ model.load_state_dict(ckpt, strict=False)
+ model.eval().requires_grad_(False)
+ noise_scheduler = DDIMScheduler(num_train_timesteps=1000, prediction_type='sample', clip_sample=False)
+ pipeline = TrajPipeline(model=model, scheduler=noise_scheduler)
+ return pipeline
+
+def gen_tracking_video(base_dir):
+
+ animated_points = np.load(f'{base_dir}/gen_data.npy')
+ animated_points = animated_points * 2
+ new_animate_points = np.zeros((49, 2048, 3))
+ for i in range(47):
+ if i % 2 == 0:
+ new_animate_points[i + 1] = animated_points[i // 2]
+ else:
+ new_animate_points[i + 1] = (animated_points[i // 2] + animated_points[i // 2 + 1]) / 2
+ new_animate_points[0] = new_animate_points[1]
+ new_animate_points[48] = new_animate_points[47]
+ animated_points = new_animate_points
+
+ projection_matrix = np.load(f'{base_dir}/projection.npy')
+ crop_info = np.load(f'{base_dir}/crop_info.npy')
+ center = np.load(f'{base_dir}/center.npy')
+ scale = np.load(f'{base_dir}/scale.npy')
+ animated_points = (animated_points / scale) + center
+
+ ## Aligned to Gaussian points at this moment
+ print(animated_points.mean(), animated_points.std(), animated_points.max(), animated_points.min())
+ device = torch.device("cuda")
+ sys.argv = ['pipeline_track_gen.py', 'big']
+ opt = tyro.cli(AllConfigs)
+
+ scale_factor = 2
+ focal = 0.5 * opt.output_size / np.tan(np.deg2rad(opt.fovy) / 2)
+ new_fovy_rad = scale_factor * np.arctan(opt.output_size / focal)
+ new_fovy_deg = np.rad2deg(new_fovy_rad)
+ opt.fovy = new_fovy_deg
+ opt.output_size *= scale_factor # Expand canvas size by 2
+
+ gs = GaussianRenderer(opt)
+ gaussians = gs.load_ply(f'{base_dir}/point_cloud.ply', compatible=True).to(device).float()
+ idx = torch.from_numpy(np.load(f'{base_dir}/idx.npy')).to(device)
+ gaussian_pos = gaussians[:, :3].contiguous()
+ drive_x = gaussian_pos[idx]
+ cdist = -1.0 * torch.cdist(gaussian_pos, drive_x) # [N, 2048]
+ _, topk_index = torch.topk(cdist, 8, -1)
+
+ cam_poses = torch.from_numpy(orbit_camera(0, 0, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)
+ cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
+ cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
+ cam_view_proj = cam_view @ gs.proj_matrix.to(device) # [V, 4, 4]
+ cam_pos = - cam_poses[:, :3, 3] # [V, 3]
+
+ pos = []
+
+ for i in tqdm(range(0, 49, 1)):
+ drive_current = torch.from_numpy(animated_points[i]).to(device).float()
+ ret_points, new_rotation = interpolate_points(gaussian_pos, gaussians[:, 7:11], drive_x, drive_current, topk_index)
+ gaussians_new = gaussians.clone()
+ gaussians_new[:, :3] = ret_points
+ gaussians_new[:, 7:11] = new_rotation
+ pos.append(ret_points.cpu().numpy())
+
+ # with torch.no_grad():
+ # ret = gs.render(gaussians_new.unsqueeze(0), cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), scale_modifier=1)
+ # mask = (ret['alpha'][0,0].permute(1, 2, 0).contiguous().float().cpu().numpy() * 255.0).astype(np.uint8)
+ # image = (ret['image'][0, 0].permute(1, 2, 0).contiguous().float().cpu().numpy()*255.0).astype(np.uint8)
+ # image_save = np.concatenate([image, mask], axis=-1)
+
+ # h_begin, w_begin, res = crop_info[0], crop_info[1], crop_info[2]
+ # h_begin = h_begin - (256 * scale_factor - 256)
+ # w_begin = w_begin - (256 * scale_factor - 256)
+ # image_save = Image.fromarray(image_save).resize((res * scale_factor, res * scale_factor), Image.LANCZOS)
+
+ template_path = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'templates', 'tracks_template.npy')
+ track_template = np.load(template_path, allow_pickle=True)
+ tracks = track_template.item()['tracks']
+ tracks_output = tracks.copy()
+ tracks_init = tracks[0, 0]
+ track_idx = []
+ mask = np.zeros(tracks_init.shape[0], dtype=bool)
+
+ for i in tqdm(range(49)):
+
+ # points = animated_points[i]
+ points = pos[i]
+
+ projected_points = (projection_matrix.T @ np.hstack((points, np.ones((points.shape[0], 1)))).T).T
+ projected_points_weights = 1. / (projected_points[:, -1:] + 1e-8)
+ projected_points = (projected_points * projected_points_weights)[:, :-1]
+
+ h_begin, w_begin, res = crop_info[0], crop_info[1], crop_info[2]
+ image_shape = (res, res) # Example image shape (H, W)
+ projected_points[:, :2] = ((projected_points[:, :2] + 1) * image_shape[1] - 1) / 2
+ projected_points[:, 0] += w_begin
+ projected_points[:, 1] += h_begin
+
+ if i == 0:
+ track_point_candidates = track_first(projected_points, (480, 720))
+ for j in range(tracks_init.shape[0]):
+ x, y = tracks_init[j, 0], tracks_init[j, 1]
+ target = np.array([x, y])
+ candidate, track_point_candidates = find_and_remove_nearest_point(target, track_point_candidates)
+ if candidate is not None:
+ track_idx.append(candidate[3].astype(np.int32))
+ mask[j] = True
+
+ tracks_output[0, i, mask] = projected_points[track_idx]
+ tracks_output[0, i, ~mask, :2] = tracks_output[0, 0, ~mask, :2]
+ tracks_output[0, i, ~mask, 2] = 2
+
+ track_template.item()['tracks'] = tracks_output
+ # track_template.item()['drag_points'] = np.stack(drag_points, axis=0)
+ sub_name = 'tracks_gen'
+ sub_dir = f'{base_dir}/{sub_name}'
+ os.makedirs(sub_dir, exist_ok=True)
+
+ np.save(f'{sub_dir}/tracks.npy', track_template)
+ args = Namespace(tracks_dir=sub_dir, output_dir=sub_dir, output_fps=24, point_size=10, len_track=0, num_frames=49, video_path=None)
+ visualize_tracks(tracks_dir=sub_dir, output_dir=sub_dir, args=args)
+
+def load_das(gpu_id, output_dir):
+ das = DiffusionAsShaderPipeline(gpu_id=gpu_id, output_dir=output_dir)
+ return das
+
+def normalize_points(output_dir, fluid=False):
+ from .transform import transform2origin, shift2center
+ import trimesh
+ from torch_cluster import fps
+
+ device = 'cuda'
+
+ pc_path = f'{output_dir}/point_cloud.ply'
+ pc = trimesh.load_mesh(pc_path)
+ points = pc.vertices
+ points = np.array(points)
+ points, center, scale = transform2origin(points, size=1)
+ N = 2048
+ grid_center = [5, 5, 5]
+ drag_size = [0.4, 0.4, 0.4]
+
+ points = shift2center(points, center=grid_center)
+ points = torch.tensor(points, dtype=torch.float32, device=device).contiguous()
+ np.save(f'{output_dir}/center.npy', center)
+ np.save(f'{output_dir}/scale.npy', scale)
+ ratio_N = N / points.shape[0]
+ idx = fps(points, ratio=ratio_N, random_start=True)
+ points = points[idx].cpu().numpy()
+ np.save(f'{output_dir}/idx.npy', idx.cpu().numpy())
+ return points, center, scale
\ No newline at end of file
diff --git a/src/utils/loading.py b/src/utils/loading.py
new file mode 100644
index 0000000000000000000000000000000000000000..81c7289bdffb4a143c17bdb73d1bb00509b3a5a4
--- /dev/null
+++ b/src/utils/loading.py
@@ -0,0 +1,39 @@
+
+from PIL import Image
+from typing import Tuple
+
+import trimesh
+import numpy as np
+
+def load_mesh(path):
+ mesh = trimesh.load_mesh(path, force='mesh')
+ if isinstance(mesh, trimesh.Scene):
+ mesh = trimesh.util.concatenate(mesh.dump())
+ return mesh
+
+def paste_image(A: Image.Image, B: Image.Image, h: int, w: int) -> Image.Image:
+ A = A.convert("RGBA")
+ B = B.convert("RGBA") # Ensure B has an alpha channel
+
+ A_width, A_height = A.size
+ B_width, B_height = B.size
+
+ # Crop A if h or w are negative
+ crop_left = max(0, -w)
+ crop_top = max(0, -h)
+ A_cropped = A.crop((crop_left, crop_top, A_width, A_height))
+
+ # Adjust destination position on B
+ paste_x = max(0, w)
+ paste_y = max(0, h)
+
+ # Ensure A_cropped fits within B bounds
+ max_w = B_width - paste_x
+ max_h = B_height - paste_y
+ A_cropped = A_cropped.crop((0, 0, min(A_cropped.width, max_w), min(A_cropped.height, max_h)))
+
+ # Use alpha channel of A as mask
+ alpha = A_cropped.split()[-1]
+ B.paste(A_cropped, (paste_x - 2, paste_y - 2), mask=alpha)
+
+ return B
\ No newline at end of file
diff --git a/src/utils/physics.py b/src/utils/physics.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a04624c3609fbcf3e190be17ae3c76cc379653d
--- /dev/null
+++ b/src/utils/physics.py
@@ -0,0 +1,252 @@
+import os
+import h5py
+import torch
+import torch.nn.functional as Fn
+import numpy as np
+import json
+
+class DeformLoss(torch.nn.Module):
+
+ def __init__(self):
+ super().__init__()
+
+ self.device = "cuda"
+ self.N = 2048
+ self.I33 = torch.eye(3, device=self.device).unsqueeze(0).repeat(self.N, 1, 1)
+ self.dT = 0.0417
+ self.grid_lim = 10
+ self.grid_size = 125
+ self.dx = self.grid_lim / self.grid_size
+ self.inv_dx = 1 / self.dx
+ self.density = 1000
+
+ def forward_sequential(self, x, vol, F, C, frame_interval=2, norm_fac=5, v=None):
+
+ # Denormalize x & Double dt (since we sample every 2 frames) for training
+ if norm_fac > 0:
+ x = x * 2 + norm_fac
+ dT = self.dT * frame_interval
+
+ loss = 0
+
+ for bs in range(x.shape[0]):
+
+ particle_mass = (self.density * vol[bs]).unsqueeze(-1).repeat(1, 3)
+
+ start_t = 1 if frame_interval == 1 else 0
+ end_t = x.shape[1] - 2
+ for t in range(start_t, end_t):
+
+ # Initialize
+ grid_m = torch.zeros((self.grid_size, self.grid_size, self.grid_size), device=self.device)
+ grid_v = torch.zeros((self.grid_size, self.grid_size, self.grid_size, 3), device=self.device)
+
+ particle_x = x[bs, t]
+ if v is not None:
+ particle_v = v[bs, t + 1]
+ else:
+ particle_v = (x[bs, t + 2] - x[bs, t]) / (2 * dT)
+
+ particle_F = F[bs, t].reshape(-1, 3, 3)
+ particle_F_next = F[bs, t + 1].reshape(-1, 3, 3)
+ particle_C = C[bs, t].reshape(-1, 3, 3)
+
+ # P2G
+ grid_pos = particle_x * self.inv_dx
+ base_pos = (grid_pos - 0.5).int()
+ fx = grid_pos - base_pos
+ w = [0.5 * ((1.5 - fx) ** 2), 0.75 - ((fx - 1) ** 2), 0.5 * ((fx - 0.5) ** 2)]
+ w = torch.stack(w, dim=2)
+ dw = [fx - 1.5, -2 * (fx - 1), fx - 0.5]
+ dw = torch.stack(dw, dim=2)
+
+ for i in range(3):
+ for j in range(3):
+ for k in range(3):
+ dpos = torch.tensor([i, j, k], device=self.device).unsqueeze(0).repeat(self.N, 1)
+ dpos = (dpos - fx) * self.dx
+ ix = base_pos[:, 0] + i
+ iy = base_pos[:, 1] + j
+ iz = base_pos[:, 2] + k
+ weight = w[:, 0, i] * w[:, 1, j] * w[:, 2, k]
+ dweight = [dw[:, 0, i] * w[:, 1, j] * w[:, 2, k],
+ w[:, 0, i] * dw[:, 1, j] * w[:, 2, k],
+ w[:, 0, i] * w[:, 1, j] * dw[:, 2, k]]
+ dweight = torch.stack(dweight, dim=1) * self.inv_dx
+
+ v_in_add = weight.unsqueeze(-1) * particle_mass * (particle_v + \
+ (particle_C @ dpos.unsqueeze(-1)).squeeze(-1))
+
+ flat_idx = ix * self.grid_size * self.grid_size + iy * self.grid_size + iz
+ flat_idx = flat_idx.long()
+
+ grid_v = grid_v.view(-1, 3)
+ grid_v = grid_v.scatter_add(0, flat_idx.unsqueeze(-1).repeat(1, 3), v_in_add)
+ grid_v = grid_v.view(self.grid_size, self.grid_size, self.grid_size, 3)
+
+ grid_m = grid_m.view(-1)
+ grid_m = grid_m.scatter_add(0, flat_idx, weight * particle_mass[:, 0])
+ grid_m = grid_m.view(self.grid_size, self.grid_size, self.grid_size)
+
+ # Grid Norm
+ grid_m = torch.where(grid_m > 1e-15, grid_m, torch.ones_like(grid_m))
+ grid_v = grid_v / grid_m.unsqueeze(-1)
+
+ # G2P
+ new_F_pred = torch.zeros_like(particle_F)
+
+ for i in range(3):
+ for j in range(3):
+ for k in range(3):
+ dpos = torch.tensor([i, j, k], device=self.device).unsqueeze(0).repeat(self.N, 1).float() - fx
+ ix = base_pos[:, 0] + i
+ iy = base_pos[:, 1] + j
+ iz = base_pos[:, 2] + k
+
+ weight = w[:, 0, i] * w[:, 1, j] * w[:, 2, k]
+ dweight = [dw[:, 0, i] * w[:, 1, j] * w[:, 2, k],
+ w[:, 0, i] * dw[:, 1, j] * w[:, 2, k],
+ w[:, 0, i] * w[:, 1, j] * dw[:, 2, k]]
+ dweight = torch.stack(dweight, dim=1) * self.inv_dx
+ grid_v_local = grid_v[ix, iy, iz]
+ new_F_pred = new_F_pred + (grid_v_local.unsqueeze(-1) @ dweight.unsqueeze(1))
+
+ F_pred = (self.I33 + new_F_pred * dT) @ particle_F
+ loss = loss + Fn.l1_loss(F_pred, particle_F_next)
+ # loss = loss + Fn.l1_loss(particle_F, particle_F_next)
+
+ return loss / x.shape[0]
+
+ def forward(self, x, vol, F, C, frame_interval=2, norm_fac=5, v=None):
+
+ # Denormalize x & Double dt (since we sample every 2 frames) for training
+ if norm_fac > 0:
+ x = x * 2 + norm_fac
+ dT = self.dT * frame_interval
+
+ loss = 0
+
+ bs = x.shape[0]
+ start_t = 1 if frame_interval == 1 else 0
+ end_t = x.shape[1] - 2
+ M = bs * (end_t - start_t)
+
+ # Initialize
+ grid_m = torch.zeros((M, self.grid_size, self.grid_size, self.grid_size), device=self.device)
+ grid_v = torch.zeros((M, self.grid_size, self.grid_size, self.grid_size, 3), device=self.device)
+
+ particle_x = x[:, start_t:end_t].reshape(M, self.N, 3)
+ # particle_x = x[:, (start_t+1):(end_t+1)].reshape(M, self.N, 3)
+
+ if v is not None:
+ # particle_v = v[:, start_t:end_t].reshape(M, self.N, 3)
+ particle_v = v[:, (start_t+1):(end_t+1)].reshape(M, self.N, 3)
+ else:
+ particle_v = (x[:, (start_t+2):(end_t+2)] - x[:, start_t:end_t]) / (2 * dT)
+ particle_v = particle_v.reshape(M, self.N, 3)
+
+ particle_F = F[:, start_t:end_t].reshape(M, self.N, 3, 3)
+ particle_F_next = F[:, (start_t+1):(end_t+1)].reshape(M, self.N, 3, 3)
+
+ particle_C = C[:, start_t:end_t].reshape(M, self.N, 3, 3)
+ # particle_C = C[:, (start_t+1):(end_t+1)].reshape(M, self.N, 3, 3)
+
+ vol = vol.unsqueeze(1).repeat(1, end_t - start_t, 1).reshape(M, self.N)
+ particle_mass = (self.density * vol).unsqueeze(-1).repeat(1, 1, 3)
+
+ # P2G
+ grid_pos = particle_x * self.inv_dx
+ base_pos = (grid_pos - 0.5).int()
+ fx = grid_pos - base_pos
+ w = [0.5 * ((1.5 - fx) ** 2), 0.75 - ((fx - 1) ** 2), 0.5 * ((fx - 0.5) ** 2)]
+ w = torch.stack(w, dim=3)
+ dw = [fx - 1.5, -2 * (fx - 1), fx - 0.5]
+ dw = torch.stack(dw, dim=3)
+
+ for i in range(3):
+ for j in range(3):
+ for k in range(3):
+
+ dpos = torch.tensor([i, j, k], device=self.device).unsqueeze(0).unsqueeze(0).repeat(M, self.N, 1)
+ dpos = (dpos - fx) * self.dx
+ ix = base_pos[:, :, 0] + i
+ iy = base_pos[:, :, 1] + j
+ iz = base_pos[:, :, 2] + k
+
+ weight = w[:, :, 0, i] * w[:, :, 1, j] * w[:, :, 2, k]
+ dweight = [dw[:, :, 0, i] * w[:, :, 1, j] * w[:, :, 2, k],
+ w[:, :, 0, i] * dw[:, :, 1, j] * w[:, :, 2, k],
+ w[:, :, 0, i] * w[:, :, 1, j] * dw[:, :, 2, k]]
+ dweight = torch.stack(dweight, dim=2) * self.inv_dx
+
+ v_in_add = weight.unsqueeze(-1) * particle_mass * (particle_v + \
+ (particle_C @ dpos.unsqueeze(-1)).squeeze(-1))
+
+ flat_idx = ix * self.grid_size * self.grid_size + iy * self.grid_size + iz
+ flat_idx = flat_idx.long()
+
+ grid_v = grid_v.view(M, -1, 3)
+ grid_v = grid_v.scatter_add(1, flat_idx.unsqueeze(-1).repeat(1, 1, 3), v_in_add)
+ grid_v = grid_v.view(M, self.grid_size, self.grid_size, self.grid_size, 3)
+
+ grid_m = grid_m.view(M, -1)
+ grid_m = grid_m.scatter_add(1, flat_idx, weight * particle_mass[:, :, 0])
+ grid_m = grid_m.view(M, self.grid_size, self.grid_size, self.grid_size)
+ # Grid Norm
+ grid_m = torch.where(grid_m > 1e-15, grid_m, torch.ones_like(grid_m))
+ grid_v = grid_v / grid_m.unsqueeze(-1)
+
+ # G2P
+ new_F_pred = torch.zeros_like(particle_F)
+
+ for i in range(3):
+ for j in range(3):
+ for k in range(3):
+
+ dpos = torch.tensor([i, j, k], device=self.device).unsqueeze(0).unsqueeze(0).repeat(M, self.N, 1).float() - fx
+ ix = base_pos[:, :, 0] + i
+ iy = base_pos[:, :, 1] + j
+ iz = base_pos[:, :, 2] + k
+ weight = w[:, :, 0, i] * w[:, :, 1, j] * w[:, :, 2, k]
+ dweight = [dw[:, :, 0, i] * w[:, :, 1, j] * w[:, :, 2, k],
+ w[:, :, 0, i] * dw[:, :, 1, j] * w[:, :, 2, k],
+ w[:, :, 0, i] * w[:, :, 1, j] * dw[:, :, 2, k]]
+
+ dweight = torch.stack(dweight, dim=2) * self.inv_dx
+ flat_idx = ix * self.grid_size * self.grid_size + iy * self.grid_size + iz
+ flat_idx = flat_idx.long()
+
+ grid_v = grid_v.view(M, -1, 3)
+ grid_v_local = grid_v.gather(1, flat_idx.unsqueeze(-1).repeat(1, 1, 3))
+ new_F_pred = new_F_pred + (grid_v_local.unsqueeze(-1) @ dweight.unsqueeze(2))
+
+ F_pred = (self.I33 + new_F_pred * dT) @ particle_F
+ loss = loss + Fn.l1_loss(F_pred, particle_F_next)
+ return loss * (end_t - start_t)
+
+def loss_momentum(x, vol, force, drag_pt_num, start_frame=1, frame_interval=2,
+ norm_fac=5, v=None, density=1000, dt=0.0417):
+
+ # Denormalize x & Double dt (since we sample every 2 frames) for training
+ if norm_fac > 0:
+ x = x * 2 + norm_fac
+ dt = dt * frame_interval
+
+ loss = []
+ if v is not None:
+ v_curr = v[:, 1:-1]
+ else:
+ v_pos = x[:, 1:-1] - x[:, :-2]
+ v_neg = x[:, 2:] - x[:, 1:-1]
+ v_curr = (v_pos + v_neg) / (2 * dt)
+
+ p_int = density * vol.unsqueeze(-1).unsqueeze(1) * v_curr
+ p_int = p_int.sum(dim=2)
+ dt_acc = torch.arange(1, x.shape[1] - 1, device=p_int.device, dtype=p_int.dtype) * dt
+ force = force.unsqueeze(1)
+ drag_pt_num = drag_pt_num.unsqueeze(1)
+ dt_acc = dt_acc.unsqueeze(0).unsqueeze(-1).repeat(drag_pt_num.shape[0], 1, 3)
+ p_ext = force * dt_acc * drag_pt_num
+ p_ext = p_ext + start_frame * force * (dt / frame_interval) * drag_pt_num
+ loss = Fn.mse_loss(p_int, p_ext)
+ return loss
\ No newline at end of file
diff --git a/src/utils/physparam.py b/src/utils/physparam.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ea15fa3421a67127a03eade607baeb8b1c690d2
--- /dev/null
+++ b/src/utils/physparam.py
@@ -0,0 +1,222 @@
+import torch
+from diffusers import DDPMScheduler, DDIMScheduler
+from dataset.traj_dataset import TrajDataset
+from model.mdm import MDM
+from model.mdm_dit import MDM_DiT
+from model.spacetime import MDM_ST
+import sys
+from options import TrainingConfig, TestingConfig
+from omegaconf import OmegaConf
+from pipeline_traj import TrajPipeline
+import torch
+from safetensors.torch import load_file
+import argparse
+import os
+import torch.nn as nn
+import torch.nn.functional as F
+from eval import create_model
+from tqdm import tqdm
+import numpy as np
+from utils.visualization import save_pointcloud_video, save_pointcloud_json, save_threejs_html
+import matplotlib.pyplot as plt
+
+def fibonacci_sphere(n):
+ i = torch.arange(n, dtype=torch.float32)
+ phi = 2 * torch.pi * i / ((1 + 5**0.5) / 2) # golden‑angle
+ z = 1 - 2 * (i + 0.5) / n # uniform in [-1,1]
+ r_xy = (1 - z**2).sqrt()
+ x = r_xy * torch.cos(phi)
+ y = r_xy * torch.sin(phi)
+ return torch.stack((x, y, z), dim=1) # shape (n,3)
+
+class Inferrer:
+ def __init__(self, args, device='cuda'):
+ self.args = args
+ self.device = device
+ self.model = create_model(args).to(device)
+
+ ckpt = load_file(args.resume, device='cpu')
+ self.model.load_state_dict(ckpt, strict=False)
+ self.model.eval().requires_grad_(False).to(device)
+ self.scheduler = DDIMScheduler(num_train_timesteps=1000, prediction_type='sample', clip_sample=False)
+ self.pipeline = TrajPipeline(model=self.model, scheduler=self.scheduler)
+
+ @torch.no_grad()
+ def probe_params(self, init_pc, force, motion_obs, mask, drag_point, floor_height, coeff, y, vis_dir=None, fname=None):
+ out = []
+ for e in torch.arange(4.0, 7.1, 0.5):
+ # for n in torch.arange(0.2, 0.45, 0.05):
+ # for n in [0.36]:
+ E, nu = torch.tensor([e], device=self.device).reshape(1, 1), torch.tensor([n], device=self.device).reshape(1, 1)
+ motion_pred = self.pipeline(init_pc, force, E, nu, mask, drag_point, floor_height, gravity=None, coeff=coeff, y=y, device=self.device, batch_size=1, generator=torch.Generator().manual_seed(self.args.seed), n_frames=self.args.train_dataset.n_training_frames, num_inference_steps=25)
+ loss = F.mse_loss(motion_pred, motion_obs.to(self.device))
+ out.append([loss, e, n])
+ # save_pointcloud_video(motion_pred.squeeze().cpu().numpy(), motion_obs.squeeze().cpu().numpy(), os.path.join(f'{e.item():03f}_{nu.item():02f}.gif'), drag_mask=mask[:1, 0, :, 0].cpu().numpy().squeeze(), vis_flag='objaverse')
+ out = torch.tensor(out).cpu().numpy()
+ print("Best E, nu: ", out[np.argmin(out[:, 0])])
+ plt.plot(out[:, 1], out[:, 0], marker='o', linestyle='-', linewidth=2)
+ plt.xlabel('E')
+ plt.ylabel('Loss')
+ plt.savefig(os.path.join(vis_dir, f'{fname}.png'))
+ plt.close()
+
+ return out
+
+ def forward_model(self, motion_noisy, t, init_pc, force, E, nu, mask, guidance_scale=1.0):
+ bsz = motion_noisy.shape[0]
+ null_emb = torch.tensor([1] * motion_noisy.shape[0]).to(motion_noisy.dtype)
+ if cfg > 1.0:
+ motion_noisy = torch.cat([motion_noisy] * 2)
+ init_pc = torch.cat([init_pc] * 2)
+ force = torch.cat([force] * 2)
+ E = torch.cat([E] * 2)
+ nu = torch.cat([nu] * 2)
+ t = torch.cat([t] * 2)
+ mask = torch.cat([mask] * 2)
+ null_emb = torch.cat([torch.tensor([0] * bsz).to(motion_noisy.dtype), null_emb])
+ null_emb = null_emb[:, None, None].to(self.device, dtype=motion_noisy.dtype)
+ model_output = self.model(motion_noisy, t, init_pc, force, E, nu, mask)
+ if cfg > 1.0:
+ model_pred_uncond, model_pred_cond = model_output.chunk(2)
+ model_output = model_pred_uncond + guidance_scale * (model_pred_cond - model_pred_uncond)
+ return model_output
+
+ def inference_model(self, init_pc, force, E, nu, mask, drag_point, floor_height, coeff,
+ generator,
+ device,
+ batch_size: int = 1,
+ num_inference_steps: int = 50,
+ guidance_scale=1.0,
+ n_frames=20
+ ):
+ # Sample gaussian noise to begin loop
+ sample = torch.randn((batch_size, n_frames, init_pc.shape[2], 3), generator=generator).to(device)
+ # set step values
+ self.scheduler.set_timesteps(num_inference_steps)
+ do_classifier_free_guidance = (guidance_scale > 1.0)
+ null_emb = torch.tensor([1] * batch_size).to(sample.dtype)
+ if do_classifier_free_guidance:
+ init_pc = torch.cat([init_pc] * 2)
+ force = torch.cat([force] * 2)
+ E = torch.cat([E] * 2)
+ nu = torch.cat([nu] * 2)
+ mask = torch.cat([mask] * 2)
+ drag_point = torch.cat([drag_point] * 2)
+ floor_height = torch.cat([floor_height] * 2)
+ null_emb = torch.cat([torch.tensor([0] * batch_size).to(sample.dtype), null_emb])
+ null_emb = null_emb[:, None, None].to(device)
+ for t in self.scheduler.timesteps:
+ t = torch.tensor([t] * batch_size, device=device)
+ sample_input = torch.cat([sample] * 2) if do_classifier_free_guidance else sample
+ t = torch.cat([t] * 2) if do_classifier_free_guidance else t
+ # 1. predict noise model_output
+ model_output = self.model(sample_input, t, init_pc, force, E, nu, mask, drag_point, floor_height, coeff, y=y, null_emb=null_emb)
+ if do_classifier_free_guidance:
+ model_pred_uncond, model_pred_cond = model_output.chunk(2)
+ model_output = model_pred_uncond + guidance_scale * (model_pred_cond - model_pred_uncond)
+ sample = self.scheduler.step(model_output, t[0], sample).prev_sample
+ return sample
+
+ def estimate_params(self, model_name, motion_obs, init_pc, force, mask, drag_point, floor_height, coeff, y, cfg=1.0, gravity=None, probe=False, num_steps=400):
+ device = 'cuda'
+
+ all_loss = []
+ if probe:
+ out = []
+ for e in torch.arange(4.0, 7.1, 0.5):
+ E = torch.tensor([e], device=self.device).reshape(1, 1)
+ motion_pred = self.pipeline(init_pc, force, E, nu, mask, drag_point, floor_height, gravity=gravity, coeff=coeff, y=y, device=self.device, batch_size=1, generator=torch.Generator().manual_seed(self.args.seed), n_frames=self.args.train_dataset.n_training_frames, num_inference_steps=25)
+ loss = F.mse_loss(motion_pred, motion_obs.to(self.device))
+ out.append([loss.item(), E.item()])
+ out = torch.tensor(out)
+ print("Best E, nu: ", out[torch.argmin(out[:, 0])])
+ E = nn.Parameter(torch.tensor([out[np.argmin(out[:, 0]), 1]], device=device).reshape(1, 1))
+ all_loss.append(out[torch.argmin(out[:, 0])])
+ else:
+ E = nn.Parameter(torch.tensor([4.5], device=device).reshape(1, 1))
+ # nu = nn.Parameter(torch.tensor([0.15], device=device).reshape(1, 1))
+ # force = nn.Parameter(torch.zeros([1, 3], device=device))
+ # drag_point = nn.Parameter(torch.zeros([1, 3], device=device))
+ optimizer = torch.optim.Adam([
+ {'params': E, 'lr': 1e-2, 'min': 4.0, 'max': 7.0},
+ # {'params': nu, 'lr': 1e-2, 'min': 0.15, 'max': 0.4},
+ # {'params': [force, drag_point], 'lr': 1e-2}
+ ])
+ self.model.requires_grad_(True)
+ progress_bar = tqdm(total=num_steps)
+ progress_bar.set_description("Training")
+ Es = []
+ for step in range(num_steps):
+ optimizer.zero_grad()
+ noise = torch.randn_like(motion_obs, device=device)
+ t = torch.randint(0, self.scheduler.num_train_timesteps, (motion_obs.shape[0],), device=device)
+ motion_noisy = self.scheduler.add_noise(motion_obs, noise, t)
+ model_output = self.model(motion_noisy, t, init_pc, force, E, nu, mask, drag_point, floor_height, gravity, coeff, y=y)
+ loss = F.mse_loss(model_output, motion_obs)
+ progress_bar.update(1)
+ progress_bar.set_postfix({'loss': loss.item(), 'E': E.item(), 'nu': nu.item()})
+ loss.backward()
+ optimizer.step()
+
+ with torch.no_grad():
+ E.clamp_(4.0, 7.0)
+
+ if (step + 1) % 200 == 0:
+ Es.append(E.item())
+
+ if (step + 1) % 200 == 0:
+ motion_pred = self.pipeline(init_pc, force.detach(), E.detach(), nu.detach(), mask, drag_point.detach(), floor_height, gravity=gravity, coeff=coeff, y=y, device=self.device, batch_size=1, generator=torch.Generator().manual_seed(self.args.seed), n_frames=self.args.train_dataset.n_training_frames, num_inference_steps=25)
+ loss = F.mse_loss(motion_pred, motion_obs)
+ all_loss.append(torch.tensor([loss.item(), E.item()]))
+ out = torch.stack(all_loss)
+ print(out)
+ E = out[torch.argmin(out[:, 0]), 1].to(device)
+ Es.append(E.item())
+ save_pointcloud_video(motion_pred.squeeze().cpu().numpy(), motion_obs.squeeze().cpu().numpy(), os.path.join(f'./debug/v3', f'{model_name}_{E.item():03f}_{nu.item():02f}.gif'), drag_mask=mask[:1, 0, :, 0].cpu().numpy().squeeze(), vis_flag='objaverse')
+ return Es, nu, force, drag_point
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--config', type=str, required=True)
+ args = parser.parse_args()
+ schema = OmegaConf.structured(TestingConfig)
+ cfg = OmegaConf.load(args.config)
+ args = OmegaConf.merge(schema, cfg)
+
+ val_dataset = TrajDataset('val', args.train_dataset)
+ # val_dataset = [val_dataset[i] for i in range(len(val_dataset) - 15, len(val_dataset))]
+ val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=args.eval_batch_size, shuffle=False, num_workers=args.dataloader_num_workers)
+
+ inferrer = Inferrer(args)
+ loss = 0.0
+ loss_mask = 0.0
+ probe = True
+ num_steps = 400
+ for i, (batch, _) in enumerate(val_dataloader):
+ device = torch.device('cuda')
+ with torch.autocast("cuda", dtype=torch.bfloat16):
+ model_name = batch['model'][0]
+ motion_obs = batch['points_tgt'].to(device)
+ init_pc = batch['points_src'].to(device)
+ force = batch['force'].to(device)
+ E = batch['E'].to(device)
+ nu = batch['nu'].to(device)
+ mask = batch['mask'][..., :1].to(device, dtype=force.dtype)
+ drag_point = batch['drag_point'].to(device)
+ floor_height = batch['floor_height'].to(device)
+ coeff = batch['base_drag_coeff']
+ y=None if 'mat_type' not in batch else batch['mat_type'].to(device)
+ gravity = batch['gravity'].to(device) if 'gravity' in batch else None
+ print(model_name, floor_height)
+ # for j in range(output.shape[0]):
+ # save_pointcloud_video(motion_obs.squeeze().cpu().numpy(), motion_obs.squeeze().cpu().numpy(), os.path.join('./debug', f'{i:03d}_{E.item():03f}_{nu.item():02f}.gif'), drag_mask=mask[:1, 0, :, 0].cpu().numpy().squeeze(), vis_flag='objaverse')
+
+ print('GT', E, nu, drag_point, force, y)
+
+ est_E, est_nu, est_f, est_d = inferrer.estimate_params(model_name, motion_obs.to(device), init_pc, force, mask, drag_point, floor_height, coeff, y=y, cfg=1.0, gravity=gravity, probe=probe, num_steps=num_steps)
+ # print(f'EST_{model_name}', F.mse_loss(est_E, E), F.mse_loss(est_nu, nu), F.mse_loss(est_d[..., :3], drag_point[..., :3]), F.mse_loss(est_f, force))
+ est_E = ','.join([f'{e:.3f}' for e in est_E]) if isinstance(est_E, list) else est_E.item()
+ print(est_E)
+ with open(os.path.join('./debug', f'output_probe{probe}_steps{num_steps}.txt'), 'a+') as f:
+ f.write(f'{model_name},{E.item()},{est_E},{nu.item()},{est_nu.item()},{drag_point.cpu().numpy()},{est_d.cpu().numpy()},{force.cpu().numpy()},{est_f.cpu().numpy()}\n')
+ # break
\ No newline at end of file
diff --git a/src/utils/seeding.py b/src/utils/seeding.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b5faa8ebb931d40c0421a0ad81250478958d66b
--- /dev/null
+++ b/src/utils/seeding.py
@@ -0,0 +1,16 @@
+import random
+import numpy as np
+import torch
+import hashlib
+
+def seed_everything(seed):
+ """Set random seeds for Python, NumPy, and PyTorch based on a string."""
+ # Convert string to an integer hash
+ if isinstance(seed, str):
+ seed = int(hashlib.md5(seed.encode()).hexdigest(), 16) % (2**32) # 32-bit seed
+
+ # Set seeds
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
\ No newline at end of file
diff --git a/src/utils/sim_utils.py b/src/utils/sim_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd2e4a2f9cfca960695ec4243c633e37aa2783ba
--- /dev/null
+++ b/src/utils/sim_utils.py
@@ -0,0 +1,39 @@
+import taichi as ti
+import torch
+
+@ti.kernel
+def assign_particle_to_grid(pos: ti.template(), grid: ti.template(), grid_dx: float):
+ for pi in range(pos.shape[0]):
+ p = pos[pi]
+ i = ti.floor(p[0] / grid_dx, dtype=int)
+ j = ti.floor(p[1] / grid_dx, dtype=int)
+ k = ti.floor(p[2] / grid_dx, dtype=int)
+ ti.atomic_add(grid[i, j, k], 1)
+
+@ti.kernel
+def compute_particle_volume(
+ pos: ti.template(), grid: ti.template(), particle_vol: ti.template(), grid_dx: float
+):
+ for pi in range(pos.shape[0]):
+ p = pos[pi]
+ i = ti.floor(p[0] / grid_dx, dtype=int)
+ j = ti.floor(p[1] / grid_dx, dtype=int)
+ k = ti.floor(p[2] / grid_dx, dtype=int)
+ particle_vol[pi] = (grid_dx * grid_dx * grid_dx) / grid[i, j, k]
+
+def get_particle_volume(pos, grid_n: int, grid_dx: float, uniform: bool = False):
+ ti_pos = ti.Vector.field(n=3, dtype=float, shape=pos.shape[0])
+ ti_pos.from_torch(pos.reshape(-1, 3))
+
+ grid = ti.field(dtype=int, shape=(grid_n, grid_n, grid_n))
+ particle_vol = ti.field(dtype=float, shape=pos.shape[0])
+
+ assign_particle_to_grid(ti_pos, grid, grid_dx)
+ compute_particle_volume(ti_pos, grid, particle_vol, grid_dx)
+
+ if uniform:
+ vol = particle_vol.to_torch()
+ vol = torch.mean(vol).repeat(pos.shape[0])
+ return vol
+ else:
+ return particle_vol.to_torch()
\ No newline at end of file
diff --git a/src/utils/track_utils/preprocessing.py b/src/utils/track_utils/preprocessing.py
new file mode 100644
index 0000000000000000000000000000000000000000..d5b58cdee376b5c7ce836b48ed0375ad155119ce
--- /dev/null
+++ b/src/utils/track_utils/preprocessing.py
@@ -0,0 +1,81 @@
+import numpy as np
+import torch
+
+from PIL import Image
+
+def project_to_image(points_camera, intrinsic_matrix):
+ """
+ Project 3D points in camera coordinates to 2D image plane.
+ :param points_camera: Nx3 array of 3D points in camera coordinates.
+ :param intrinsic_matrix: 3x3 camera intrinsic matrix.
+ :return: Nx2 array of 2D pixel coordinates.
+ """
+ # Get homogeneous image coordinates
+ points_image_h = intrinsic_matrix @ points_camera.T # 3xN
+ # Normalize to get 2D pixel coordinates
+ points_image = points_image_h[:2, :] / points_image_h[2, :]
+ return points_image.T # Nx2
+
+def get_pad(image, target_width=720):
+
+ # Get the current size
+ if image.ndim == 2: # Grayscale image
+ _, width = image.shape
+ channels = None
+ elif image.ndim == 3: # RGB or RGBA image
+ _, width, channels = image.shape
+ else:
+ raise ValueError("Input image must be 2D or 3D (grayscale, RGB, or RGBA).")
+
+ # Desired size
+ target_width = 720
+
+ # Calculate padding
+ padding_left = (target_width - width) // 2
+ padding_right = target_width - width - padding_left
+
+ # Apply padding
+ if channels: # RGB or RGBA image
+ padded_image = np.pad(
+ image,
+ pad_width=((0, 0), (padding_left, padding_right), (0, 0)),
+ mode='constant',
+ constant_values=2
+ )
+ else: # Grayscale image
+ padded_image = np.pad(
+ image,
+ pad_width=((0, 0), (padding_left, padding_right)),
+ mode='constant',
+ constant_values=2
+ )
+
+ return padded_image
+
+def find_and_remove_nearest_point(target, candidates, dense=False):
+
+ offset_x = 5.0543
+ offset_y = 3.3152
+
+ x_min, x_max = target[0] - offset_x, target[0] + offset_x
+ y_min, y_max = target[1] - offset_y, target[1] + offset_y
+ satisfied_idx = np.where((candidates[:, 0] >= x_min) & (candidates[:, 0] <= x_max) & (candidates[:, 1] >= y_min) & (candidates[:, 1] <= y_max))[0]
+ if satisfied_idx.shape[0] == 0:
+ return None, candidates
+
+ satisfied_candidates = candidates[satisfied_idx]
+ distance = np.linalg.norm(satisfied_candidates[:, :2] - target, axis=1)
+ min_idx = np.argmin(distance)
+ candidate = satisfied_candidates[min_idx]
+ kept_idx = np.where(candidates[:, -1] != candidate[-1])
+ updated_candidates = candidates[kept_idx]
+ return candidate, updated_candidates
+
+def track_first(projected_points, image_shape):
+
+ candidate_list = []
+ # Fill image with XYZ values
+ for i, (x, y, z) in enumerate(projected_points):
+ candidate_list.append(np.array([x, y, z, i]))
+ candidate_list = np.stack(candidate_list)
+ return candidate_list
\ No newline at end of file
diff --git a/src/utils/track_utils/visualize_tracks.py b/src/utils/track_utils/visualize_tracks.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a8c0a741fc82697ad5ad940b93cb2817441f91e
--- /dev/null
+++ b/src/utils/track_utils/visualize_tracks.py
@@ -0,0 +1,46 @@
+import os
+import argparse
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+from .visualizer import Visualizer
+
+def create_white_video(num_frames, target_h=480, target_w=720):
+ white_video = torch.ones((1, num_frames, 3, target_h, target_w))
+ return white_video
+
+def process_video(tracks_path, output_dir, args):
+
+ video_name = os.path.splitext(os.path.basename(tracks_path))[0].replace('_tracks', '')
+ video = create_white_video(args.num_frames)
+
+ combined_data = np.load(tracks_path, allow_pickle=True).item()
+ tracks = torch.from_numpy(combined_data['tracks'])
+ visibility = torch.from_numpy(combined_data['visibility'])
+
+ vis = Visualizer(
+ save_dir=output_dir,
+ grayscale=False,
+ fps=args.output_fps,
+ pad_value=0,
+ linewidth=args.point_size,
+ tracks_leave_trace=args.len_track
+ )
+
+ video_vis = vis.visualize(
+ video=video,
+ tracks=tracks,
+ visibility=visibility,
+ filename=video_name
+ )
+
+def visualize_tracks(tracks_dir, output_dir, args):
+
+ args.tracks_dir = tracks_dir
+
+ os.makedirs(output_dir, exist_ok=True)
+ tracks_files = [f for f in os.listdir(args.tracks_dir) if f.endswith('tracks.npy')]
+ for tracks_file in tracks_files:
+ tracks_path = os.path.join(args.tracks_dir, tracks_file)
+ process_video(tracks_path, output_dir, args)
\ No newline at end of file
diff --git a/src/utils/track_utils/visualizer.py b/src/utils/track_utils/visualizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..bf0c4ae6dfe5d52212e9a2bc4d20f855faa204e7
--- /dev/null
+++ b/src/utils/track_utils/visualizer.py
@@ -0,0 +1,406 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import numpy as np
+import cv2
+import torch
+import flow_vis
+
+from matplotlib import cm
+import torch.nn.functional as F
+import torchvision.transforms as transforms
+import pkg_resources
+try:
+ moviepy_version = pkg_resources.get_distribution("moviepy").version
+ if moviepy_version >= "2.0.0":
+ from moviepy import ImageSequenceClip
+ else:
+ from moviepy.editor import ImageSequenceClip
+except ImportError:
+ raise ImportError("moviepy is required for video processing. Please install it with: pip install moviepy==1.0.3")
+import matplotlib.pyplot as plt
+
+
+def read_video_from_path(path):
+ cap = cv2.VideoCapture(path)
+ if not cap.isOpened():
+ print("Error opening video file")
+ else:
+ frames = []
+ while cap.isOpened():
+ ret, frame = cap.read()
+ if ret == True:
+ frames.append(np.array(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)))
+ else:
+ break
+ cap.release()
+ return np.stack(frames)
+
+
+class Visualizer:
+ def __init__(
+ self,
+ save_dir: str = "./results",
+ grayscale: bool = False,
+ pad_value: int = 0,
+ fps: int = 10,
+ mode: str = "rainbow", # 'cool', 'optical_flow'
+ linewidth: int = 1,
+ show_first_frame: int = 10,
+ tracks_leave_trace: int = 0, # -1 for infinite
+ ):
+ self.mode = mode
+ self.save_dir = save_dir
+ self.vtxt_path = os.path.join(save_dir, "videos.txt")
+ self.ttxt_path = os.path.join(save_dir, "trackings.txt")
+ if mode == "rainbow":
+ self.color_map = cm.get_cmap("gist_rainbow")
+ elif mode == "cool":
+ self.color_map = cm.get_cmap(mode)
+ self.show_first_frame = show_first_frame
+ self.grayscale = grayscale
+ self.tracks_leave_trace = tracks_leave_trace
+ self.pad_value = pad_value
+ self.linewidth = linewidth
+ self.fps = fps
+
+ def visualize(
+ self,
+ video: torch.Tensor, # (B,T,C,H,W)
+ tracks: torch.Tensor, # (B,T,N,2)
+ visibility: torch.Tensor = None, # (B, T, N, 1) bool
+ gt_tracks: torch.Tensor = None, # (B,T,N,2)
+ segm_mask: torch.Tensor = None, # (B,1,H,W)
+ filename: str = "video",
+ writer=None, # tensorboard Summary Writer, used for visualization during training
+ step: int = 0,
+ query_frame: int = 0,
+ save_video: bool = True,
+ compensate_for_camera_motion: bool = False,
+ rigid_part = None,
+ video_depth = None # (B,T,C,H,W)
+ ):
+ if compensate_for_camera_motion:
+ assert segm_mask is not None
+ if segm_mask is not None:
+ coords = tracks[0, query_frame].round().long()
+ segm_mask = segm_mask[0, query_frame][coords[:, 1], coords[:, 0]].long()
+
+ video = F.pad(
+ video,
+ (self.pad_value, self.pad_value, self.pad_value, self.pad_value),
+ "constant",
+ 255,
+ )
+
+ if video_depth is not None:
+ video_depth = (video_depth*255).cpu().numpy().astype(np.uint8)
+ video_depth = ([cv2.applyColorMap(video_depth[0,i,0], cv2.COLORMAP_INFERNO)
+ for i in range(video_depth.shape[1])])
+ video_depth = np.stack(video_depth, axis=0)
+ video_depth = torch.from_numpy(video_depth).permute(0, 3, 1, 2)[None]
+
+ tracks = tracks + self.pad_value
+
+ if self.grayscale:
+ transform = transforms.Grayscale()
+ video = transform(video)
+ video = video.repeat(1, 1, 3, 1, 1)
+
+ tracking_video = self.draw_tracks_on_video(
+ video=video,
+ tracks=tracks,
+ visibility=visibility,
+ segm_mask=segm_mask,
+ gt_tracks=gt_tracks,
+ query_frame=query_frame,
+ compensate_for_camera_motion=compensate_for_camera_motion,
+ rigid_part=rigid_part
+ )
+
+ if save_video:
+ tracking_dir = os.path.join(self.save_dir, "tracking")
+ if not os.path.exists(tracking_dir):
+ os.makedirs(tracking_dir)
+ self.save_video(tracking_video, filename=filename+"_tracking",
+ savedir=tracking_dir, writer=writer, step=step)
+ return tracking_video
+
+ def save_video(self, video, filename, savedir=None, writer=None, step=0):
+ if writer is not None:
+ writer.add_video(
+ f"{filename}",
+ video.to(torch.uint8),
+ global_step=step,
+ fps=self.fps,
+ )
+ else:
+ os.makedirs(self.save_dir, exist_ok=True)
+ wide_list = list(video.unbind(1))
+ wide_list = [wide[0].permute(1, 2, 0).cpu().numpy() for wide in wide_list]
+ clip = ImageSequenceClip(wide_list, fps=self.fps)
+
+ # Write the video file
+ if savedir is None:
+ save_path = os.path.join(self.save_dir, f"{filename}.mp4")
+ else:
+ save_path = os.path.join(savedir, f"{filename}.mp4")
+ clip.write_videofile(save_path, codec="libx264", fps=self.fps, logger=None)
+
+ print(f"Video saved to {save_path}")
+
+ def draw_tracks_on_video(
+ self,
+ video: torch.Tensor,
+ tracks: torch.Tensor,
+ visibility: torch.Tensor = None,
+ segm_mask: torch.Tensor = None,
+ gt_tracks=None,
+ query_frame: int = 0,
+ compensate_for_camera_motion=False,
+ rigid_part=None,
+ ):
+ B, T, C, H, W = video.shape
+ _, _, N, D = tracks.shape
+
+ assert D == 3
+ assert C == 3
+ video = video[0].permute(0, 2, 3, 1).byte().detach().cpu().numpy() # S, H, W, C
+ tracks = tracks[0].detach().cpu().numpy() # S, N, 2
+ if gt_tracks is not None:
+ gt_tracks = gt_tracks[0].detach().cpu().numpy()
+
+ res_video = []
+
+ # create a blank tensor with the same shape as the video
+ for rgb in video:
+ black_frame = np.zeros_like(rgb.copy(), dtype=rgb.dtype)
+ res_video.append(black_frame)
+
+ vector_colors = np.zeros((T, N, 3))
+
+ if self.mode == "optical_flow":
+
+ vector_colors = flow_vis.flow_to_color(tracks - tracks[query_frame][None])
+
+ elif segm_mask is None:
+ if self.mode == "rainbow":
+ x_min, x_max = tracks[0, :, 0].min(), tracks[0, :, 0].max()
+ y_min, y_max = tracks[0, :, 1].min(), tracks[0, :, 1].max()
+
+ z_inv = 1/tracks[0, :, 2]
+ # z_min, z_max = np.percentile(z_inv, [2, 98])
+ z_min, z_max = np.percentile(z_inv, [0.5, 99.9])
+ # print(tracks[0, :, 2].min(), tracks[0, :, 2].max())
+ # print(z_inv.min(), z_inv.max())
+ # print(z_min, z_max)
+
+ norm_x = plt.Normalize(x_min, x_max)
+ norm_y = plt.Normalize(y_min, y_max)
+ norm_z = plt.Normalize(z_min, z_max)
+
+ for n in range(N):
+ r = norm_x(tracks[0, n, 0])
+ g = norm_y(tracks[0, n, 1])
+ # r = 0
+ # g = 0
+ b = norm_z(1/tracks[0, n, 2])
+ color = np.array([r, g, b])[None] * 255
+ vector_colors[:, n] = np.repeat(color, T, axis=0)
+ else:
+ # color changes with time
+ for t in range(T):
+ color = np.array(self.color_map(t / T)[:3])[None] * 255
+ vector_colors[t] = np.repeat(color, N, axis=0)
+ else:
+ if self.mode == "rainbow":
+ vector_colors[:, segm_mask <= 0, :] = 255
+
+ x_min, x_max = tracks[0, :, 0].min(), tracks[0, :, 0].max()
+ y_min, y_max = tracks[0, :, 1].min(), tracks[0, :, 1].max()
+ z_min, z_max = tracks[0, :, 2].min(), tracks[0, :, 2].max()
+
+ norm_x = plt.Normalize(x_min, x_max)
+ norm_y = plt.Normalize(y_min, y_max)
+ norm_z = plt.Normalize(z_min, z_max)
+
+ for n in range(N):
+ r = norm_x(tracks[0, n, 0])
+ g = norm_y(tracks[0, n, 1])
+ b = norm_z(tracks[0, n, 2])
+ color = np.array([r, g, b])[None] * 255
+ vector_colors[:, n] = np.repeat(color, T, axis=0)
+
+ else:
+ # color changes with segm class
+ segm_mask = segm_mask.cpu()
+ color = np.zeros((segm_mask.shape[0], 3), dtype=np.float32)
+ color[segm_mask > 0] = np.array(self.color_map(1.0)[:3]) * 255.0
+ color[segm_mask <= 0] = np.array(self.color_map(0.0)[:3]) * 255.0
+ vector_colors = np.repeat(color[None], T, axis=0)
+
+ # Draw tracks
+ # print("Start drawing tracks")
+
+ if self.tracks_leave_trace != 0:
+ for t in range(1, T):
+ first_ind = (
+ max(0, t - self.tracks_leave_trace)
+ if self.tracks_leave_trace >= 0
+ else 0
+ )
+ curr_tracks = tracks[first_ind : t + 1]
+ curr_colors = vector_colors[first_ind : t + 1]
+ if compensate_for_camera_motion:
+ diff = (
+ tracks[first_ind : t + 1, segm_mask <= 0]
+ - tracks[t : t + 1, segm_mask <= 0]
+ ).mean(1)[:, None]
+
+ curr_tracks = curr_tracks - diff
+ curr_tracks = curr_tracks[:, segm_mask > 0]
+ curr_colors = curr_colors[:, segm_mask > 0]
+
+ res_video[t] = self._draw_pred_tracks(
+ res_video[t],
+ curr_tracks,
+ curr_colors,
+ )
+ if gt_tracks is not None:
+ res_video[t] = self._draw_gt_tracks(
+ res_video[t], gt_tracks[first_ind : t + 1]
+ )
+
+ if rigid_part is not None:
+ cls_label = torch.unique(rigid_part)
+ cls_num = len(torch.unique(rigid_part))
+ # visualize the clustering results
+ cmap = plt.get_cmap('jet') # get the color mapping
+ colors = cmap(np.linspace(0, 1, cls_num))
+ colors = (colors[:, :3] * 255)
+ color_map = {lable.item(): color for lable, color in zip(cls_label, colors)}
+
+ # Draw points
+ for t in range(T):
+
+ # Create a list to store information for each point
+ # print(f"Drawing frame {t}")
+ points_info = []
+ for i in range(N):
+ coord = (tracks[t, i, 0], tracks[t, i, 1])
+ depth = tracks[t, i, 2] # assume the third dimension is depth
+ visibile = True
+ if visibility is not None:
+ visibile = visibility[0, t, i]
+ if coord[0] != 0 and coord[1] != 0:
+ if not compensate_for_camera_motion or (
+ compensate_for_camera_motion and segm_mask[i] > 0
+ ):
+ points_info.append((i, coord, depth, visibile))
+
+ # Sort points by depth, points with smaller depth (closer) will be drawn later
+ points_info.sort(key=lambda x: x[2], reverse=True)
+
+ for i, coord, _, visibile in points_info:
+ if rigid_part is not None:
+ color = color_map[rigid_part.squeeze()[i].item()]
+ cv2.circle(
+ res_video[t],
+ coord,
+ int(self.linewidth * 2),
+ color.tolist(),
+ thickness=-1 if visibile else 2
+ -1,
+ )
+ else:
+ # Determine rectangle width based on the distance between adjacent tracks in the first frame
+ if t == 0:
+ distances = np.linalg.norm(tracks[0] - tracks[0, i], axis=1)
+ distances = distances[distances > 0]
+ rect_size = int(np.min(distances))/2
+
+ # Define coordinates for top-left and bottom-right corners of the rectangle
+ top_left = (int(coord[0] - rect_size), int(coord[1] - rect_size/1.5)) # Rectangle width is 1.5x (video aspect ratio is 1.5:1)
+ bottom_right = (int(coord[0] + rect_size), int(coord[1] + rect_size/1.5))
+
+ # Draw rectangle
+ cv2.rectangle(
+ res_video[t],
+ top_left,
+ bottom_right,
+ vector_colors[t, i].tolist(),
+ thickness=-1 if visibile else 0
+ -1,
+ )
+
+ # Construct the final rgb sequence
+ return torch.from_numpy(np.stack(res_video)).permute(0, 3, 1, 2)[None].byte()
+
+ def _draw_pred_tracks(
+ self,
+ rgb: np.ndarray, # H x W x 3
+ tracks: np.ndarray, # T x 2
+ vector_colors: np.ndarray,
+ alpha: float = 0.5,
+ ):
+ T, N, _ = tracks.shape
+
+ for s in range(T - 1):
+ vector_color = vector_colors[s]
+ original = rgb.copy()
+ alpha = (s / T) ** 2
+ for i in range(N):
+ coord_y = (int(tracks[s, i, 0]), int(tracks[s, i, 1]))
+ coord_x = (int(tracks[s + 1, i, 0]), int(tracks[s + 1, i, 1]))
+ if coord_y[0] != 0 and coord_y[1] != 0:
+ cv2.line(
+ rgb,
+ coord_y,
+ coord_x,
+ vector_color[i].tolist(),
+ self.linewidth,
+ cv2.LINE_AA,
+ )
+ if self.tracks_leave_trace > 0:
+ rgb = cv2.addWeighted(rgb, alpha, original, 1 - alpha, 0)
+ return rgb
+
+ def _draw_gt_tracks(
+ self,
+ rgb: np.ndarray, # H x W x 3,
+ gt_tracks: np.ndarray, # T x 2
+ ):
+ T, N, _ = gt_tracks.shape
+ color = np.array((211.0, 0.0, 0.0))
+
+ for t in range(T):
+ for i in range(N):
+ gt_tracks = gt_tracks[t][i]
+ # draw a red cross
+ if gt_tracks[0] > 0 and gt_tracks[1] > 0:
+ length = self.linewidth * 3
+ coord_y = (int(gt_tracks[0]) + length, int(gt_tracks[1]) + length)
+ coord_x = (int(gt_tracks[0]) - length, int(gt_tracks[1]) - length)
+ cv2.line(
+ rgb,
+ coord_y,
+ coord_x,
+ color,
+ self.linewidth,
+ cv2.LINE_AA,
+ )
+ coord_y = (int(gt_tracks[0]) - length, int(gt_tracks[1]) + length)
+ coord_x = (int(gt_tracks[0]) + length, int(gt_tracks[1]) - length)
+ cv2.line(
+ rgb,
+ coord_y,
+ coord_x,
+ color,
+ self.linewidth,
+ cv2.LINE_AA,
+ )
+ return rgb
diff --git a/src/utils/transform.py b/src/utils/transform.py
new file mode 100644
index 0000000000000000000000000000000000000000..21902f63863901196094914894c9f3097cbf84d1
--- /dev/null
+++ b/src/utils/transform.py
@@ -0,0 +1,19 @@
+import torch
+import numpy as np
+
+def transform2origin(v, size=1):
+ bmax = v.max(axis=0)
+ bmin = v.min(axis=0)
+ aabb = bmax - bmin
+ center = (bmax + bmin) / 2
+ scale = size / (aabb.max() * 0.5)
+ new_v = (v - center) * scale
+ return new_v, center, scale
+
+def shift2center_th(position_tensor, center=[5, 5, 5]):
+ tensor = torch.tensor(center, dtype=torch.float32, device=position_tensor.device).contiguous()
+ return position_tensor + tensor
+
+def shift2center(position_tensor, center=[5, 5, 5]):
+ tensor = np.array(center)
+ return position_tensor + tensor
\ No newline at end of file
diff --git a/src/utils/ui_utils.py b/src/utils/ui_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..5bf7b7f81e20b4303fe9bca8080c608d73981652
--- /dev/null
+++ b/src/utils/ui_utils.py
@@ -0,0 +1,82 @@
+import gradio as gr
+from PIL import Image
+import numpy as np
+
+from copy import deepcopy
+import cv2
+import plotly.graph_objects as go
+
+
+
+def mask_image(image,
+ mask,
+ color=[255,0,0],
+ alpha=0.5):
+ """ Overlay mask on image for visualization purpose.
+ Args:
+ image (H, W, 3) or (H, W): input image
+ mask (H, W): mask to be overlaid
+ color: the color of overlaid mask
+ alpha: the transparency of the mask
+ """
+ out = deepcopy(image)
+ img = deepcopy(image)
+ img[mask == 1] = color
+ out = cv2.addWeighted(img, alpha, out, 1-alpha, 0, out)
+ return out
+
+def image_preprocess(input_image, target_res, lower_contrast=True, rescale=True):
+ image_arr = np.array(input_image)
+ in_w, in_h = image_arr.shape[:2]
+
+ if lower_contrast:
+ alpha = 0.8 # Contrast control (1.0-3.0)
+ beta = 0 # Brightness control (0-100)
+ # Apply the contrast adjustment
+ image_arr = cv2.convertScaleAbs(image_arr, alpha=alpha, beta=beta)
+ image_arr[image_arr[..., -1] > 200, -1] = 255
+
+ ret, mask = cv2.threshold(
+ np.array(input_image.split()[-1]), 0, 255, cv2.THRESH_BINARY
+ )
+ x, y, w, h = cv2.boundingRect(mask)
+ max_size = max(w, h)
+ ratio = 0.75
+ if rescale:
+ side_len = int(max_size / ratio)
+ else:
+ side_len = in_w
+ padded_image = np.zeros((side_len, side_len, 4), dtype=np.uint8)
+ center = side_len // 2
+ padded_image[
+ center - h // 2 : center - h // 2 + h, center - w // 2 : center - w // 2 + w
+ ] = image_arr[y : y + h, x : x + w]
+ rgba = Image.fromarray(padded_image).resize((target_res, target_res), Image.LANCZOS)
+ return y + h // 2 - center, x + w // 2 - center, side_len, rgba
+
+def plot_point_cloud(points, arrows):
+ scatter = go.Scatter3d(
+ x=points[:, 0],
+ y=points[:, 1],
+ z=points[:, 2],
+ mode='markers',
+ marker=dict(size=2, color='blue'),
+ name='Point Cloud'
+ )
+
+ cone_traces = []
+ for arrow in arrows:
+ origin = arrow['origin']
+ direction = arrow['dir']
+ cone = go.Cone(
+ x=[origin[0]], y=[origin[1]], z=[origin[2]],
+ u=[direction[0]], v=[direction[1]], w=[direction[2]],
+ sizemode='raw', sizeref=5,
+ anchor='tail', colorscale='Reds', showscale=False,
+ name='Arrow'
+ )
+ cone_traces.append(cone)
+
+ fig = go.Figure(data=[scatter] + cone_traces)
+ fig.update_layout(scene=dict(aspectmode='data'), margin=dict(l=0, r=0, t=0, b=0), height=400)
+ return fig
diff --git a/src/utils/visualization.py b/src/utils/visualization.py
new file mode 100644
index 0000000000000000000000000000000000000000..85e126d7cc1d55754c2472e6a62b97e4ad0f582c
--- /dev/null
+++ b/src/utils/visualization.py
@@ -0,0 +1,318 @@
+import numpy as np
+import matplotlib.pyplot as plt
+import os
+import json
+
+from PIL import Image
+from io import BytesIO
+import sys, pathlib, html
+
+def camera_view_dir_y(elev, azim):
+ """Unit vector for camera direction with Y as 'up'."""
+ elev_rad = np.radians(elev)
+ azim_rad = np.radians(azim)
+ dx = np.sin(azim_rad) * np.cos(elev_rad)
+ dy = np.sin(elev_rad)
+ dz = np.cos(azim_rad) * np.cos(elev_rad)
+ return np.array([dx, dy, dz])
+
+def compute_depth(points, elev, azim):
+ """Project points onto the camera's view direction (Y as 'up')."""
+ view_dir = camera_view_dir_y(elev, azim)
+ # depth = p · view_dir
+ depth = points @ view_dir
+ return depth
+
+def save_pointcloud_video(points_pred, points_gt, save_path, drag_mask=None, fps=48, point_color='blue', vis_flag=''):
+
+ # Configure the figure
+ fig = plt.figure(figsize=(6, 6))
+ ax = fig.add_subplot(111, projection='3d')
+ ax.set_box_aspect([1, 1, 1])
+
+ if 'objaverse' in vis_flag:
+ x_max, y_max, z_max = 1.5, 1.5, 1.5
+ x_min, y_min, z_min = -1.5, -1.5, -1.5
+ else:
+ x_max, y_max, z_max = 1, 1, 1
+ x_min, y_min, z_min = -1, -1, -1
+
+ if 'shapenet' or 'objaverse' in vis_flag:
+ elev, azim = 45, 225
+
+ ax.view_init(elev=elev, azim=azim, vertical_axis='y')
+
+ # Plot and save each frame
+ cmap_1 = plt.colormaps.get_cmap('cool')
+ cmap_2 = plt.colormaps.get_cmap('autumn')
+ frames_pred = []
+ frames_gt = []
+
+ if drag_mask is not None and drag_mask.sum() == 0:
+ drag_mask = None
+
+ for label, points in [('pred', points_pred), ('gt', points_gt)]:
+
+ for i in range(points.shape[0]):
+
+ frame_points = points[i]
+ if drag_mask is not None and not (drag_mask == True).all():
+ drag_mask = (drag_mask == 1.0)
+ drag_points = frame_points[drag_mask]
+ frame_points = frame_points[~drag_mask]
+
+ depth_frame_points = compute_depth(frame_points, elev=elev, azim=azim)
+ depth_frame_points_normalized = (depth_frame_points - depth_frame_points.min()) / \
+ (depth_frame_points.max() - depth_frame_points.min())
+ color_frame_points = cmap_1(depth_frame_points_normalized)
+
+ if drag_mask is not None and not (drag_mask == True).all():
+ frame_points_drag = drag_points
+ depth_frame_points_drag = compute_depth(frame_points_drag, elev=elev, azim=azim)
+ depth_frame_points_drag_normalized = (depth_frame_points_drag - depth_frame_points_drag.min()) / \
+ (depth_frame_points_drag.max() - depth_frame_points_drag.min())
+ color_frame_points_drag = cmap_2(np.ones_like(depth_frame_points_drag_normalized) * -10)
+ all_points = np.concatenate([frame_points, frame_points_drag], axis=0)
+ all_color = np.concatenate([color_frame_points, color_frame_points_drag], axis=0)
+ else:
+ all_points, all_color = frame_points, color_frame_points
+
+
+ ax.clear()
+ ax.scatter(all_points[:, 0], all_points[:, 1], all_points[:, 2], c=all_color, s=1, depthshade=False)
+
+ ax.axis('off') # Turn off the axes
+ ax.grid(False) # Hide the grid
+
+ # Set equal aspect ratio
+ ax.set_xlim(x_min, x_max)
+ ax.set_ylim(y_min, y_max)
+ ax.set_zlim(z_min, z_max)
+
+ # Adjust margins for tight layout
+ plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
+
+ # Save frame
+ buf = BytesIO()
+ plt.savefig(buf, bbox_inches='tight', pad_inches=0.0, dpi=300)
+ buf.seek(0)
+
+ if label == 'pred':
+ frames_pred.append(Image.open(buf))
+ else:
+ frames_gt.append(Image.open(buf))
+
+ plt.close()
+ frames = []
+ for i in range(len(frames_pred)):
+ frame = np.concatenate([np.array(frames_pred[i]), np.array(frames_gt[i])], axis=1)
+ frames.append(Image.fromarray(frame))
+ frames[0].save(save_path, save_all=True, append_images=frames[1:], fps=fps, loop=0)
+
+def save_pointcloud_json(points, output_json):
+ """
+ Generate and save a point cloud sequence to a JSON file.
+
+ Parameters:
+ num_frames (int): Number of frames in the sequence.
+ num_points (int): Number of points per frame.
+ output_json (str): Output JSON file path.
+ """
+ sequence = []
+ for frame in range(points.shape[0]):
+ # points = np.random.uniform(-1.5, 1.5, size=(num_points, 3)).tolist()
+ sequence.append({"frame": frame, "points": points[frame].tolist()})
+
+ # Save the sequence to a JSON file
+ with open(output_json, "w") as json_file:
+ json.dump({"sequence": sequence}, json_file)
+
+def save_threejs_html(path1, path2, output_html):
+ html_template = f"""
+
+
+
+
+
+
+ Three.js Point Cloud Animation
+
+
+
+
+
+
+
+
+
+"""
+ with open(output_html, 'w') as file:
+ file.write(html_template)
+
+def vis_pcl_grid(point_clouds, save_path, grid_shape=(2, 2), bounds=[[0, 0, 0], [3, 3, 3]]):
+ """
+ Visualizes multiple 3D point clouds in a grid.
+
+ Args:
+ point_clouds (list of np.ndarray): A list of point clouds, each as a (N, 3) numpy array.
+ grid_shape (tuple): Shape of the grid (rows, cols).
+ """
+ rows, cols = grid_shape
+ fig = plt.figure(figsize=(5 * cols, 5 * rows))
+
+ for i, point_cloud in enumerate(point_clouds):
+ if i >= rows * cols:
+ break # Prevent overpopulation of the grid
+
+ ax = fig.add_subplot(rows, cols, i + 1, projection='3d')
+ ax.view_init(elev=0, azim=90)
+ ax.scatter(point_cloud[:, 0], point_cloud[:, 1], point_cloud[:, 2], s=1)
+
+ if bounds == None:
+ ax.set_xlim()
+ else:
+ ax.set_xlim(bounds[0][0], bounds[1][0])
+ ax.set_ylim(bounds[0][1], bounds[1][1])
+ ax.set_zlim(bounds[0][2], bounds[1][2])
+
+ ax.set_title(f"Point Cloud {i + 1}")
+ ax.set_xlabel("X")
+ ax.set_ylabel("Y")
+ ax.set_zlabel("Z")
+ ax.grid(False)
+
+ plt.tight_layout()
+ plt.savefig(save_path, bbox_inches='tight', pad_inches=0.0, dpi=300)
+ plt.close()
+
+def generate_html_from_exts(vis_dir, output_html, exts):
+ gifs = sorted(pathlib.Path(vis_dir).glob(f'*.{exts}'))
+ rows = [
+ "",
+ "",
+ "",
+ " ",
+ " GIF gallery",
+ " ",
+ "",
+ "",
+ ]
+
+ # 4) one
per gif with caption
+ for gif in gifs:
+ name = gif.name # full file name (incl. .gif)
+ alt = html.escape(gif.stem) # alt text sans extension
+ rows.append(
+ f"
per gif with caption
+ for i, gif in enumerate(gif_paths):
+ name = gif # full file name (incl. .gif)
+ alt = html.escape(gif) # alt text sans extension
+ rows.append(
+ f"