Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	init
Browse files- .gitattributes +1 -0
 - .gitignore +142 -0
 - LICENSE +201 -0
 - app.py +372 -0
 - dreamo/dreamo_pipeline.py +466 -0
 - dreamo/transformer.py +187 -0
 - dreamo/utils.py +222 -0
 - example_inputs/cat.png +3 -0
 - example_inputs/dog1.png +3 -0
 - example_inputs/dog2.png +3 -0
 - example_inputs/dress.png +3 -0
 - example_inputs/hinton.jpeg +3 -0
 - example_inputs/man1.png +3 -0
 - example_inputs/man2.jpeg +3 -0
 - example_inputs/mickey.png +3 -0
 - example_inputs/mountain.png +3 -0
 - example_inputs/perfume.png +3 -0
 - example_inputs/shirt.png +3 -0
 - example_inputs/skirt.jpeg +3 -0
 - example_inputs/toy1.png +3 -0
 - example_inputs/woman1.png +3 -0
 - example_inputs/woman2.png +3 -0
 - example_inputs/woman3.png +3 -0
 - example_inputs/woman4.jpeg +3 -0
 - models/.gitkeep +0 -0
 - pyproject.toml +29 -0
 - requirements.txt +12 -0
 - tools/BEN2.py +1359 -0
 
    	
        .gitattributes
    CHANGED
    
    | 
         @@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text 
     | 
|
| 33 | 
         
             
            *.zip filter=lfs diff=lfs merge=lfs -text
         
     | 
| 34 | 
         
             
            *.zst filter=lfs diff=lfs merge=lfs -text
         
     | 
| 35 | 
         
             
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         
     | 
| 
         | 
| 
         | 
|
| 33 | 
         
             
            *.zip filter=lfs diff=lfs merge=lfs -text
         
     | 
| 34 | 
         
             
            *.zst filter=lfs diff=lfs merge=lfs -text
         
     | 
| 35 | 
         
             
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         
     | 
| 36 | 
         
            +
            example_inputs/* filter=lfs diff=lfs merge=lfs -text
         
     | 
    	
        .gitignore
    ADDED
    
    | 
         @@ -0,0 +1,142 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            datasets/*
         
     | 
| 2 | 
         
            +
            experiments/*
         
     | 
| 3 | 
         
            +
            results/*
         
     | 
| 4 | 
         
            +
            tb_logger/*
         
     | 
| 5 | 
         
            +
            wandb/*
         
     | 
| 6 | 
         
            +
            tmp/*
         
     | 
| 7 | 
         
            +
            weights/*
         
     | 
| 8 | 
         
            +
            inputs/*
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            *.DS_Store
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            # Byte-compiled / optimized / DLL files
         
     | 
| 13 | 
         
            +
            __pycache__/
         
     | 
| 14 | 
         
            +
            *.py[cod]
         
     | 
| 15 | 
         
            +
            *$py.class
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            # C extensions
         
     | 
| 18 | 
         
            +
            *.so
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            # Distribution / packaging
         
     | 
| 21 | 
         
            +
            .Python
         
     | 
| 22 | 
         
            +
            build/
         
     | 
| 23 | 
         
            +
            develop-eggs/
         
     | 
| 24 | 
         
            +
            dist/
         
     | 
| 25 | 
         
            +
            downloads/
         
     | 
| 26 | 
         
            +
            eggs/
         
     | 
| 27 | 
         
            +
            .eggs/
         
     | 
| 28 | 
         
            +
            lib/
         
     | 
| 29 | 
         
            +
            lib64/
         
     | 
| 30 | 
         
            +
            parts/
         
     | 
| 31 | 
         
            +
            sdist/
         
     | 
| 32 | 
         
            +
            var/
         
     | 
| 33 | 
         
            +
            wheels/
         
     | 
| 34 | 
         
            +
            pip-wheel-metadata/
         
     | 
| 35 | 
         
            +
            share/python-wheels/
         
     | 
| 36 | 
         
            +
            *.egg-info/
         
     | 
| 37 | 
         
            +
            .installed.cfg
         
     | 
| 38 | 
         
            +
            *.egg
         
     | 
| 39 | 
         
            +
            MANIFEST
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
            # PyInstaller
         
     | 
| 42 | 
         
            +
            #  Usually these files are written by a python script from a template
         
     | 
| 43 | 
         
            +
            #  before PyInstaller builds the exe, so as to inject date/other infos into it.
         
     | 
| 44 | 
         
            +
            *.manifest
         
     | 
| 45 | 
         
            +
            *.spec
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
            # Installer logs
         
     | 
| 48 | 
         
            +
            pip-log.txt
         
     | 
| 49 | 
         
            +
            pip-delete-this-directory.txt
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
            # Unit test / coverage reports
         
     | 
| 52 | 
         
            +
            htmlcov/
         
     | 
| 53 | 
         
            +
            .tox/
         
     | 
| 54 | 
         
            +
            .nox/
         
     | 
| 55 | 
         
            +
            .coverage
         
     | 
| 56 | 
         
            +
            .coverage.*
         
     | 
| 57 | 
         
            +
            .cache
         
     | 
| 58 | 
         
            +
            nosetests.xml
         
     | 
| 59 | 
         
            +
            coverage.xml
         
     | 
| 60 | 
         
            +
            *.cover
         
     | 
| 61 | 
         
            +
            *.py,cover
         
     | 
| 62 | 
         
            +
            .hypothesis/
         
     | 
| 63 | 
         
            +
            .pytest_cache/
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
            # Translations
         
     | 
| 66 | 
         
            +
            *.mo
         
     | 
| 67 | 
         
            +
            *.pot
         
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
            +
            # Django stuff:
         
     | 
| 70 | 
         
            +
            *.log
         
     | 
| 71 | 
         
            +
            local_settings.py
         
     | 
| 72 | 
         
            +
            db.sqlite3
         
     | 
| 73 | 
         
            +
            db.sqlite3-journal
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
            # Flask stuff:
         
     | 
| 76 | 
         
            +
            instance/
         
     | 
| 77 | 
         
            +
            .webassets-cache
         
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
            # Scrapy stuff:
         
     | 
| 80 | 
         
            +
            .scrapy
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
            # Sphinx documentation
         
     | 
| 83 | 
         
            +
            docs/_build/
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
            # PyBuilder
         
     | 
| 86 | 
         
            +
            target/
         
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
            # Jupyter Notebook
         
     | 
| 89 | 
         
            +
            .ipynb_checkpoints
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
            # IPython
         
     | 
| 92 | 
         
            +
            profile_default/
         
     | 
| 93 | 
         
            +
            ipython_config.py
         
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
            # pyenv
         
     | 
| 96 | 
         
            +
            .python-version
         
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
            +
            # pipenv
         
     | 
| 99 | 
         
            +
            #   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
         
     | 
| 100 | 
         
            +
            #   However, in case of collaboration, if having platform-specific dependencies or dependencies
         
     | 
| 101 | 
         
            +
            #   having no cross-platform support, pipenv may install dependencies that don't work, or not
         
     | 
| 102 | 
         
            +
            #   install all needed dependencies.
         
     | 
| 103 | 
         
            +
            #Pipfile.lock
         
     | 
| 104 | 
         
            +
             
     | 
| 105 | 
         
            +
            # PEP 582; used by e.g. github.com/David-OConnor/pyflow
         
     | 
| 106 | 
         
            +
            __pypackages__/
         
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
            # Celery stuff
         
     | 
| 109 | 
         
            +
            celerybeat-schedule
         
     | 
| 110 | 
         
            +
            celerybeat.pid
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
            # SageMath parsed files
         
     | 
| 113 | 
         
            +
            *.sage.py
         
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
            # Environments
         
     | 
| 116 | 
         
            +
            .env
         
     | 
| 117 | 
         
            +
            .venv
         
     | 
| 118 | 
         
            +
            env/
         
     | 
| 119 | 
         
            +
            venv/
         
     | 
| 120 | 
         
            +
            ENV/
         
     | 
| 121 | 
         
            +
            env.bak/
         
     | 
| 122 | 
         
            +
            venv.bak/
         
     | 
| 123 | 
         
            +
             
     | 
| 124 | 
         
            +
            # Spyder project settings
         
     | 
| 125 | 
         
            +
            .spyderproject
         
     | 
| 126 | 
         
            +
            .spyproject
         
     | 
| 127 | 
         
            +
             
     | 
| 128 | 
         
            +
            # Rope project settings
         
     | 
| 129 | 
         
            +
            .ropeproject
         
     | 
| 130 | 
         
            +
             
     | 
| 131 | 
         
            +
            # mkdocs documentation
         
     | 
| 132 | 
         
            +
            /site
         
     | 
| 133 | 
         
            +
             
     | 
| 134 | 
         
            +
            # mypy
         
     | 
| 135 | 
         
            +
            .mypy_cache/
         
     | 
| 136 | 
         
            +
            .dmypy.json
         
     | 
| 137 | 
         
            +
            dmypy.json
         
     | 
| 138 | 
         
            +
             
     | 
| 139 | 
         
            +
            # Pyre type checker
         
     | 
| 140 | 
         
            +
            .pyre/
         
     | 
| 141 | 
         
            +
             
     | 
| 142 | 
         
            +
            .idea/
         
     | 
    	
        LICENSE
    ADDED
    
    | 
         @@ -0,0 +1,201 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
                                             Apache License
         
     | 
| 2 | 
         
            +
                                       Version 2.0, January 2004
         
     | 
| 3 | 
         
            +
                                    http://www.apache.org/licenses/
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
               TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
               1. Definitions.
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
                  "License" shall mean the terms and conditions for use, reproduction,
         
     | 
| 10 | 
         
            +
                  and distribution as defined by Sections 1 through 9 of this document.
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
                  "Licensor" shall mean the copyright owner or entity authorized by
         
     | 
| 13 | 
         
            +
                  the copyright owner that is granting the License.
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
                  "Legal Entity" shall mean the union of the acting entity and all
         
     | 
| 16 | 
         
            +
                  other entities that control, are controlled by, or are under common
         
     | 
| 17 | 
         
            +
                  control with that entity. For the purposes of this definition,
         
     | 
| 18 | 
         
            +
                  "control" means (i) the power, direct or indirect, to cause the
         
     | 
| 19 | 
         
            +
                  direction or management of such entity, whether by contract or
         
     | 
| 20 | 
         
            +
                  otherwise, or (ii) ownership of fifty percent (50%) or more of the
         
     | 
| 21 | 
         
            +
                  outstanding shares, or (iii) beneficial ownership of such entity.
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
                  "You" (or "Your") shall mean an individual or Legal Entity
         
     | 
| 24 | 
         
            +
                  exercising permissions granted by this License.
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
                  "Source" form shall mean the preferred form for making modifications,
         
     | 
| 27 | 
         
            +
                  including but not limited to software source code, documentation
         
     | 
| 28 | 
         
            +
                  source, and configuration files.
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
                  "Object" form shall mean any form resulting from mechanical
         
     | 
| 31 | 
         
            +
                  transformation or translation of a Source form, including but
         
     | 
| 32 | 
         
            +
                  not limited to compiled object code, generated documentation,
         
     | 
| 33 | 
         
            +
                  and conversions to other media types.
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
                  "Work" shall mean the work of authorship, whether in Source or
         
     | 
| 36 | 
         
            +
                  Object form, made available under the License, as indicated by a
         
     | 
| 37 | 
         
            +
                  copyright notice that is included in or attached to the work
         
     | 
| 38 | 
         
            +
                  (an example is provided in the Appendix below).
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
                  "Derivative Works" shall mean any work, whether in Source or Object
         
     | 
| 41 | 
         
            +
                  form, that is based on (or derived from) the Work and for which the
         
     | 
| 42 | 
         
            +
                  editorial revisions, annotations, elaborations, or other modifications
         
     | 
| 43 | 
         
            +
                  represent, as a whole, an original work of authorship. For the purposes
         
     | 
| 44 | 
         
            +
                  of this License, Derivative Works shall not include works that remain
         
     | 
| 45 | 
         
            +
                  separable from, or merely link (or bind by name) to the interfaces of,
         
     | 
| 46 | 
         
            +
                  the Work and Derivative Works thereof.
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
                  "Contribution" shall mean any work of authorship, including
         
     | 
| 49 | 
         
            +
                  the original version of the Work and any modifications or additions
         
     | 
| 50 | 
         
            +
                  to that Work or Derivative Works thereof, that is intentionally
         
     | 
| 51 | 
         
            +
                  submitted to Licensor for inclusion in the Work by the copyright owner
         
     | 
| 52 | 
         
            +
                  or by an individual or Legal Entity authorized to submit on behalf of
         
     | 
| 53 | 
         
            +
                  the copyright owner. For the purposes of this definition, "submitted"
         
     | 
| 54 | 
         
            +
                  means any form of electronic, verbal, or written communication sent
         
     | 
| 55 | 
         
            +
                  to the Licensor or its representatives, including but not limited to
         
     | 
| 56 | 
         
            +
                  communication on electronic mailing lists, source code control systems,
         
     | 
| 57 | 
         
            +
                  and issue tracking systems that are managed by, or on behalf of, the
         
     | 
| 58 | 
         
            +
                  Licensor for the purpose of discussing and improving the Work, but
         
     | 
| 59 | 
         
            +
                  excluding communication that is conspicuously marked or otherwise
         
     | 
| 60 | 
         
            +
                  designated in writing by the copyright owner as "Not a Contribution."
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
                  "Contributor" shall mean Licensor and any individual or Legal Entity
         
     | 
| 63 | 
         
            +
                  on behalf of whom a Contribution has been received by Licensor and
         
     | 
| 64 | 
         
            +
                  subsequently incorporated within the Work.
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
               2. Grant of Copyright License. Subject to the terms and conditions of
         
     | 
| 67 | 
         
            +
                  this License, each Contributor hereby grants to You a perpetual,
         
     | 
| 68 | 
         
            +
                  worldwide, non-exclusive, no-charge, royalty-free, irrevocable
         
     | 
| 69 | 
         
            +
                  copyright license to reproduce, prepare Derivative Works of,
         
     | 
| 70 | 
         
            +
                  publicly display, publicly perform, sublicense, and distribute the
         
     | 
| 71 | 
         
            +
                  Work and such Derivative Works in Source or Object form.
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
               3. Grant of Patent License. Subject to the terms and conditions of
         
     | 
| 74 | 
         
            +
                  this License, each Contributor hereby grants to You a perpetual,
         
     | 
| 75 | 
         
            +
                  worldwide, non-exclusive, no-charge, royalty-free, irrevocable
         
     | 
| 76 | 
         
            +
                  (except as stated in this section) patent license to make, have made,
         
     | 
| 77 | 
         
            +
                  use, offer to sell, sell, import, and otherwise transfer the Work,
         
     | 
| 78 | 
         
            +
                  where such license applies only to those patent claims licensable
         
     | 
| 79 | 
         
            +
                  by such Contributor that are necessarily infringed by their
         
     | 
| 80 | 
         
            +
                  Contribution(s) alone or by combination of their Contribution(s)
         
     | 
| 81 | 
         
            +
                  with the Work to which such Contribution(s) was submitted. If You
         
     | 
| 82 | 
         
            +
                  institute patent litigation against any entity (including a
         
     | 
| 83 | 
         
            +
                  cross-claim or counterclaim in a lawsuit) alleging that the Work
         
     | 
| 84 | 
         
            +
                  or a Contribution incorporated within the Work constitutes direct
         
     | 
| 85 | 
         
            +
                  or contributory patent infringement, then any patent licenses
         
     | 
| 86 | 
         
            +
                  granted to You under this License for that Work shall terminate
         
     | 
| 87 | 
         
            +
                  as of the date such litigation is filed.
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
               4. Redistribution. You may reproduce and distribute copies of the
         
     | 
| 90 | 
         
            +
                  Work or Derivative Works thereof in any medium, with or without
         
     | 
| 91 | 
         
            +
                  modifications, and in Source or Object form, provided that You
         
     | 
| 92 | 
         
            +
                  meet the following conditions:
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
                  (a) You must give any other recipients of the Work or
         
     | 
| 95 | 
         
            +
                      Derivative Works a copy of this License; and
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
                  (b) You must cause any modified files to carry prominent notices
         
     | 
| 98 | 
         
            +
                      stating that You changed the files; and
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
                  (c) You must retain, in the Source form of any Derivative Works
         
     | 
| 101 | 
         
            +
                      that You distribute, all copyright, patent, trademark, and
         
     | 
| 102 | 
         
            +
                      attribution notices from the Source form of the Work,
         
     | 
| 103 | 
         
            +
                      excluding those notices that do not pertain to any part of
         
     | 
| 104 | 
         
            +
                      the Derivative Works; and
         
     | 
| 105 | 
         
            +
             
     | 
| 106 | 
         
            +
                  (d) If the Work includes a "NOTICE" text file as part of its
         
     | 
| 107 | 
         
            +
                      distribution, then any Derivative Works that You distribute must
         
     | 
| 108 | 
         
            +
                      include a readable copy of the attribution notices contained
         
     | 
| 109 | 
         
            +
                      within such NOTICE file, excluding those notices that do not
         
     | 
| 110 | 
         
            +
                      pertain to any part of the Derivative Works, in at least one
         
     | 
| 111 | 
         
            +
                      of the following places: within a NOTICE text file distributed
         
     | 
| 112 | 
         
            +
                      as part of the Derivative Works; within the Source form or
         
     | 
| 113 | 
         
            +
                      documentation, if provided along with the Derivative Works; or,
         
     | 
| 114 | 
         
            +
                      within a display generated by the Derivative Works, if and
         
     | 
| 115 | 
         
            +
                      wherever such third-party notices normally appear. The contents
         
     | 
| 116 | 
         
            +
                      of the NOTICE file are for informational purposes only and
         
     | 
| 117 | 
         
            +
                      do not modify the License. You may add Your own attribution
         
     | 
| 118 | 
         
            +
                      notices within Derivative Works that You distribute, alongside
         
     | 
| 119 | 
         
            +
                      or as an addendum to the NOTICE text from the Work, provided
         
     | 
| 120 | 
         
            +
                      that such additional attribution notices cannot be construed
         
     | 
| 121 | 
         
            +
                      as modifying the License.
         
     | 
| 122 | 
         
            +
             
     | 
| 123 | 
         
            +
                  You may add Your own copyright statement to Your modifications and
         
     | 
| 124 | 
         
            +
                  may provide additional or different license terms and conditions
         
     | 
| 125 | 
         
            +
                  for use, reproduction, or distribution of Your modifications, or
         
     | 
| 126 | 
         
            +
                  for any such Derivative Works as a whole, provided Your use,
         
     | 
| 127 | 
         
            +
                  reproduction, and distribution of the Work otherwise complies with
         
     | 
| 128 | 
         
            +
                  the conditions stated in this License.
         
     | 
| 129 | 
         
            +
             
     | 
| 130 | 
         
            +
               5. Submission of Contributions. Unless You explicitly state otherwise,
         
     | 
| 131 | 
         
            +
                  any Contribution intentionally submitted for inclusion in the Work
         
     | 
| 132 | 
         
            +
                  by You to the Licensor shall be under the terms and conditions of
         
     | 
| 133 | 
         
            +
                  this License, without any additional terms or conditions.
         
     | 
| 134 | 
         
            +
                  Notwithstanding the above, nothing herein shall supersede or modify
         
     | 
| 135 | 
         
            +
                  the terms of any separate license agreement you may have executed
         
     | 
| 136 | 
         
            +
                  with Licensor regarding such Contributions.
         
     | 
| 137 | 
         
            +
             
     | 
| 138 | 
         
            +
               6. Trademarks. This License does not grant permission to use the trade
         
     | 
| 139 | 
         
            +
                  names, trademarks, service marks, or product names of the Licensor,
         
     | 
| 140 | 
         
            +
                  except as required for reasonable and customary use in describing the
         
     | 
| 141 | 
         
            +
                  origin of the Work and reproducing the content of the NOTICE file.
         
     | 
| 142 | 
         
            +
             
     | 
| 143 | 
         
            +
               7. Disclaimer of Warranty. Unless required by applicable law or
         
     | 
| 144 | 
         
            +
                  agreed to in writing, Licensor provides the Work (and each
         
     | 
| 145 | 
         
            +
                  Contributor provides its Contributions) on an "AS IS" BASIS,
         
     | 
| 146 | 
         
            +
                  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
         
     | 
| 147 | 
         
            +
                  implied, including, without limitation, any warranties or conditions
         
     | 
| 148 | 
         
            +
                  of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
         
     | 
| 149 | 
         
            +
                  PARTICULAR PURPOSE. You are solely responsible for determining the
         
     | 
| 150 | 
         
            +
                  appropriateness of using or redistributing the Work and assume any
         
     | 
| 151 | 
         
            +
                  risks associated with Your exercise of permissions under this License.
         
     | 
| 152 | 
         
            +
             
     | 
| 153 | 
         
            +
               8. Limitation of Liability. In no event and under no legal theory,
         
     | 
| 154 | 
         
            +
                  whether in tort (including negligence), contract, or otherwise,
         
     | 
| 155 | 
         
            +
                  unless required by applicable law (such as deliberate and grossly
         
     | 
| 156 | 
         
            +
                  negligent acts) or agreed to in writing, shall any Contributor be
         
     | 
| 157 | 
         
            +
                  liable to You for damages, including any direct, indirect, special,
         
     | 
| 158 | 
         
            +
                  incidental, or consequential damages of any character arising as a
         
     | 
| 159 | 
         
            +
                  result of this License or out of the use or inability to use the
         
     | 
| 160 | 
         
            +
                  Work (including but not limited to damages for loss of goodwill,
         
     | 
| 161 | 
         
            +
                  work stoppage, computer failure or malfunction, or any and all
         
     | 
| 162 | 
         
            +
                  other commercial damages or losses), even if such Contributor
         
     | 
| 163 | 
         
            +
                  has been advised of the possibility of such damages.
         
     | 
| 164 | 
         
            +
             
     | 
| 165 | 
         
            +
               9. Accepting Warranty or Additional Liability. While redistributing
         
     | 
| 166 | 
         
            +
                  the Work or Derivative Works thereof, You may choose to offer,
         
     | 
| 167 | 
         
            +
                  and charge a fee for, acceptance of support, warranty, indemnity,
         
     | 
| 168 | 
         
            +
                  or other liability obligations and/or rights consistent with this
         
     | 
| 169 | 
         
            +
                  License. However, in accepting such obligations, You may act only
         
     | 
| 170 | 
         
            +
                  on Your own behalf and on Your sole responsibility, not on behalf
         
     | 
| 171 | 
         
            +
                  of any other Contributor, and only if You agree to indemnify,
         
     | 
| 172 | 
         
            +
                  defend, and hold each Contributor harmless for any liability
         
     | 
| 173 | 
         
            +
                  incurred by, or claims asserted against, such Contributor by reason
         
     | 
| 174 | 
         
            +
                  of your accepting any such warranty or additional liability.
         
     | 
| 175 | 
         
            +
             
     | 
| 176 | 
         
            +
               END OF TERMS AND CONDITIONS
         
     | 
| 177 | 
         
            +
             
     | 
| 178 | 
         
            +
               APPENDIX: How to apply the Apache License to your work.
         
     | 
| 179 | 
         
            +
             
     | 
| 180 | 
         
            +
                  To apply the Apache License to your work, attach the following
         
     | 
| 181 | 
         
            +
                  boilerplate notice, with the fields enclosed by brackets "[]"
         
     | 
| 182 | 
         
            +
                  replaced with your own identifying information. (Don't include
         
     | 
| 183 | 
         
            +
                  the brackets!)  The text should be enclosed in the appropriate
         
     | 
| 184 | 
         
            +
                  comment syntax for the file format. We also recommend that a
         
     | 
| 185 | 
         
            +
                  file or class name and description of purpose be included on the
         
     | 
| 186 | 
         
            +
                  same "printed page" as the copyright notice for easier
         
     | 
| 187 | 
         
            +
                  identification within third-party archives.
         
     | 
| 188 | 
         
            +
             
     | 
| 189 | 
         
            +
               Copyright [yyyy] [name of copyright owner]
         
     | 
| 190 | 
         
            +
             
     | 
| 191 | 
         
            +
               Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 192 | 
         
            +
               you may not use this file except in compliance with the License.
         
     | 
| 193 | 
         
            +
               You may obtain a copy of the License at
         
     | 
| 194 | 
         
            +
             
     | 
| 195 | 
         
            +
                   http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 196 | 
         
            +
             
     | 
| 197 | 
         
            +
               Unless required by applicable law or agreed to in writing, software
         
     | 
| 198 | 
         
            +
               distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 199 | 
         
            +
               WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 200 | 
         
            +
               See the License for the specific language governing permissions and
         
     | 
| 201 | 
         
            +
               limitations under the License.
         
     | 
    	
        app.py
    ADDED
    
    | 
         @@ -0,0 +1,372 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 4 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 5 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 6 | 
         
            +
            #
         
     | 
| 7 | 
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 8 | 
         
            +
            #
         
     | 
| 9 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 10 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 11 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 12 | 
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 13 | 
         
            +
            # limitations under the License.
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            import spaces
         
     | 
| 16 | 
         
            +
            import argparse
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            import os
         
     | 
| 19 | 
         
            +
            import cv2
         
     | 
| 20 | 
         
            +
            import gradio as gr
         
     | 
| 21 | 
         
            +
            import numpy as np
         
     | 
| 22 | 
         
            +
            import torch
         
     | 
| 23 | 
         
            +
            from facexlib.utils.face_restoration_helper import FaceRestoreHelper
         
     | 
| 24 | 
         
            +
            import huggingface_hub
         
     | 
| 25 | 
         
            +
            from huggingface_hub import hf_hub_download
         
     | 
| 26 | 
         
            +
            from PIL import Image
         
     | 
| 27 | 
         
            +
            from torchvision.transforms.functional import normalize
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
            from dreamo.dreamo_pipeline import DreamOPipeline
         
     | 
| 30 | 
         
            +
            from dreamo.utils import img2tensor, resize_numpy_image_area, tensor2img
         
     | 
| 31 | 
         
            +
            from tools import BEN2
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
            parser = argparse.ArgumentParser()
         
     | 
| 34 | 
         
            +
            parser.add_argument('--port', type=int, default=8080)
         
     | 
| 35 | 
         
            +
            args = parser.parse_args()
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
            huggingface_hub.login(os.getenv('HF_TOKEN'))
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
            class Generator:
         
     | 
| 41 | 
         
            +
                def __init__(self):
         
     | 
| 42 | 
         
            +
                    device = torch.device('cuda')
         
     | 
| 43 | 
         
            +
                    # preprocessing models
         
     | 
| 44 | 
         
            +
                    # background remove model: BEN2
         
     | 
| 45 | 
         
            +
                    self.bg_rm_model = BEN2.BEN_Base().to(device).eval()
         
     | 
| 46 | 
         
            +
                    hf_hub_download(repo_id='PramaLLC/BEN2', filename='BEN2_Base.pth', local_dir='models')
         
     | 
| 47 | 
         
            +
                    self.bg_rm_model.loadcheckpoints('models/BEN2_Base.pth')
         
     | 
| 48 | 
         
            +
                    # face crop and align tool: facexlib
         
     | 
| 49 | 
         
            +
                    self.face_helper = FaceRestoreHelper(
         
     | 
| 50 | 
         
            +
                        upscale_factor=1,
         
     | 
| 51 | 
         
            +
                        face_size=512,
         
     | 
| 52 | 
         
            +
                        crop_ratio=(1, 1),
         
     | 
| 53 | 
         
            +
                        det_model='retinaface_resnet50',
         
     | 
| 54 | 
         
            +
                        save_ext='png',
         
     | 
| 55 | 
         
            +
                        device=device,
         
     | 
| 56 | 
         
            +
                    )
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
                    # load dreamo
         
     | 
| 59 | 
         
            +
                    model_root = 'black-forest-labs/FLUX.1-dev'
         
     | 
| 60 | 
         
            +
                    dreamo_pipeline = DreamOPipeline.from_pretrained(model_root, torch_dtype=torch.bfloat16)
         
     | 
| 61 | 
         
            +
                    dreamo_pipeline.load_dreamo_model(device, use_turbo=True)
         
     | 
| 62 | 
         
            +
                    self.dreamo_pipeline = dreamo_pipeline.to(device)
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
                @torch.no_grad()
         
     | 
| 65 | 
         
            +
                def get_align_face(self, img):
         
     | 
| 66 | 
         
            +
                    # the face preprocessing code is same as PuLID
         
     | 
| 67 | 
         
            +
                    self.face_helper.clean_all()
         
     | 
| 68 | 
         
            +
                    image_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
         
     | 
| 69 | 
         
            +
                    self.face_helper.read_image(image_bgr)
         
     | 
| 70 | 
         
            +
                    self.face_helper.get_face_landmarks_5(only_center_face=True)
         
     | 
| 71 | 
         
            +
                    self.face_helper.align_warp_face()
         
     | 
| 72 | 
         
            +
                    if len(self.face_helper.cropped_faces) == 0:
         
     | 
| 73 | 
         
            +
                        return None
         
     | 
| 74 | 
         
            +
                    align_face = self.face_helper.cropped_faces[0]
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
                    input = img2tensor(align_face, bgr2rgb=True).unsqueeze(0) / 255.0
         
     | 
| 77 | 
         
            +
                    input = input.to(torch.device("cuda"))
         
     | 
| 78 | 
         
            +
                    parsing_out = self.face_helper.face_parse(normalize(input, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]))[0]
         
     | 
| 79 | 
         
            +
                    parsing_out = parsing_out.argmax(dim=1, keepdim=True)
         
     | 
| 80 | 
         
            +
                    bg_label = [0, 16, 18, 7, 8, 9, 14, 15]
         
     | 
| 81 | 
         
            +
                    bg = sum(parsing_out == i for i in bg_label).bool()
         
     | 
| 82 | 
         
            +
                    white_image = torch.ones_like(input)
         
     | 
| 83 | 
         
            +
                    # only keep the face features
         
     | 
| 84 | 
         
            +
                    face_features_image = torch.where(bg, white_image, input)
         
     | 
| 85 | 
         
            +
                    face_features_image = tensor2img(face_features_image, rgb2bgr=False)
         
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
                    return face_features_image
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
            generator = Generator()
         
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
             
     | 
| 93 | 
         
            +
            @spaces.GPU
         
     | 
| 94 | 
         
            +
            @torch.inference_mode()
         
     | 
| 95 | 
         
            +
            def generate_image(
         
     | 
| 96 | 
         
            +
                ref_image1,
         
     | 
| 97 | 
         
            +
                ref_image2,
         
     | 
| 98 | 
         
            +
                ref_task1,
         
     | 
| 99 | 
         
            +
                ref_task2,
         
     | 
| 100 | 
         
            +
                prompt,
         
     | 
| 101 | 
         
            +
                width,
         
     | 
| 102 | 
         
            +
                height,
         
     | 
| 103 | 
         
            +
                ref_res,
         
     | 
| 104 | 
         
            +
                num_steps,
         
     | 
| 105 | 
         
            +
                guidance,
         
     | 
| 106 | 
         
            +
                seed,
         
     | 
| 107 | 
         
            +
                true_cfg,
         
     | 
| 108 | 
         
            +
                cfg_start_step,
         
     | 
| 109 | 
         
            +
                cfg_end_step,
         
     | 
| 110 | 
         
            +
                neg_prompt,
         
     | 
| 111 | 
         
            +
                neg_guidance,
         
     | 
| 112 | 
         
            +
                first_step_guidance,
         
     | 
| 113 | 
         
            +
            ):
         
     | 
| 114 | 
         
            +
                print(prompt)
         
     | 
| 115 | 
         
            +
                ref_conds = []
         
     | 
| 116 | 
         
            +
                debug_images = []
         
     | 
| 117 | 
         
            +
             
     | 
| 118 | 
         
            +
                ref_images = [ref_image1, ref_image2]
         
     | 
| 119 | 
         
            +
                ref_tasks = [ref_task1, ref_task2]
         
     | 
| 120 | 
         
            +
             
     | 
| 121 | 
         
            +
                for idx, (ref_image, ref_task) in enumerate(zip(ref_images, ref_tasks)):
         
     | 
| 122 | 
         
            +
                    if ref_image is not None:
         
     | 
| 123 | 
         
            +
                        if ref_task == "id":
         
     | 
| 124 | 
         
            +
                            ref_image = generator.get_align_face(ref_image)
         
     | 
| 125 | 
         
            +
                        elif ref_task != "style":
         
     | 
| 126 | 
         
            +
                            ref_image = generator.bg_rm_model.inference(Image.fromarray(ref_image))
         
     | 
| 127 | 
         
            +
                        ref_image = resize_numpy_image_area(np.array(ref_image), ref_res * ref_res)
         
     | 
| 128 | 
         
            +
                        debug_images.append(ref_image)
         
     | 
| 129 | 
         
            +
                        ref_image = img2tensor(ref_image, bgr2rgb=False).unsqueeze(0) / 255.0
         
     | 
| 130 | 
         
            +
                        ref_image = 2 * ref_image - 1.0
         
     | 
| 131 | 
         
            +
                        ref_conds.append(
         
     | 
| 132 | 
         
            +
                            {
         
     | 
| 133 | 
         
            +
                                'img': ref_image,
         
     | 
| 134 | 
         
            +
                                'task': ref_task,
         
     | 
| 135 | 
         
            +
                                'idx': idx + 1,
         
     | 
| 136 | 
         
            +
                            }
         
     | 
| 137 | 
         
            +
                        )
         
     | 
| 138 | 
         
            +
             
     | 
| 139 | 
         
            +
                seed = int(seed)
         
     | 
| 140 | 
         
            +
                if seed == -1:
         
     | 
| 141 | 
         
            +
                    seed = torch.Generator(device="cpu").seed()
         
     | 
| 142 | 
         
            +
             
     | 
| 143 | 
         
            +
                image = generator.dreamo_pipeline(
         
     | 
| 144 | 
         
            +
                    prompt=prompt,
         
     | 
| 145 | 
         
            +
                    width=width,
         
     | 
| 146 | 
         
            +
                    height=height,
         
     | 
| 147 | 
         
            +
                    num_inference_steps=num_steps,
         
     | 
| 148 | 
         
            +
                    guidance_scale=guidance,
         
     | 
| 149 | 
         
            +
                    ref_conds=ref_conds,
         
     | 
| 150 | 
         
            +
                    generator=torch.Generator(device="cpu").manual_seed(seed),
         
     | 
| 151 | 
         
            +
                    true_cfg_scale=true_cfg,
         
     | 
| 152 | 
         
            +
                    true_cfg_start_step=cfg_start_step,
         
     | 
| 153 | 
         
            +
                    true_cfg_end_step=cfg_end_step,
         
     | 
| 154 | 
         
            +
                    negative_prompt=neg_prompt,
         
     | 
| 155 | 
         
            +
                    neg_guidance_scale=neg_guidance,
         
     | 
| 156 | 
         
            +
                    first_step_guidance_scale=first_step_guidance if first_step_guidance > 0 else guidance,
         
     | 
| 157 | 
         
            +
                ).images[0]
         
     | 
| 158 | 
         
            +
             
     | 
| 159 | 
         
            +
                return image, debug_images, seed
         
     | 
| 160 | 
         
            +
             
     | 
| 161 | 
         
            +
             
     | 
| 162 | 
         
            +
            _HEADER_ = '''
         
     | 
| 163 | 
         
            +
            <div style="text-align: center; max-width: 650px; margin: 0 auto;">
         
     | 
| 164 | 
         
            +
                <h1 style="font-size: 2.5rem; font-weight: 700; margin-bottom: 1rem; display: contents;">DreamO</h1>
         
     | 
| 165 | 
         
            +
                <p style="font-size: 1rem; margin-bottom: 1.5rem;">Paper: <a href='https://arxiv.org/abs/2504.16915' target='_blank'>DreamO: A Unified Framework for Image Customization</a> | Codes: <a href='https://github.com/bytedance/DreamO' target='_blank'>GitHub</a></p>
         
     | 
| 166 | 
         
            +
            </div>
         
     | 
| 167 | 
         
            +
             
     | 
| 168 | 
         
            +
            ❗️❗️❗️**User Guide:**
         
     | 
| 169 | 
         
            +
            - The most important thing to do first is to try the examples provided below the demo, which will help you better understand the capabilities of the DreamO model and the types of tasks it currently supports
         
     | 
| 170 | 
         
            +
            - For each input, please select the appropriate task type. For general objects, characters, or clothing, choose IP — we will remove the background from the input image. If you select ID, we will extract the face region from the input image (similar to PuLID). If you select Style, the background will be preserved, and you must prepend the prompt with the instruction: 'generate a same style image.' to activate the style task.
         
     | 
| 171 | 
         
            +
            - To accelerate inference, we adopt FLUX-turbo LoRA, which reduces the sampling steps from 25 to 12 compared to FLUX-dev. Additionally, we distill a CFG LoRA, achieving nearly a twofold reduction in steps by eliminating the need for true CFG
         
     | 
| 172 | 
         
            +
             
     | 
| 173 | 
         
            +
            '''  # noqa E501
         
     | 
| 174 | 
         
            +
             
     | 
| 175 | 
         
            +
            _CITE_ = r"""
         
     | 
| 176 | 
         
            +
            If DreamO is helpful, please help to ⭐ the <a href='https://github.com/bytedance/DreamO' target='_blank'> Github Repo</a>. Thanks!
         
     | 
| 177 | 
         
            +
            ---
         
     | 
| 178 | 
         
            +
             
     | 
| 179 | 
         
            +
            📧 **Contact**
         
     | 
| 180 | 
         
            +
            If you have any questions or feedbacks, feel free to open a discussion or contact <b>wuyanze123@gmail.com</b> and <b>eechongm@gmail.com</b>
         
     | 
| 181 | 
         
            +
            """  # noqa E501
         
     | 
| 182 | 
         
            +
             
     | 
| 183 | 
         
            +
             
     | 
| 184 | 
         
            +
            def create_demo():
         
     | 
| 185 | 
         
            +
             
     | 
| 186 | 
         
            +
                with gr.Blocks() as demo:
         
     | 
| 187 | 
         
            +
                    gr.Markdown(_HEADER_)
         
     | 
| 188 | 
         
            +
             
     | 
| 189 | 
         
            +
                    with gr.Row():
         
     | 
| 190 | 
         
            +
                        with gr.Column():
         
     | 
| 191 | 
         
            +
                            with gr.Row():
         
     | 
| 192 | 
         
            +
                                ref_image1 = gr.Image(label="ref image 1", type="numpy", height=256)
         
     | 
| 193 | 
         
            +
                                ref_image2 = gr.Image(label="ref image 2", type="numpy", height=256)
         
     | 
| 194 | 
         
            +
                            with gr.Row():
         
     | 
| 195 | 
         
            +
                                ref_task1 = gr.Dropdown(choices=["ip", "id", "style"], value="ip", label="task for ref image 1")
         
     | 
| 196 | 
         
            +
                                ref_task2 = gr.Dropdown(choices=["ip", "id", "style"], value="ip", label="task for ref image 2")
         
     | 
| 197 | 
         
            +
                            prompt = gr.Textbox(label="Prompt", value="a person playing guitar in the street")
         
     | 
| 198 | 
         
            +
                            width = gr.Slider(768, 1024, 1024, step=16, label="Width")
         
     | 
| 199 | 
         
            +
                            height = gr.Slider(768, 1024, 1024, step=16, label="Height")
         
     | 
| 200 | 
         
            +
                            num_steps = gr.Slider(8, 30, 12, step=1, label="Number of steps")
         
     | 
| 201 | 
         
            +
                            guidance = gr.Slider(1.0, 10.0, 3.5, step=0.1, label="Guidance")
         
     | 
| 202 | 
         
            +
                            seed = gr.Textbox(-1, label="Seed (-1 for random)")
         
     | 
| 203 | 
         
            +
                            with gr.Accordion("Advanced Options", open=False, visible=False):
         
     | 
| 204 | 
         
            +
                                ref_res = gr.Slider(512, 1024, 512, step=16, label="resolution for ref image")
         
     | 
| 205 | 
         
            +
                                neg_prompt = gr.Textbox(label="Neg Prompt", value="")
         
     | 
| 206 | 
         
            +
                                neg_guidance = gr.Slider(1.0, 10.0, 3.5, step=0.1, label="Neg Guidance")
         
     | 
| 207 | 
         
            +
                                true_cfg = gr.Slider(1, 5, 1, step=0.1, label="true cfg")
         
     | 
| 208 | 
         
            +
                                cfg_start_step = gr.Slider(0, 30, 0, step=1, label="cfg start step")
         
     | 
| 209 | 
         
            +
                                cfg_end_step = gr.Slider(0, 30, 0, step=1, label="cfg end step")
         
     | 
| 210 | 
         
            +
                                first_step_guidance = gr.Slider(0, 10, 0, step=0.1, label="first step guidance")
         
     | 
| 211 | 
         
            +
                            generate_btn = gr.Button("Generate")
         
     | 
| 212 | 
         
            +
                            gr.Markdown(_CITE_)
         
     | 
| 213 | 
         
            +
             
     | 
| 214 | 
         
            +
                        with gr.Column():
         
     | 
| 215 | 
         
            +
                            output_image = gr.Image(label="Generated Image", format='png')
         
     | 
| 216 | 
         
            +
                            debug_image = gr.Gallery(
         
     | 
| 217 | 
         
            +
                                label="Preprocessing output (including possible face crop and background remove)",
         
     | 
| 218 | 
         
            +
                                elem_id="gallery",
         
     | 
| 219 | 
         
            +
                            )
         
     | 
| 220 | 
         
            +
                            seed_output = gr.Textbox(label="Used Seed")
         
     | 
| 221 | 
         
            +
             
     | 
| 222 | 
         
            +
                    with gr.Row(), gr.Column():
         
     | 
| 223 | 
         
            +
                        gr.Markdown("## Examples")
         
     | 
| 224 | 
         
            +
                        example_inps = [
         
     | 
| 225 | 
         
            +
                            [
         
     | 
| 226 | 
         
            +
                                'example_inputs/woman1.png',
         
     | 
| 227 | 
         
            +
                                'ip',
         
     | 
| 228 | 
         
            +
                                'profile shot dark photo of a 25-year-old female with smoke escaping from her mouth, the backlit smoke gives the image an ephemeral quality, natural face, natural eyebrows, natural skin texture, award winning photo, highly detailed face, atmospheric lighting, film grain, monochrome',  # noqa E501
         
     | 
| 229 | 
         
            +
                                9180879731249039735,
         
     | 
| 230 | 
         
            +
                            ],
         
     | 
| 231 | 
         
            +
                            [
         
     | 
| 232 | 
         
            +
                                'example_inputs/man1.png',
         
     | 
| 233 | 
         
            +
                                'ip',
         
     | 
| 234 | 
         
            +
                                'a man sitting on the cloud, playing guitar',
         
     | 
| 235 | 
         
            +
                                1206523688721442817,
         
     | 
| 236 | 
         
            +
                            ],
         
     | 
| 237 | 
         
            +
                            [
         
     | 
| 238 | 
         
            +
                                'example_inputs/toy1.png',
         
     | 
| 239 | 
         
            +
                                'ip',
         
     | 
| 240 | 
         
            +
                                'a purple toy holding a sign saying "DreamO", on the mountain',
         
     | 
| 241 | 
         
            +
                                1563188099017016129,
         
     | 
| 242 | 
         
            +
                            ],
         
     | 
| 243 | 
         
            +
                            [
         
     | 
| 244 | 
         
            +
                                'example_inputs/perfume.png',
         
     | 
| 245 | 
         
            +
                                'ip',
         
     | 
| 246 | 
         
            +
                                'a perfume under spotlight',
         
     | 
| 247 | 
         
            +
                                116150031980664704,
         
     | 
| 248 | 
         
            +
                            ],
         
     | 
| 249 | 
         
            +
                        ]
         
     | 
| 250 | 
         
            +
                        gr.Examples(examples=example_inps, inputs=[ref_image1, ref_task1, prompt, seed], label='IP task', cache_examples='lazy')
         
     | 
| 251 | 
         
            +
             
     | 
| 252 | 
         
            +
                        example_inps = [
         
     | 
| 253 | 
         
            +
                            [
         
     | 
| 254 | 
         
            +
                                'example_inputs/hinton.jpeg',
         
     | 
| 255 | 
         
            +
                                None,
         
     | 
| 256 | 
         
            +
                                'id',
         
     | 
| 257 | 
         
            +
                                'ip',
         
     | 
| 258 | 
         
            +
                                'portrait, Chibi',
         
     | 
| 259 | 
         
            +
                                5443415087540486371,
         
     | 
| 260 | 
         
            +
                            ],
         
     | 
| 261 | 
         
            +
                        ]
         
     | 
| 262 | 
         
            +
                        gr.Examples(
         
     | 
| 263 | 
         
            +
                            examples=example_inps,
         
     | 
| 264 | 
         
            +
                            inputs=[ref_image1, ref_task1, prompt, seed],
         
     | 
| 265 | 
         
            +
                            label='ID task (similar to PuLID, will only refer to the face)',
         
     | 
| 266 | 
         
            +
                            cache_examples='lazy',
         
     | 
| 267 | 
         
            +
                        )
         
     | 
| 268 | 
         
            +
             
     | 
| 269 | 
         
            +
                        example_inps = [
         
     | 
| 270 | 
         
            +
                            [
         
     | 
| 271 | 
         
            +
                                'example_inputs/mickey.png',
         
     | 
| 272 | 
         
            +
                                'style',
         
     | 
| 273 | 
         
            +
                                'generate a same style image. A rooster wearing overalls.',
         
     | 
| 274 | 
         
            +
                                6245580464677124951,
         
     | 
| 275 | 
         
            +
                            ],
         
     | 
| 276 | 
         
            +
                            [
         
     | 
| 277 | 
         
            +
                                'example_inputs/mountain.png',
         
     | 
| 278 | 
         
            +
                                'style',
         
     | 
| 279 | 
         
            +
                                'generate a same style image. A pavilion by the river, and the distant mountains are endless',
         
     | 
| 280 | 
         
            +
                                5248066378927500767,
         
     | 
| 281 | 
         
            +
                            ],
         
     | 
| 282 | 
         
            +
                        ]
         
     | 
| 283 | 
         
            +
                        gr.Examples(examples=example_inps, inputs=[ref_image1, ref_task1, prompt, seed], label='Style task', cache_examples='lazy')
         
     | 
| 284 | 
         
            +
             
     | 
| 285 | 
         
            +
                        example_inps = [
         
     | 
| 286 | 
         
            +
                            [
         
     | 
| 287 | 
         
            +
                                'example_inputs/shirt.png',
         
     | 
| 288 | 
         
            +
                                'example_inputs/skirt.jpeg',
         
     | 
| 289 | 
         
            +
                                'ip',
         
     | 
| 290 | 
         
            +
                                'ip',
         
     | 
| 291 | 
         
            +
                                'A girl is wearing a short-sleeved shirt and a short skirt on the beach.',
         
     | 
| 292 | 
         
            +
                                9514069256241143615,
         
     | 
| 293 | 
         
            +
                            ],
         
     | 
| 294 | 
         
            +
                            [
         
     | 
| 295 | 
         
            +
                                'example_inputs/woman2.png',
         
     | 
| 296 | 
         
            +
                                'example_inputs/dress.png',
         
     | 
| 297 | 
         
            +
                                'id',
         
     | 
| 298 | 
         
            +
                                'ip',
         
     | 
| 299 | 
         
            +
                                'the woman wearing a dress, In the banquet hall',
         
     | 
| 300 | 
         
            +
                                7698454872441022867,
         
     | 
| 301 | 
         
            +
                            ],
         
     | 
| 302 | 
         
            +
                        ]
         
     | 
| 303 | 
         
            +
                        gr.Examples(
         
     | 
| 304 | 
         
            +
                            examples=example_inps,
         
     | 
| 305 | 
         
            +
                            inputs=[ref_image1, ref_image2, ref_task1, ref_task2, prompt, seed],
         
     | 
| 306 | 
         
            +
                            label='Try-On task',
         
     | 
| 307 | 
         
            +
                            cache_examples='lazy',
         
     | 
| 308 | 
         
            +
                        )
         
     | 
| 309 | 
         
            +
             
     | 
| 310 | 
         
            +
                        example_inps = [
         
     | 
| 311 | 
         
            +
                            [
         
     | 
| 312 | 
         
            +
                                'example_inputs/dog1.png',
         
     | 
| 313 | 
         
            +
                                'example_inputs/dog2.png',
         
     | 
| 314 | 
         
            +
                                'ip',
         
     | 
| 315 | 
         
            +
                                'ip',
         
     | 
| 316 | 
         
            +
                                'two dogs in the jungle',
         
     | 
| 317 | 
         
            +
                                3356402871128791851,
         
     | 
| 318 | 
         
            +
                            ],
         
     | 
| 319 | 
         
            +
                            [
         
     | 
| 320 | 
         
            +
                                'example_inputs/woman3.png',
         
     | 
| 321 | 
         
            +
                                'example_inputs/cat.png',
         
     | 
| 322 | 
         
            +
                                'ip',
         
     | 
| 323 | 
         
            +
                                'ip',
         
     | 
| 324 | 
         
            +
                                'A girl rides a giant cat, walking in the noisy modern city. High definition, realistic, non-cartoonish. Excellent photography work, 8k high definition.',  # noqa E501
         
     | 
| 325 | 
         
            +
                                11980469406460273604,
         
     | 
| 326 | 
         
            +
                            ],
         
     | 
| 327 | 
         
            +
                            [
         
     | 
| 328 | 
         
            +
                                'example_inputs/man2.jpeg',
         
     | 
| 329 | 
         
            +
                                'example_inputs/woman4.jpeg',
         
     | 
| 330 | 
         
            +
                                'ip',
         
     | 
| 331 | 
         
            +
                                'ip',
         
     | 
| 332 | 
         
            +
                                'a man is dancing with a woman in the room',
         
     | 
| 333 | 
         
            +
                                8303780338601106219,
         
     | 
| 334 | 
         
            +
                            ],
         
     | 
| 335 | 
         
            +
                        ]
         
     | 
| 336 | 
         
            +
                        gr.Examples(
         
     | 
| 337 | 
         
            +
                            examples=example_inps,
         
     | 
| 338 | 
         
            +
                            inputs=[ref_image1, ref_image2, ref_task1, ref_task2, prompt, seed],
         
     | 
| 339 | 
         
            +
                            label='Multi IP',
         
     | 
| 340 | 
         
            +
                            cache_examples='lazy',
         
     | 
| 341 | 
         
            +
                        )
         
     | 
| 342 | 
         
            +
             
     | 
| 343 | 
         
            +
                    generate_btn.click(
         
     | 
| 344 | 
         
            +
                        fn=generate_image,
         
     | 
| 345 | 
         
            +
                        inputs=[
         
     | 
| 346 | 
         
            +
                            ref_image1,
         
     | 
| 347 | 
         
            +
                            ref_image2,
         
     | 
| 348 | 
         
            +
                            ref_task1,
         
     | 
| 349 | 
         
            +
                            ref_task2,
         
     | 
| 350 | 
         
            +
                            prompt,
         
     | 
| 351 | 
         
            +
                            width,
         
     | 
| 352 | 
         
            +
                            height,
         
     | 
| 353 | 
         
            +
                            ref_res,
         
     | 
| 354 | 
         
            +
                            num_steps,
         
     | 
| 355 | 
         
            +
                            guidance,
         
     | 
| 356 | 
         
            +
                            seed,
         
     | 
| 357 | 
         
            +
                            true_cfg,
         
     | 
| 358 | 
         
            +
                            cfg_start_step,
         
     | 
| 359 | 
         
            +
                            cfg_end_step,
         
     | 
| 360 | 
         
            +
                            neg_prompt,
         
     | 
| 361 | 
         
            +
                            neg_guidance,
         
     | 
| 362 | 
         
            +
                            first_step_guidance,
         
     | 
| 363 | 
         
            +
                        ],
         
     | 
| 364 | 
         
            +
                        outputs=[output_image, debug_image, seed_output],
         
     | 
| 365 | 
         
            +
                    )
         
     | 
| 366 | 
         
            +
             
     | 
| 367 | 
         
            +
                return demo
         
     | 
| 368 | 
         
            +
             
     | 
| 369 | 
         
            +
             
     | 
| 370 | 
         
            +
            if __name__ == '__main__':
         
     | 
| 371 | 
         
            +
                demo = create_demo()
         
     | 
| 372 | 
         
            +
                demo.queue().launch(server_name='0.0.0.0', server_port=args.port)
         
     | 
    	
        dreamo/dreamo_pipeline.py
    ADDED
    
    | 
         @@ -0,0 +1,466 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
         
     | 
| 2 | 
         
            +
            # Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
         
     | 
| 3 | 
         
            +
            #
         
     | 
| 4 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 5 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 6 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 7 | 
         
            +
            #
         
     | 
| 8 | 
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 9 | 
         
            +
            #
         
     | 
| 10 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 11 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 12 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 13 | 
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 14 | 
         
            +
            # limitations under the License.
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            from typing import Any, Callable, Dict, List, Optional, Union
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            import diffusers
         
     | 
| 19 | 
         
            +
            import numpy as np
         
     | 
| 20 | 
         
            +
            import torch
         
     | 
| 21 | 
         
            +
            import torch.nn as nn
         
     | 
| 22 | 
         
            +
            from diffusers import FluxPipeline
         
     | 
| 23 | 
         
            +
            from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps
         
     | 
| 24 | 
         
            +
            from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
         
     | 
| 25 | 
         
            +
            from einops import repeat
         
     | 
| 26 | 
         
            +
            from huggingface_hub import hf_hub_download
         
     | 
| 27 | 
         
            +
            from safetensors.torch import load_file
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
            from dreamo.transformer import flux_transformer_forward
         
     | 
| 30 | 
         
            +
            from dreamo.utils import convert_flux_lora_to_diffusers
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
            diffusers.models.transformers.transformer_flux.FluxTransformer2DModel.forward = flux_transformer_forward
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
            def get_task_embedding_idx(task):
         
     | 
| 36 | 
         
            +
                return 0
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
            class DreamOPipeline(FluxPipeline):
         
     | 
| 40 | 
         
            +
                def __init__(self, scheduler, vae, text_encoder, tokenizer, text_encoder_2, tokenizer_2, transformer):
         
     | 
| 41 | 
         
            +
                    super().__init__(scheduler, vae, text_encoder, tokenizer, text_encoder_2, tokenizer_2, transformer)
         
     | 
| 42 | 
         
            +
                    self.t5_embedding = nn.Embedding(10, 4096)
         
     | 
| 43 | 
         
            +
                    self.task_embedding = nn.Embedding(2, 3072)
         
     | 
| 44 | 
         
            +
                    self.idx_embedding = nn.Embedding(10, 3072)
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
                def load_dreamo_model(self, device, use_turbo=True):
         
     | 
| 47 | 
         
            +
                    hf_hub_download(repo_id='ByteDance/DreamO', filename='dreamo.safetensors', local_dir='models')
         
     | 
| 48 | 
         
            +
                    hf_hub_download(repo_id='ByteDance/DreamO', filename='dreamo_cfg_distill.safetensors', local_dir='models')
         
     | 
| 49 | 
         
            +
                    dreamo_lora = load_file('models/dreamo.safetensors')
         
     | 
| 50 | 
         
            +
                    cfg_distill_lora = load_file('models/dreamo_cfg_distill.safetensors')
         
     | 
| 51 | 
         
            +
                    self.t5_embedding.weight.data = dreamo_lora.pop('dreamo_t5_embedding.weight')[-10:]
         
     | 
| 52 | 
         
            +
                    self.task_embedding.weight.data = dreamo_lora.pop('dreamo_task_embedding.weight')
         
     | 
| 53 | 
         
            +
                    self.idx_embedding.weight.data = dreamo_lora.pop('dreamo_idx_embedding.weight')
         
     | 
| 54 | 
         
            +
                    self._prepare_t5()
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
                    dreamo_diffuser_lora = convert_flux_lora_to_diffusers(dreamo_lora)
         
     | 
| 57 | 
         
            +
                    cfg_diffuser_lora = convert_flux_lora_to_diffusers(cfg_distill_lora)
         
     | 
| 58 | 
         
            +
                    adapter_names = ['dreamo']
         
     | 
| 59 | 
         
            +
                    adapter_weights = [1]
         
     | 
| 60 | 
         
            +
                    self.load_lora_weights(dreamo_diffuser_lora, adapter_name='dreamo')
         
     | 
| 61 | 
         
            +
                    if cfg_diffuser_lora is not None:
         
     | 
| 62 | 
         
            +
                        self.load_lora_weights(cfg_diffuser_lora, adapter_name='cfg')
         
     | 
| 63 | 
         
            +
                        adapter_names.append('cfg')
         
     | 
| 64 | 
         
            +
                        adapter_weights.append(1)
         
     | 
| 65 | 
         
            +
                    if use_turbo:
         
     | 
| 66 | 
         
            +
                        self.load_lora_weights(
         
     | 
| 67 | 
         
            +
                            hf_hub_download(
         
     | 
| 68 | 
         
            +
                                "alimama-creative/FLUX.1-Turbo-Alpha", "diffusion_pytorch_model.safetensors", local_dir='models'
         
     | 
| 69 | 
         
            +
                            ),
         
     | 
| 70 | 
         
            +
                            adapter_name='turbo',
         
     | 
| 71 | 
         
            +
                        )
         
     | 
| 72 | 
         
            +
                        adapter_names.append('turbo')
         
     | 
| 73 | 
         
            +
                        adapter_weights.append(1)
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
                    self.fuse_lora(adapter_names=adapter_names, adapter_weights=adapter_weights, lora_scale=1)
         
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
                    self.t5_embedding = self.t5_embedding.to(device)
         
     | 
| 78 | 
         
            +
                    self.task_embedding = self.task_embedding.to(device)
         
     | 
| 79 | 
         
            +
                    self.idx_embedding = self.idx_embedding.to(device)
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
                def _prepare_t5(self):
         
     | 
| 82 | 
         
            +
                    self.text_encoder_2.resize_token_embeddings(len(self.tokenizer_2))
         
     | 
| 83 | 
         
            +
                    num_new_token = 10
         
     | 
| 84 | 
         
            +
                    new_token_list = [f"[ref#{i}]" for i in range(1, 10)] + ["[res]"]
         
     | 
| 85 | 
         
            +
                    self.tokenizer_2.add_tokens(new_token_list, special_tokens=False)
         
     | 
| 86 | 
         
            +
                    self.text_encoder_2.resize_token_embeddings(len(self.tokenizer_2))
         
     | 
| 87 | 
         
            +
                    input_embedding = self.text_encoder_2.get_input_embeddings().weight.data
         
     | 
| 88 | 
         
            +
                    input_embedding[-num_new_token:] = self.t5_embedding.weight.data
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
                @staticmethod
         
     | 
| 91 | 
         
            +
                def _prepare_latent_image_ids(batch_size, height, width, device, dtype, start_height=0, start_width=0):
         
     | 
| 92 | 
         
            +
                    latent_image_ids = torch.zeros(height // 2, width // 2, 3)
         
     | 
| 93 | 
         
            +
                    latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] + start_height
         
     | 
| 94 | 
         
            +
                    latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] + start_width
         
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
                    latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
         
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
            +
                    latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
         
     | 
| 99 | 
         
            +
                    latent_image_ids = latent_image_ids.reshape(
         
     | 
| 100 | 
         
            +
                        batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
         
     | 
| 101 | 
         
            +
                    )
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
                    return latent_image_ids.to(device=device, dtype=dtype)
         
     | 
| 104 | 
         
            +
             
     | 
| 105 | 
         
            +
                @staticmethod
         
     | 
| 106 | 
         
            +
                def _prepare_style_latent_image_ids(batch_size, height, width, device, dtype, start_height=0, start_width=0):
         
     | 
| 107 | 
         
            +
                    latent_image_ids = torch.zeros(height // 2, width // 2, 3)
         
     | 
| 108 | 
         
            +
                    latent_image_ids[..., 1] = latent_image_ids[..., 1] + start_height
         
     | 
| 109 | 
         
            +
                    latent_image_ids[..., 2] = latent_image_ids[..., 2] + start_width
         
     | 
| 110 | 
         
            +
             
     | 
| 111 | 
         
            +
                    latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
         
     | 
| 112 | 
         
            +
             
     | 
| 113 | 
         
            +
                    latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
         
     | 
| 114 | 
         
            +
                    latent_image_ids = latent_image_ids.reshape(
         
     | 
| 115 | 
         
            +
                        batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
         
     | 
| 116 | 
         
            +
                    )
         
     | 
| 117 | 
         
            +
             
     | 
| 118 | 
         
            +
                    return latent_image_ids.to(device=device, dtype=dtype)
         
     | 
| 119 | 
         
            +
             
     | 
| 120 | 
         
            +
                @torch.no_grad()
         
     | 
| 121 | 
         
            +
                def __call__(
         
     | 
| 122 | 
         
            +
                    self,
         
     | 
| 123 | 
         
            +
                    prompt: Union[str, List[str]] = None,
         
     | 
| 124 | 
         
            +
                    prompt_2: Optional[Union[str, List[str]]] = None,
         
     | 
| 125 | 
         
            +
                    negative_prompt: Union[str, List[str]] = None,
         
     | 
| 126 | 
         
            +
                    negative_prompt_2: Optional[Union[str, List[str]]] = None,
         
     | 
| 127 | 
         
            +
                    true_cfg_scale: float = 1.0,
         
     | 
| 128 | 
         
            +
                    true_cfg_start_step: int = 1,
         
     | 
| 129 | 
         
            +
                    true_cfg_end_step: int = 1,
         
     | 
| 130 | 
         
            +
                    height: Optional[int] = None,
         
     | 
| 131 | 
         
            +
                    width: Optional[int] = None,
         
     | 
| 132 | 
         
            +
                    num_inference_steps: int = 28,
         
     | 
| 133 | 
         
            +
                    sigmas: Optional[List[float]] = None,
         
     | 
| 134 | 
         
            +
                    guidance_scale: float = 3.5,
         
     | 
| 135 | 
         
            +
                    neg_guidance_scale: float = 3.5,
         
     | 
| 136 | 
         
            +
                    num_images_per_prompt: Optional[int] = 1,
         
     | 
| 137 | 
         
            +
                    generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
         
     | 
| 138 | 
         
            +
                    latents: Optional[torch.FloatTensor] = None,
         
     | 
| 139 | 
         
            +
                    prompt_embeds: Optional[torch.FloatTensor] = None,
         
     | 
| 140 | 
         
            +
                    pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
         
     | 
| 141 | 
         
            +
                    negative_prompt_embeds: Optional[torch.FloatTensor] = None,
         
     | 
| 142 | 
         
            +
                    negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
         
     | 
| 143 | 
         
            +
                    output_type: Optional[str] = "pil",
         
     | 
| 144 | 
         
            +
                    return_dict: bool = True,
         
     | 
| 145 | 
         
            +
                    joint_attention_kwargs: Optional[Dict[str, Any]] = None,
         
     | 
| 146 | 
         
            +
                    callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
         
     | 
| 147 | 
         
            +
                    callback_on_step_end_tensor_inputs: List[str] = ["latents"],
         
     | 
| 148 | 
         
            +
                    max_sequence_length: int = 512,
         
     | 
| 149 | 
         
            +
                    ref_conds=None,
         
     | 
| 150 | 
         
            +
                    first_step_guidance_scale=3.5,
         
     | 
| 151 | 
         
            +
                ):
         
     | 
| 152 | 
         
            +
                    r"""
         
     | 
| 153 | 
         
            +
                    Function invoked when calling the pipeline for generation.
         
     | 
| 154 | 
         
            +
             
     | 
| 155 | 
         
            +
                    Args:
         
     | 
| 156 | 
         
            +
                        prompt (`str` or `List[str]`, *optional*):
         
     | 
| 157 | 
         
            +
                            The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
         
     | 
| 158 | 
         
            +
                            instead.
         
     | 
| 159 | 
         
            +
                        prompt_2 (`str` or `List[str]`, *optional*):
         
     | 
| 160 | 
         
            +
                            The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
         
     | 
| 161 | 
         
            +
                            will be used instead.
         
     | 
| 162 | 
         
            +
                        negative_prompt (`str` or `List[str]`, *optional*):
         
     | 
| 163 | 
         
            +
                            The prompt or prompts not to guide the image generation. If not defined, one has to pass
         
     | 
| 164 | 
         
            +
                            `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
         
     | 
| 165 | 
         
            +
                            not greater than `1`).
         
     | 
| 166 | 
         
            +
                        negative_prompt_2 (`str` or `List[str]`, *optional*):
         
     | 
| 167 | 
         
            +
                            The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
         
     | 
| 168 | 
         
            +
                            `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
         
     | 
| 169 | 
         
            +
                        true_cfg_scale (`float`, *optional*, defaults to 1.0):
         
     | 
| 170 | 
         
            +
                            When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
         
     | 
| 171 | 
         
            +
                        height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
         
     | 
| 172 | 
         
            +
                            The height in pixels of the generated image. This is set to 1024 by default for the best results.
         
     | 
| 173 | 
         
            +
                        width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
         
     | 
| 174 | 
         
            +
                            The width in pixels of the generated image. This is set to 1024 by default for the best results.
         
     | 
| 175 | 
         
            +
                        num_inference_steps (`int`, *optional*, defaults to 50):
         
     | 
| 176 | 
         
            +
                            The number of denoising steps. More denoising steps usually lead to a higher quality image at the
         
     | 
| 177 | 
         
            +
                            expense of slower inference.
         
     | 
| 178 | 
         
            +
                        sigmas (`List[float]`, *optional*):
         
     | 
| 179 | 
         
            +
                            Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
         
     | 
| 180 | 
         
            +
                            their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
         
     | 
| 181 | 
         
            +
                            will be used.
         
     | 
| 182 | 
         
            +
                        guidance_scale (`float`, *optional*, defaults to 3.5):
         
     | 
| 183 | 
         
            +
                            Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
         
     | 
| 184 | 
         
            +
                            `guidance_scale` is defined as `w` of equation 2. of [Imagen
         
     | 
| 185 | 
         
            +
                            Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
         
     | 
| 186 | 
         
            +
                            1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
         
     | 
| 187 | 
         
            +
                            usually at the expense of lower image quality.
         
     | 
| 188 | 
         
            +
                        num_images_per_prompt (`int`, *optional*, defaults to 1):
         
     | 
| 189 | 
         
            +
                            The number of images to generate per prompt.
         
     | 
| 190 | 
         
            +
                        generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
         
     | 
| 191 | 
         
            +
                            One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
         
     | 
| 192 | 
         
            +
                            to make generation deterministic.
         
     | 
| 193 | 
         
            +
                        latents (`torch.FloatTensor`, *optional*):
         
     | 
| 194 | 
         
            +
                            Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
         
     | 
| 195 | 
         
            +
                            generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
         
     | 
| 196 | 
         
            +
                            tensor will ge generated by sampling using the supplied random `generator`.
         
     | 
| 197 | 
         
            +
                        prompt_embeds (`torch.FloatTensor`, *optional*):
         
     | 
| 198 | 
         
            +
                            Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
         
     | 
| 199 | 
         
            +
                            provided, text embeddings will be generated from `prompt` input argument.
         
     | 
| 200 | 
         
            +
                        pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
         
     | 
| 201 | 
         
            +
                            Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
         
     | 
| 202 | 
         
            +
                            If not provided, pooled text embeddings will be generated from `prompt` input argument.
         
     | 
| 203 | 
         
            +
                        negative_prompt_embeds (`torch.FloatTensor`, *optional*):
         
     | 
| 204 | 
         
            +
                            Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
         
     | 
| 205 | 
         
            +
                            weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
         
     | 
| 206 | 
         
            +
                            argument.
         
     | 
| 207 | 
         
            +
                        negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
         
     | 
| 208 | 
         
            +
                            Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
         
     | 
| 209 | 
         
            +
                            weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
         
     | 
| 210 | 
         
            +
                            input argument.
         
     | 
| 211 | 
         
            +
                        output_type (`str`, *optional*, defaults to `"pil"`):
         
     | 
| 212 | 
         
            +
                            The output format of the generate image. Choose between
         
     | 
| 213 | 
         
            +
                            [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
         
     | 
| 214 | 
         
            +
                        return_dict (`bool`, *optional*, defaults to `True`):
         
     | 
| 215 | 
         
            +
                            Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
         
     | 
| 216 | 
         
            +
                        joint_attention_kwargs (`dict`, *optional*):
         
     | 
| 217 | 
         
            +
                            A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
         
     | 
| 218 | 
         
            +
                            `self.processor` in
         
     | 
| 219 | 
         
            +
                            [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
         
     | 
| 220 | 
         
            +
                        callback_on_step_end (`Callable`, *optional*):
         
     | 
| 221 | 
         
            +
                            A function that calls at the end of each denoising steps during the inference. The function is called
         
     | 
| 222 | 
         
            +
                            with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
         
     | 
| 223 | 
         
            +
                            callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
         
     | 
| 224 | 
         
            +
                            `callback_on_step_end_tensor_inputs`.
         
     | 
| 225 | 
         
            +
                        callback_on_step_end_tensor_inputs (`List`, *optional*):
         
     | 
| 226 | 
         
            +
                            The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
         
     | 
| 227 | 
         
            +
                            will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
         
     | 
| 228 | 
         
            +
                            `._callback_tensor_inputs` attribute of your pipeline class.
         
     | 
| 229 | 
         
            +
                        max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
         
     | 
| 230 | 
         
            +
             
     | 
| 231 | 
         
            +
                    Examples:
         
     | 
| 232 | 
         
            +
             
     | 
| 233 | 
         
            +
                    Returns:
         
     | 
| 234 | 
         
            +
                        [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
         
     | 
| 235 | 
         
            +
                        is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
         
     | 
| 236 | 
         
            +
                        images.
         
     | 
| 237 | 
         
            +
                    """
         
     | 
| 238 | 
         
            +
             
     | 
| 239 | 
         
            +
                    height = height or self.default_sample_size * self.vae_scale_factor
         
     | 
| 240 | 
         
            +
                    width = width or self.default_sample_size * self.vae_scale_factor
         
     | 
| 241 | 
         
            +
             
     | 
| 242 | 
         
            +
                    # 1. Check inputs. Raise error if not correct
         
     | 
| 243 | 
         
            +
                    self.check_inputs(
         
     | 
| 244 | 
         
            +
                        prompt,
         
     | 
| 245 | 
         
            +
                        prompt_2,
         
     | 
| 246 | 
         
            +
                        height,
         
     | 
| 247 | 
         
            +
                        width,
         
     | 
| 248 | 
         
            +
                        prompt_embeds=prompt_embeds,
         
     | 
| 249 | 
         
            +
                        pooled_prompt_embeds=pooled_prompt_embeds,
         
     | 
| 250 | 
         
            +
                        callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
         
     | 
| 251 | 
         
            +
                        max_sequence_length=max_sequence_length,
         
     | 
| 252 | 
         
            +
                    )
         
     | 
| 253 | 
         
            +
             
     | 
| 254 | 
         
            +
                    self._guidance_scale = guidance_scale
         
     | 
| 255 | 
         
            +
                    self._joint_attention_kwargs = joint_attention_kwargs
         
     | 
| 256 | 
         
            +
                    self._current_timestep = None
         
     | 
| 257 | 
         
            +
                    self._interrupt = False
         
     | 
| 258 | 
         
            +
             
     | 
| 259 | 
         
            +
                    # 2. Define call parameters
         
     | 
| 260 | 
         
            +
                    if prompt is not None and isinstance(prompt, str):
         
     | 
| 261 | 
         
            +
                        batch_size = 1
         
     | 
| 262 | 
         
            +
                    elif prompt is not None and isinstance(prompt, list):
         
     | 
| 263 | 
         
            +
                        batch_size = len(prompt)
         
     | 
| 264 | 
         
            +
                    else:
         
     | 
| 265 | 
         
            +
                        batch_size = prompt_embeds.shape[0]
         
     | 
| 266 | 
         
            +
             
     | 
| 267 | 
         
            +
                    device = self._execution_device
         
     | 
| 268 | 
         
            +
             
     | 
| 269 | 
         
            +
                    lora_scale = (
         
     | 
| 270 | 
         
            +
                        self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
         
     | 
| 271 | 
         
            +
                    )
         
     | 
| 272 | 
         
            +
                    has_neg_prompt = negative_prompt is not None or (
         
     | 
| 273 | 
         
            +
                        negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
         
     | 
| 274 | 
         
            +
                    )
         
     | 
| 275 | 
         
            +
                    do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
         
     | 
| 276 | 
         
            +
                    (
         
     | 
| 277 | 
         
            +
                        prompt_embeds,
         
     | 
| 278 | 
         
            +
                        pooled_prompt_embeds,
         
     | 
| 279 | 
         
            +
                        text_ids,
         
     | 
| 280 | 
         
            +
                    ) = self.encode_prompt(
         
     | 
| 281 | 
         
            +
                        prompt=prompt,
         
     | 
| 282 | 
         
            +
                        prompt_2=prompt_2,
         
     | 
| 283 | 
         
            +
                        prompt_embeds=prompt_embeds,
         
     | 
| 284 | 
         
            +
                        pooled_prompt_embeds=pooled_prompt_embeds,
         
     | 
| 285 | 
         
            +
                        device=device,
         
     | 
| 286 | 
         
            +
                        num_images_per_prompt=num_images_per_prompt,
         
     | 
| 287 | 
         
            +
                        max_sequence_length=max_sequence_length,
         
     | 
| 288 | 
         
            +
                        lora_scale=lora_scale,
         
     | 
| 289 | 
         
            +
                    )
         
     | 
| 290 | 
         
            +
                    if do_true_cfg:
         
     | 
| 291 | 
         
            +
                        (
         
     | 
| 292 | 
         
            +
                            negative_prompt_embeds,
         
     | 
| 293 | 
         
            +
                            negative_pooled_prompt_embeds,
         
     | 
| 294 | 
         
            +
                            _,
         
     | 
| 295 | 
         
            +
                        ) = self.encode_prompt(
         
     | 
| 296 | 
         
            +
                            prompt=negative_prompt,
         
     | 
| 297 | 
         
            +
                            prompt_2=negative_prompt_2,
         
     | 
| 298 | 
         
            +
                            prompt_embeds=negative_prompt_embeds,
         
     | 
| 299 | 
         
            +
                            pooled_prompt_embeds=negative_pooled_prompt_embeds,
         
     | 
| 300 | 
         
            +
                            device=device,
         
     | 
| 301 | 
         
            +
                            num_images_per_prompt=num_images_per_prompt,
         
     | 
| 302 | 
         
            +
                            max_sequence_length=max_sequence_length,
         
     | 
| 303 | 
         
            +
                            lora_scale=lora_scale,
         
     | 
| 304 | 
         
            +
                        )
         
     | 
| 305 | 
         
            +
             
     | 
| 306 | 
         
            +
                    # 4. Prepare latent variables
         
     | 
| 307 | 
         
            +
                    num_channels_latents = self.transformer.config.in_channels // 4
         
     | 
| 308 | 
         
            +
                    latents, latent_image_ids = self.prepare_latents(
         
     | 
| 309 | 
         
            +
                        batch_size * num_images_per_prompt,
         
     | 
| 310 | 
         
            +
                        num_channels_latents,
         
     | 
| 311 | 
         
            +
                        height,
         
     | 
| 312 | 
         
            +
                        width,
         
     | 
| 313 | 
         
            +
                        prompt_embeds.dtype,
         
     | 
| 314 | 
         
            +
                        device,
         
     | 
| 315 | 
         
            +
                        generator,
         
     | 
| 316 | 
         
            +
                        latents,
         
     | 
| 317 | 
         
            +
                    )
         
     | 
| 318 | 
         
            +
             
     | 
| 319 | 
         
            +
                    # 4.1 concat ref tokens to latent
         
     | 
| 320 | 
         
            +
                    origin_img_len = latents.shape[1]
         
     | 
| 321 | 
         
            +
                    embeddings = repeat(self.task_embedding.weight[1], "c -> n l c", n=batch_size, l=origin_img_len)
         
     | 
| 322 | 
         
            +
                    ref_latents = []
         
     | 
| 323 | 
         
            +
                    ref_latent_image_idss = []
         
     | 
| 324 | 
         
            +
                    start_height = height // 16
         
     | 
| 325 | 
         
            +
                    start_width = width // 16
         
     | 
| 326 | 
         
            +
                    for ref_cond in ref_conds:
         
     | 
| 327 | 
         
            +
                        img = ref_cond['img']  # [b, 3, h, w], range [-1, 1]
         
     | 
| 328 | 
         
            +
                        task = ref_cond['task']
         
     | 
| 329 | 
         
            +
                        idx = ref_cond['idx']
         
     | 
| 330 | 
         
            +
             
     | 
| 331 | 
         
            +
                        # encode ref with VAE
         
     | 
| 332 | 
         
            +
                        img = img.to(latents)
         
     | 
| 333 | 
         
            +
                        ref_latent = self.vae.encode(img).latent_dist.sample()
         
     | 
| 334 | 
         
            +
                        ref_latent = (ref_latent - self.vae.config.shift_factor) * self.vae.config.scaling_factor
         
     | 
| 335 | 
         
            +
                        cur_height = ref_latent.shape[2]
         
     | 
| 336 | 
         
            +
                        cur_width = ref_latent.shape[3]
         
     | 
| 337 | 
         
            +
                        ref_latent = self._pack_latents(ref_latent, batch_size, num_channels_latents, cur_height, cur_width)
         
     | 
| 338 | 
         
            +
                        ref_latent_image_ids = self._prepare_latent_image_ids(
         
     | 
| 339 | 
         
            +
                            batch_size, cur_height, cur_width, device, prompt_embeds.dtype, start_height, start_width
         
     | 
| 340 | 
         
            +
                        )
         
     | 
| 341 | 
         
            +
                        start_height += cur_height // 2
         
     | 
| 342 | 
         
            +
                        start_width += cur_width // 2
         
     | 
| 343 | 
         
            +
             
     | 
| 344 | 
         
            +
                        # prepare task_idx_embedding
         
     | 
| 345 | 
         
            +
                        task_idx = get_task_embedding_idx(task)
         
     | 
| 346 | 
         
            +
                        cur_task_embedding = repeat(
         
     | 
| 347 | 
         
            +
                            self.task_embedding.weight[task_idx], "c -> n l c", n=batch_size, l=ref_latent.shape[1]
         
     | 
| 348 | 
         
            +
                        )
         
     | 
| 349 | 
         
            +
                        cur_idx_embedding = repeat(
         
     | 
| 350 | 
         
            +
                            self.idx_embedding.weight[idx], "c -> n l c", n=batch_size, l=ref_latent.shape[1]
         
     | 
| 351 | 
         
            +
                        )
         
     | 
| 352 | 
         
            +
                        cur_embedding = cur_task_embedding + cur_idx_embedding
         
     | 
| 353 | 
         
            +
             
     | 
| 354 | 
         
            +
                        # concat ref to latent
         
     | 
| 355 | 
         
            +
                        embeddings = torch.cat([embeddings, cur_embedding], dim=1)
         
     | 
| 356 | 
         
            +
                        ref_latents.append(ref_latent)
         
     | 
| 357 | 
         
            +
                        ref_latent_image_idss.append(ref_latent_image_ids)
         
     | 
| 358 | 
         
            +
             
     | 
| 359 | 
         
            +
                    # 5. Prepare timesteps
         
     | 
| 360 | 
         
            +
                    sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
         
     | 
| 361 | 
         
            +
                    image_seq_len = latents.shape[1]
         
     | 
| 362 | 
         
            +
                    mu = calculate_shift(
         
     | 
| 363 | 
         
            +
                        image_seq_len,
         
     | 
| 364 | 
         
            +
                        self.scheduler.config.get("base_image_seq_len", 256),
         
     | 
| 365 | 
         
            +
                        self.scheduler.config.get("max_image_seq_len", 4096),
         
     | 
| 366 | 
         
            +
                        self.scheduler.config.get("base_shift", 0.5),
         
     | 
| 367 | 
         
            +
                        self.scheduler.config.get("max_shift", 1.15),
         
     | 
| 368 | 
         
            +
                    )
         
     | 
| 369 | 
         
            +
                    timesteps, num_inference_steps = retrieve_timesteps(
         
     | 
| 370 | 
         
            +
                        self.scheduler,
         
     | 
| 371 | 
         
            +
                        num_inference_steps,
         
     | 
| 372 | 
         
            +
                        device,
         
     | 
| 373 | 
         
            +
                        sigmas=sigmas,
         
     | 
| 374 | 
         
            +
                        mu=mu,
         
     | 
| 375 | 
         
            +
                    )
         
     | 
| 376 | 
         
            +
                    num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
         
     | 
| 377 | 
         
            +
                    self._num_timesteps = len(timesteps)
         
     | 
| 378 | 
         
            +
             
     | 
| 379 | 
         
            +
                    # handle guidance
         
     | 
| 380 | 
         
            +
                    if self.transformer.config.guidance_embeds:
         
     | 
| 381 | 
         
            +
                        guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
         
     | 
| 382 | 
         
            +
                        guidance = guidance.expand(latents.shape[0])
         
     | 
| 383 | 
         
            +
                    else:
         
     | 
| 384 | 
         
            +
                        guidance = None
         
     | 
| 385 | 
         
            +
                    neg_guidance = torch.full([1], neg_guidance_scale, device=device, dtype=torch.float32)
         
     | 
| 386 | 
         
            +
                    neg_guidance = neg_guidance.expand(latents.shape[0])
         
     | 
| 387 | 
         
            +
                    first_step_guidance = torch.full([1], first_step_guidance_scale, device=device, dtype=torch.float32)
         
     | 
| 388 | 
         
            +
             
     | 
| 389 | 
         
            +
                    if self.joint_attention_kwargs is None:
         
     | 
| 390 | 
         
            +
                        self._joint_attention_kwargs = {}
         
     | 
| 391 | 
         
            +
             
     | 
| 392 | 
         
            +
                    # 6. Denoising loop
         
     | 
| 393 | 
         
            +
                    with self.progress_bar(total=num_inference_steps) as progress_bar:
         
     | 
| 394 | 
         
            +
                        for i, t in enumerate(timesteps):
         
     | 
| 395 | 
         
            +
                            if self.interrupt:
         
     | 
| 396 | 
         
            +
                                continue
         
     | 
| 397 | 
         
            +
             
     | 
| 398 | 
         
            +
                            self._current_timestep = t
         
     | 
| 399 | 
         
            +
                            # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
         
     | 
| 400 | 
         
            +
                            timestep = t.expand(latents.shape[0]).to(latents.dtype)
         
     | 
| 401 | 
         
            +
             
     | 
| 402 | 
         
            +
                            noise_pred = self.transformer(
         
     | 
| 403 | 
         
            +
                                hidden_states=torch.cat((latents, *ref_latents), dim=1),
         
     | 
| 404 | 
         
            +
                                timestep=timestep / 1000,
         
     | 
| 405 | 
         
            +
                                guidance=guidance if i > 0 else first_step_guidance,
         
     | 
| 406 | 
         
            +
                                pooled_projections=pooled_prompt_embeds,
         
     | 
| 407 | 
         
            +
                                encoder_hidden_states=prompt_embeds,
         
     | 
| 408 | 
         
            +
                                txt_ids=text_ids,
         
     | 
| 409 | 
         
            +
                                img_ids=torch.cat((latent_image_ids, *ref_latent_image_idss), dim=1),
         
     | 
| 410 | 
         
            +
                                joint_attention_kwargs=self.joint_attention_kwargs,
         
     | 
| 411 | 
         
            +
                                return_dict=False,
         
     | 
| 412 | 
         
            +
                                embeddings=embeddings,
         
     | 
| 413 | 
         
            +
                            )[0][:, :origin_img_len]
         
     | 
| 414 | 
         
            +
             
     | 
| 415 | 
         
            +
                            if do_true_cfg and i >= true_cfg_start_step and i < true_cfg_end_step:
         
     | 
| 416 | 
         
            +
                                neg_noise_pred = self.transformer(
         
     | 
| 417 | 
         
            +
                                    hidden_states=latents,
         
     | 
| 418 | 
         
            +
                                    timestep=timestep / 1000,
         
     | 
| 419 | 
         
            +
                                    guidance=neg_guidance,
         
     | 
| 420 | 
         
            +
                                    pooled_projections=negative_pooled_prompt_embeds,
         
     | 
| 421 | 
         
            +
                                    encoder_hidden_states=negative_prompt_embeds,
         
     | 
| 422 | 
         
            +
                                    txt_ids=text_ids,
         
     | 
| 423 | 
         
            +
                                    img_ids=latent_image_ids,
         
     | 
| 424 | 
         
            +
                                    joint_attention_kwargs=self.joint_attention_kwargs,
         
     | 
| 425 | 
         
            +
                                    return_dict=False,
         
     | 
| 426 | 
         
            +
                                )[0]
         
     | 
| 427 | 
         
            +
                                noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
         
     | 
| 428 | 
         
            +
             
     | 
| 429 | 
         
            +
                            # compute the previous noisy sample x_t -> x_t-1
         
     | 
| 430 | 
         
            +
                            latents_dtype = latents.dtype
         
     | 
| 431 | 
         
            +
                            latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
         
     | 
| 432 | 
         
            +
             
     | 
| 433 | 
         
            +
                            if latents.dtype != latents_dtype and torch.backends.mps.is_available():
         
     | 
| 434 | 
         
            +
                                # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
         
     | 
| 435 | 
         
            +
                                latents = latents.to(latents_dtype)
         
     | 
| 436 | 
         
            +
             
     | 
| 437 | 
         
            +
                            if callback_on_step_end is not None:
         
     | 
| 438 | 
         
            +
                                callback_kwargs = {}
         
     | 
| 439 | 
         
            +
                                for k in callback_on_step_end_tensor_inputs:
         
     | 
| 440 | 
         
            +
                                    callback_kwargs[k] = locals()[k]
         
     | 
| 441 | 
         
            +
                                callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
         
     | 
| 442 | 
         
            +
             
     | 
| 443 | 
         
            +
                                latents = callback_outputs.pop("latents", latents)
         
     | 
| 444 | 
         
            +
                                prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
         
     | 
| 445 | 
         
            +
             
     | 
| 446 | 
         
            +
                            # call the callback, if provided
         
     | 
| 447 | 
         
            +
                            if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
         
     | 
| 448 | 
         
            +
                                progress_bar.update()
         
     | 
| 449 | 
         
            +
             
     | 
| 450 | 
         
            +
                    self._current_timestep = None
         
     | 
| 451 | 
         
            +
             
     | 
| 452 | 
         
            +
                    if output_type == "latent":
         
     | 
| 453 | 
         
            +
                        image = latents
         
     | 
| 454 | 
         
            +
                    else:
         
     | 
| 455 | 
         
            +
                        latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
         
     | 
| 456 | 
         
            +
                        latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
         
     | 
| 457 | 
         
            +
                        image = self.vae.decode(latents, return_dict=False)[0]
         
     | 
| 458 | 
         
            +
                        image = self.image_processor.postprocess(image, output_type=output_type)
         
     | 
| 459 | 
         
            +
             
     | 
| 460 | 
         
            +
                    # Offload all models
         
     | 
| 461 | 
         
            +
                    self.maybe_free_model_hooks()
         
     | 
| 462 | 
         
            +
             
     | 
| 463 | 
         
            +
                    if not return_dict:
         
     | 
| 464 | 
         
            +
                        return (image,)
         
     | 
| 465 | 
         
            +
             
     | 
| 466 | 
         
            +
                    return FluxPipelineOutput(images=image)
         
     | 
    	
        dreamo/transformer.py
    ADDED
    
    | 
         @@ -0,0 +1,187 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
         
     | 
| 2 | 
         
            +
            # Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
         
     | 
| 3 | 
         
            +
            #
         
     | 
| 4 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 5 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 6 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 7 | 
         
            +
            #
         
     | 
| 8 | 
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 9 | 
         
            +
            #
         
     | 
| 10 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 11 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 12 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 13 | 
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 14 | 
         
            +
            # limitations under the License.
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            from typing import Any, Dict, Optional, Union
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            import numpy as np
         
     | 
| 19 | 
         
            +
            import torch
         
     | 
| 20 | 
         
            +
            from diffusers.models.modeling_outputs import Transformer2DModelOutput
         
     | 
| 21 | 
         
            +
            from diffusers.utils import (
         
     | 
| 22 | 
         
            +
                USE_PEFT_BACKEND,
         
     | 
| 23 | 
         
            +
                logging,
         
     | 
| 24 | 
         
            +
                scale_lora_layers,
         
     | 
| 25 | 
         
            +
                unscale_lora_layers,
         
     | 
| 26 | 
         
            +
            )
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
            logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
            def flux_transformer_forward(
         
     | 
| 32 | 
         
            +
                self,
         
     | 
| 33 | 
         
            +
                hidden_states: torch.Tensor,
         
     | 
| 34 | 
         
            +
                encoder_hidden_states: torch.Tensor = None,
         
     | 
| 35 | 
         
            +
                pooled_projections: torch.Tensor = None,
         
     | 
| 36 | 
         
            +
                timestep: torch.LongTensor = None,
         
     | 
| 37 | 
         
            +
                img_ids: torch.Tensor = None,
         
     | 
| 38 | 
         
            +
                txt_ids: torch.Tensor = None,
         
     | 
| 39 | 
         
            +
                guidance: torch.Tensor = None,
         
     | 
| 40 | 
         
            +
                joint_attention_kwargs: Optional[Dict[str, Any]] = None,
         
     | 
| 41 | 
         
            +
                controlnet_block_samples=None,
         
     | 
| 42 | 
         
            +
                controlnet_single_block_samples=None,
         
     | 
| 43 | 
         
            +
                return_dict: bool = True,
         
     | 
| 44 | 
         
            +
                controlnet_blocks_repeat: bool = False,
         
     | 
| 45 | 
         
            +
                embeddings: torch.Tensor = None,
         
     | 
| 46 | 
         
            +
            ) -> Union[torch.Tensor, Transformer2DModelOutput]:
         
     | 
| 47 | 
         
            +
                """
         
     | 
| 48 | 
         
            +
                The [`FluxTransformer2DModel`] forward method.
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
                Args:
         
     | 
| 51 | 
         
            +
                    hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
         
     | 
| 52 | 
         
            +
                        Input `hidden_states`.
         
     | 
| 53 | 
         
            +
                    encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
         
     | 
| 54 | 
         
            +
                        Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
         
     | 
| 55 | 
         
            +
                    pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected
         
     | 
| 56 | 
         
            +
                        from the embeddings of input conditions.
         
     | 
| 57 | 
         
            +
                    timestep ( `torch.LongTensor`):
         
     | 
| 58 | 
         
            +
                        Used to indicate denoising step.
         
     | 
| 59 | 
         
            +
                    block_controlnet_hidden_states: (`list` of `torch.Tensor`):
         
     | 
| 60 | 
         
            +
                        A list of tensors that if specified are added to the residuals of transformer blocks.
         
     | 
| 61 | 
         
            +
                    joint_attention_kwargs (`dict`, *optional*):
         
     | 
| 62 | 
         
            +
                        A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
         
     | 
| 63 | 
         
            +
                        `self.processor` in
         
     | 
| 64 | 
         
            +
                        [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
         
     | 
| 65 | 
         
            +
                    return_dict (`bool`, *optional*, defaults to `True`):
         
     | 
| 66 | 
         
            +
                        Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
         
     | 
| 67 | 
         
            +
                        tuple.
         
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
            +
                Returns:
         
     | 
| 70 | 
         
            +
                    If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
         
     | 
| 71 | 
         
            +
                    `tuple` where the first element is the sample tensor.
         
     | 
| 72 | 
         
            +
                """
         
     | 
| 73 | 
         
            +
                if joint_attention_kwargs is not None:
         
     | 
| 74 | 
         
            +
                    joint_attention_kwargs = joint_attention_kwargs.copy()
         
     | 
| 75 | 
         
            +
                    lora_scale = joint_attention_kwargs.pop("scale", 1.0)
         
     | 
| 76 | 
         
            +
                else:
         
     | 
| 77 | 
         
            +
                    lora_scale = 1.0
         
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
                if USE_PEFT_BACKEND:
         
     | 
| 80 | 
         
            +
                    # weight the lora layers by setting `lora_scale` for each PEFT layer
         
     | 
| 81 | 
         
            +
                    scale_lora_layers(self, lora_scale)
         
     | 
| 82 | 
         
            +
                else:
         
     | 
| 83 | 
         
            +
                    if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
         
     | 
| 84 | 
         
            +
                        logger.warning(
         
     | 
| 85 | 
         
            +
                            "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
         
     | 
| 86 | 
         
            +
                        )
         
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
                hidden_states = self.x_embedder(hidden_states)
         
     | 
| 89 | 
         
            +
                # add task and idx embedding
         
     | 
| 90 | 
         
            +
                if embeddings is not None:
         
     | 
| 91 | 
         
            +
                    hidden_states = hidden_states + embeddings
         
     | 
| 92 | 
         
            +
             
     | 
| 93 | 
         
            +
                timestep = timestep.to(hidden_states.dtype) * 1000
         
     | 
| 94 | 
         
            +
                guidance = guidance.to(hidden_states.dtype) * 1000 if guidance is not None else None
         
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
                temb = (
         
     | 
| 97 | 
         
            +
                    self.time_text_embed(timestep, pooled_projections)
         
     | 
| 98 | 
         
            +
                    if guidance is None
         
     | 
| 99 | 
         
            +
                    else self.time_text_embed(timestep, guidance, pooled_projections)
         
     | 
| 100 | 
         
            +
                )
         
     | 
| 101 | 
         
            +
                encoder_hidden_states = self.context_embedder(encoder_hidden_states)
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
                if txt_ids.ndim == 3:
         
     | 
| 104 | 
         
            +
                    # logger.warning(
         
     | 
| 105 | 
         
            +
                    #     "Passing `txt_ids` 3d torch.Tensor is deprecated."
         
     | 
| 106 | 
         
            +
                    #     "Please remove the batch dimension and pass it as a 2d torch Tensor"
         
     | 
| 107 | 
         
            +
                    # )
         
     | 
| 108 | 
         
            +
                    txt_ids = txt_ids[0]
         
     | 
| 109 | 
         
            +
                if img_ids.ndim == 3:
         
     | 
| 110 | 
         
            +
                    # logger.warning(
         
     | 
| 111 | 
         
            +
                    #     "Passing `img_ids` 3d torch.Tensor is deprecated."
         
     | 
| 112 | 
         
            +
                    #     "Please remove the batch dimension and pass it as a 2d torch Tensor"
         
     | 
| 113 | 
         
            +
                    # )
         
     | 
| 114 | 
         
            +
                    img_ids = img_ids[0]
         
     | 
| 115 | 
         
            +
             
     | 
| 116 | 
         
            +
                ids = torch.cat((txt_ids, img_ids), dim=0)
         
     | 
| 117 | 
         
            +
                image_rotary_emb = self.pos_embed(ids)
         
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
                for index_block, block in enumerate(self.transformer_blocks):
         
     | 
| 120 | 
         
            +
                    if torch.is_grad_enabled() and self.gradient_checkpointing:
         
     | 
| 121 | 
         
            +
                        encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
         
     | 
| 122 | 
         
            +
                            block,
         
     | 
| 123 | 
         
            +
                            hidden_states,
         
     | 
| 124 | 
         
            +
                            encoder_hidden_states,
         
     | 
| 125 | 
         
            +
                            temb,
         
     | 
| 126 | 
         
            +
                            image_rotary_emb,
         
     | 
| 127 | 
         
            +
                        )
         
     | 
| 128 | 
         
            +
             
     | 
| 129 | 
         
            +
                    else:
         
     | 
| 130 | 
         
            +
                        encoder_hidden_states, hidden_states = block(
         
     | 
| 131 | 
         
            +
                            hidden_states=hidden_states,
         
     | 
| 132 | 
         
            +
                            encoder_hidden_states=encoder_hidden_states,
         
     | 
| 133 | 
         
            +
                            temb=temb,
         
     | 
| 134 | 
         
            +
                            image_rotary_emb=image_rotary_emb,
         
     | 
| 135 | 
         
            +
                            joint_attention_kwargs=joint_attention_kwargs,
         
     | 
| 136 | 
         
            +
                        )
         
     | 
| 137 | 
         
            +
             
     | 
| 138 | 
         
            +
                    # controlnet residual
         
     | 
| 139 | 
         
            +
                    if controlnet_block_samples is not None:
         
     | 
| 140 | 
         
            +
                        interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
         
     | 
| 141 | 
         
            +
                        interval_control = int(np.ceil(interval_control))
         
     | 
| 142 | 
         
            +
                        # For Xlabs ControlNet.
         
     | 
| 143 | 
         
            +
                        if controlnet_blocks_repeat:
         
     | 
| 144 | 
         
            +
                            hidden_states = hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
         
     | 
| 145 | 
         
            +
                        else:
         
     | 
| 146 | 
         
            +
                            hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
         
     | 
| 147 | 
         
            +
                hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
         
     | 
| 148 | 
         
            +
             
     | 
| 149 | 
         
            +
                for index_block, block in enumerate(self.single_transformer_blocks):
         
     | 
| 150 | 
         
            +
                    if torch.is_grad_enabled() and self.gradient_checkpointing:
         
     | 
| 151 | 
         
            +
                        hidden_states = self._gradient_checkpointing_func(
         
     | 
| 152 | 
         
            +
                            block,
         
     | 
| 153 | 
         
            +
                            hidden_states,
         
     | 
| 154 | 
         
            +
                            temb,
         
     | 
| 155 | 
         
            +
                            image_rotary_emb,
         
     | 
| 156 | 
         
            +
                        )
         
     | 
| 157 | 
         
            +
             
     | 
| 158 | 
         
            +
                    else:
         
     | 
| 159 | 
         
            +
                        hidden_states = block(
         
     | 
| 160 | 
         
            +
                            hidden_states=hidden_states,
         
     | 
| 161 | 
         
            +
                            temb=temb,
         
     | 
| 162 | 
         
            +
                            image_rotary_emb=image_rotary_emb,
         
     | 
| 163 | 
         
            +
                            joint_attention_kwargs=joint_attention_kwargs,
         
     | 
| 164 | 
         
            +
                        )
         
     | 
| 165 | 
         
            +
             
     | 
| 166 | 
         
            +
                    # controlnet residual
         
     | 
| 167 | 
         
            +
                    if controlnet_single_block_samples is not None:
         
     | 
| 168 | 
         
            +
                        interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
         
     | 
| 169 | 
         
            +
                        interval_control = int(np.ceil(interval_control))
         
     | 
| 170 | 
         
            +
                        hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
         
     | 
| 171 | 
         
            +
                            hidden_states[:, encoder_hidden_states.shape[1] :, ...]
         
     | 
| 172 | 
         
            +
                            + controlnet_single_block_samples[index_block // interval_control]
         
     | 
| 173 | 
         
            +
                        )
         
     | 
| 174 | 
         
            +
             
     | 
| 175 | 
         
            +
                hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
         
     | 
| 176 | 
         
            +
             
     | 
| 177 | 
         
            +
                hidden_states = self.norm_out(hidden_states, temb)
         
     | 
| 178 | 
         
            +
                output = self.proj_out(hidden_states)
         
     | 
| 179 | 
         
            +
             
     | 
| 180 | 
         
            +
                if USE_PEFT_BACKEND:
         
     | 
| 181 | 
         
            +
                    # remove `lora_scale` from each PEFT layer
         
     | 
| 182 | 
         
            +
                    unscale_lora_layers(self, lora_scale)
         
     | 
| 183 | 
         
            +
             
     | 
| 184 | 
         
            +
                if not return_dict:
         
     | 
| 185 | 
         
            +
                    return (output,)
         
     | 
| 186 | 
         
            +
             
     | 
| 187 | 
         
            +
                return Transformer2DModelOutput(sample=output)
         
     | 
    	
        dreamo/utils.py
    ADDED
    
    | 
         @@ -0,0 +1,222 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 4 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 5 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 6 | 
         
            +
            #
         
     | 
| 7 | 
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 8 | 
         
            +
            #
         
     | 
| 9 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 10 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 11 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 12 | 
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 13 | 
         
            +
            # limitations under the License.
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            import math
         
     | 
| 16 | 
         
            +
            import re
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            import cv2
         
     | 
| 19 | 
         
            +
            import numpy as np
         
     | 
| 20 | 
         
            +
            import torch
         
     | 
| 21 | 
         
            +
            from torchvision.utils import make_grid
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            # from basicsr
         
     | 
| 25 | 
         
            +
            def img2tensor(imgs, bgr2rgb=True, float32=True):
         
     | 
| 26 | 
         
            +
                """Numpy array to tensor.
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
                Args:
         
     | 
| 29 | 
         
            +
                    imgs (list[ndarray] | ndarray): Input images.
         
     | 
| 30 | 
         
            +
                    bgr2rgb (bool): Whether to change bgr to rgb.
         
     | 
| 31 | 
         
            +
                    float32 (bool): Whether to change to float32.
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
                Returns:
         
     | 
| 34 | 
         
            +
                    list[tensor] | tensor: Tensor images. If returned results only have
         
     | 
| 35 | 
         
            +
                        one element, just return tensor.
         
     | 
| 36 | 
         
            +
                """
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
                def _totensor(img, bgr2rgb, float32):
         
     | 
| 39 | 
         
            +
                    if img.shape[2] == 3 and bgr2rgb:
         
     | 
| 40 | 
         
            +
                        if img.dtype == 'float64':
         
     | 
| 41 | 
         
            +
                            img = img.astype('float32')
         
     | 
| 42 | 
         
            +
                        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
         
     | 
| 43 | 
         
            +
                    img = torch.from_numpy(img.transpose(2, 0, 1))
         
     | 
| 44 | 
         
            +
                    if float32:
         
     | 
| 45 | 
         
            +
                        img = img.float()
         
     | 
| 46 | 
         
            +
                    return img
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
                if isinstance(imgs, list):
         
     | 
| 49 | 
         
            +
                    return [_totensor(img, bgr2rgb, float32) for img in imgs]
         
     | 
| 50 | 
         
            +
                return _totensor(imgs, bgr2rgb, float32)
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
            def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
         
     | 
| 54 | 
         
            +
                """Convert torch Tensors into image numpy arrays.
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
                After clamping to [min, max], values will be normalized to [0, 1].
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
                Args:
         
     | 
| 59 | 
         
            +
                    tensor (Tensor or list[Tensor]): Accept shapes:
         
     | 
| 60 | 
         
            +
                        1) 4D mini-batch Tensor of shape (B x 3/1 x H x W);
         
     | 
| 61 | 
         
            +
                        2) 3D Tensor of shape (3/1 x H x W);
         
     | 
| 62 | 
         
            +
                        3) 2D Tensor of shape (H x W).
         
     | 
| 63 | 
         
            +
                        Tensor channel should be in RGB order.
         
     | 
| 64 | 
         
            +
                    rgb2bgr (bool): Whether to change rgb to bgr.
         
     | 
| 65 | 
         
            +
                    out_type (numpy type): output types. If ``np.uint8``, transform outputs
         
     | 
| 66 | 
         
            +
                        to uint8 type with range [0, 255]; otherwise, float type with
         
     | 
| 67 | 
         
            +
                        range [0, 1]. Default: ``np.uint8``.
         
     | 
| 68 | 
         
            +
                    min_max (tuple[int]): min and max values for clamp.
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
                Returns:
         
     | 
| 71 | 
         
            +
                    (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of
         
     | 
| 72 | 
         
            +
                    shape (H x W). The channel order is BGR.
         
     | 
| 73 | 
         
            +
                """
         
     | 
| 74 | 
         
            +
                if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
         
     | 
| 75 | 
         
            +
                    raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}')
         
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
                if torch.is_tensor(tensor):
         
     | 
| 78 | 
         
            +
                    tensor = [tensor]
         
     | 
| 79 | 
         
            +
                result = []
         
     | 
| 80 | 
         
            +
                for _tensor in tensor:
         
     | 
| 81 | 
         
            +
                    _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
         
     | 
| 82 | 
         
            +
                    _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
         
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
                    n_dim = _tensor.dim()
         
     | 
| 85 | 
         
            +
                    if n_dim == 4:
         
     | 
| 86 | 
         
            +
                        img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy()
         
     | 
| 87 | 
         
            +
                        img_np = img_np.transpose(1, 2, 0)
         
     | 
| 88 | 
         
            +
                        if rgb2bgr:
         
     | 
| 89 | 
         
            +
                            img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
         
     | 
| 90 | 
         
            +
                    elif n_dim == 3:
         
     | 
| 91 | 
         
            +
                        img_np = _tensor.numpy()
         
     | 
| 92 | 
         
            +
                        img_np = img_np.transpose(1, 2, 0)
         
     | 
| 93 | 
         
            +
                        if img_np.shape[2] == 1:  # gray image
         
     | 
| 94 | 
         
            +
                            img_np = np.squeeze(img_np, axis=2)
         
     | 
| 95 | 
         
            +
                        else:
         
     | 
| 96 | 
         
            +
                            if rgb2bgr:
         
     | 
| 97 | 
         
            +
                                img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
         
     | 
| 98 | 
         
            +
                    elif n_dim == 2:
         
     | 
| 99 | 
         
            +
                        img_np = _tensor.numpy()
         
     | 
| 100 | 
         
            +
                    else:
         
     | 
| 101 | 
         
            +
                        raise TypeError(f'Only support 4D, 3D or 2D tensor. But received with dimension: {n_dim}')
         
     | 
| 102 | 
         
            +
                    if out_type == np.uint8:
         
     | 
| 103 | 
         
            +
                        # Unlike MATLAB, numpy.unit8() WILL NOT round by default.
         
     | 
| 104 | 
         
            +
                        img_np = (img_np * 255.0).round()
         
     | 
| 105 | 
         
            +
                    img_np = img_np.astype(out_type)
         
     | 
| 106 | 
         
            +
                    result.append(img_np)
         
     | 
| 107 | 
         
            +
                if len(result) == 1:
         
     | 
| 108 | 
         
            +
                    result = result[0]
         
     | 
| 109 | 
         
            +
                return result
         
     | 
| 110 | 
         
            +
             
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
            def resize_numpy_image_area(image, area=512 * 512):
         
     | 
| 113 | 
         
            +
                h, w = image.shape[:2]
         
     | 
| 114 | 
         
            +
                k = math.sqrt(area / (h * w))
         
     | 
| 115 | 
         
            +
                h = int(h * k) - (int(h * k) % 16)
         
     | 
| 116 | 
         
            +
                w = int(w * k) - (int(w * k) % 16)
         
     | 
| 117 | 
         
            +
                image = cv2.resize(image, (w, h), interpolation=cv2.INTER_AREA)
         
     | 
| 118 | 
         
            +
                return image
         
     | 
| 119 | 
         
            +
             
     | 
| 120 | 
         
            +
             
     | 
| 121 | 
         
            +
            # reference: https://github.com/huggingface/diffusers/pull/9295/files
         
     | 
| 122 | 
         
            +
            def convert_flux_lora_to_diffusers(old_state_dict):
         
     | 
| 123 | 
         
            +
                new_state_dict = {}
         
     | 
| 124 | 
         
            +
                orig_keys = list(old_state_dict.keys())
         
     | 
| 125 | 
         
            +
             
     | 
| 126 | 
         
            +
                def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
         
     | 
| 127 | 
         
            +
                    down_weight = sds_sd.pop(sds_key)
         
     | 
| 128 | 
         
            +
                    up_weight = sds_sd.pop(sds_key.replace(".down.weight", ".up.weight"))
         
     | 
| 129 | 
         
            +
             
     | 
| 130 | 
         
            +
                    # calculate dims if not provided
         
     | 
| 131 | 
         
            +
                    num_splits = len(ait_keys)
         
     | 
| 132 | 
         
            +
                    if dims is None:
         
     | 
| 133 | 
         
            +
                        dims = [up_weight.shape[0] // num_splits] * num_splits
         
     | 
| 134 | 
         
            +
                    else:
         
     | 
| 135 | 
         
            +
                        assert sum(dims) == up_weight.shape[0]
         
     | 
| 136 | 
         
            +
             
     | 
| 137 | 
         
            +
                    # make ai-toolkit weight
         
     | 
| 138 | 
         
            +
                    ait_down_keys = [k + ".lora_A.weight" for k in ait_keys]
         
     | 
| 139 | 
         
            +
                    ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
         
     | 
| 140 | 
         
            +
             
     | 
| 141 | 
         
            +
                    # down_weight is copied to each split
         
     | 
| 142 | 
         
            +
                    ait_sd.update({k: down_weight for k in ait_down_keys})
         
     | 
| 143 | 
         
            +
             
     | 
| 144 | 
         
            +
                    # up_weight is split to each split
         
     | 
| 145 | 
         
            +
                    ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))})  # noqa: C416
         
     | 
| 146 | 
         
            +
             
     | 
| 147 | 
         
            +
                for old_key in orig_keys:
         
     | 
| 148 | 
         
            +
                    # Handle double_blocks
         
     | 
| 149 | 
         
            +
                    if 'double_blocks' in old_key:
         
     | 
| 150 | 
         
            +
                        block_num = re.search(r"double_blocks_(\d+)", old_key).group(1)
         
     | 
| 151 | 
         
            +
                        new_key = f"transformer.transformer_blocks.{block_num}"
         
     | 
| 152 | 
         
            +
             
     | 
| 153 | 
         
            +
                        if "proj_lora1" in old_key:
         
     | 
| 154 | 
         
            +
                            new_key += ".attn.to_out.0"
         
     | 
| 155 | 
         
            +
                        elif "proj_lora2" in old_key:
         
     | 
| 156 | 
         
            +
                            new_key += ".attn.to_add_out"
         
     | 
| 157 | 
         
            +
                        elif "qkv_lora2" in old_key and "up" not in old_key:
         
     | 
| 158 | 
         
            +
                            handle_qkv(
         
     | 
| 159 | 
         
            +
                                old_state_dict,
         
     | 
| 160 | 
         
            +
                                new_state_dict,
         
     | 
| 161 | 
         
            +
                                old_key,
         
     | 
| 162 | 
         
            +
                                [
         
     | 
| 163 | 
         
            +
                                    f"transformer.transformer_blocks.{block_num}.attn.add_q_proj",
         
     | 
| 164 | 
         
            +
                                    f"transformer.transformer_blocks.{block_num}.attn.add_k_proj",
         
     | 
| 165 | 
         
            +
                                    f"transformer.transformer_blocks.{block_num}.attn.add_v_proj",
         
     | 
| 166 | 
         
            +
                                ],
         
     | 
| 167 | 
         
            +
                            )
         
     | 
| 168 | 
         
            +
                            # continue
         
     | 
| 169 | 
         
            +
                        elif "qkv_lora1" in old_key and "up" not in old_key:
         
     | 
| 170 | 
         
            +
                            handle_qkv(
         
     | 
| 171 | 
         
            +
                                old_state_dict,
         
     | 
| 172 | 
         
            +
                                new_state_dict,
         
     | 
| 173 | 
         
            +
                                old_key,
         
     | 
| 174 | 
         
            +
                                [
         
     | 
| 175 | 
         
            +
                                    f"transformer.transformer_blocks.{block_num}.attn.to_q",
         
     | 
| 176 | 
         
            +
                                    f"transformer.transformer_blocks.{block_num}.attn.to_k",
         
     | 
| 177 | 
         
            +
                                    f"transformer.transformer_blocks.{block_num}.attn.to_v",
         
     | 
| 178 | 
         
            +
                                ],
         
     | 
| 179 | 
         
            +
                            )
         
     | 
| 180 | 
         
            +
                            # continue
         
     | 
| 181 | 
         
            +
             
     | 
| 182 | 
         
            +
                        if "down" in old_key:
         
     | 
| 183 | 
         
            +
                            new_key += ".lora_A.weight"
         
     | 
| 184 | 
         
            +
                        elif "up" in old_key:
         
     | 
| 185 | 
         
            +
                            new_key += ".lora_B.weight"
         
     | 
| 186 | 
         
            +
             
     | 
| 187 | 
         
            +
                    # Handle single_blocks
         
     | 
| 188 | 
         
            +
                    elif 'single_blocks' in old_key:
         
     | 
| 189 | 
         
            +
                        block_num = re.search(r"single_blocks_(\d+)", old_key).group(1)
         
     | 
| 190 | 
         
            +
                        new_key = f"transformer.single_transformer_blocks.{block_num}"
         
     | 
| 191 | 
         
            +
             
     | 
| 192 | 
         
            +
                        if "proj_lora" in old_key:
         
     | 
| 193 | 
         
            +
                            new_key += ".proj_out"
         
     | 
| 194 | 
         
            +
                        elif "qkv_lora" in old_key and "up" not in old_key:
         
     | 
| 195 | 
         
            +
                            handle_qkv(
         
     | 
| 196 | 
         
            +
                                old_state_dict,
         
     | 
| 197 | 
         
            +
                                new_state_dict,
         
     | 
| 198 | 
         
            +
                                old_key,
         
     | 
| 199 | 
         
            +
                                [
         
     | 
| 200 | 
         
            +
                                    f"transformer.single_transformer_blocks.{block_num}.attn.to_q",
         
     | 
| 201 | 
         
            +
                                    f"transformer.single_transformer_blocks.{block_num}.attn.to_k",
         
     | 
| 202 | 
         
            +
                                    f"transformer.single_transformer_blocks.{block_num}.attn.to_v",
         
     | 
| 203 | 
         
            +
                                ],
         
     | 
| 204 | 
         
            +
                            )
         
     | 
| 205 | 
         
            +
             
     | 
| 206 | 
         
            +
                        if "down" in old_key:
         
     | 
| 207 | 
         
            +
                            new_key += ".lora_A.weight"
         
     | 
| 208 | 
         
            +
                        elif "up" in old_key:
         
     | 
| 209 | 
         
            +
                            new_key += ".lora_B.weight"
         
     | 
| 210 | 
         
            +
             
     | 
| 211 | 
         
            +
                    else:
         
     | 
| 212 | 
         
            +
                        # Handle other potential key patterns here
         
     | 
| 213 | 
         
            +
                        new_key = old_key
         
     | 
| 214 | 
         
            +
             
     | 
| 215 | 
         
            +
                    # Since we already handle qkv above.
         
     | 
| 216 | 
         
            +
                    if "qkv" not in old_key and 'embedding' not in old_key:
         
     | 
| 217 | 
         
            +
                        new_state_dict[new_key] = old_state_dict.pop(old_key)
         
     | 
| 218 | 
         
            +
             
     | 
| 219 | 
         
            +
                # if len(old_state_dict) > 0:
         
     | 
| 220 | 
         
            +
                #     raise ValueError(f"`old_state_dict` should be at this point but has: {list(old_state_dict.keys())}.")
         
     | 
| 221 | 
         
            +
             
     | 
| 222 | 
         
            +
                return new_state_dict
         
     | 
    	
        example_inputs/cat.png
    ADDED
    
    
											 
									 | 
									
								
											Git LFS Details
  | 
									
    	
        example_inputs/dog1.png
    ADDED
    
    
											 
									 | 
									
								
											Git LFS Details
  | 
									
    	
        example_inputs/dog2.png
    ADDED
    
    
											 
									 | 
									
								
											Git LFS Details
  | 
									
    	
        example_inputs/dress.png
    ADDED
    
    
											 
									 | 
									
								
											Git LFS Details
  | 
									
    	
        example_inputs/hinton.jpeg
    ADDED
    
    
											 
									 | 
									
								
											Git LFS Details
  | 
									
    	
        example_inputs/man1.png
    ADDED
    
    
											 
									 | 
									
								
											Git LFS Details
  | 
									
    	
        example_inputs/man2.jpeg
    ADDED
    
    
											 
									 | 
									
								
											Git LFS Details
  | 
									
    	
        example_inputs/mickey.png
    ADDED
    
    
											 
									 | 
									
								
											Git LFS Details
  | 
									
    	
        example_inputs/mountain.png
    ADDED
    
    
											 
									 | 
									
								
											Git LFS Details
  | 
									
    	
        example_inputs/perfume.png
    ADDED
    
    
											 
									 | 
									
								
											Git LFS Details
  | 
									
    	
        example_inputs/shirt.png
    ADDED
    
    
											 
									 | 
									
								
											Git LFS Details
  | 
									
    	
        example_inputs/skirt.jpeg
    ADDED
    
    
											 
									 | 
									
								
											Git LFS Details
  | 
									
    	
        example_inputs/toy1.png
    ADDED
    
    
											 
									 | 
									
								
											Git LFS Details
  | 
									
    	
        example_inputs/woman1.png
    ADDED
    
    
											 
									 | 
									
								
											Git LFS Details
  | 
									
    	
        example_inputs/woman2.png
    ADDED
    
    
											 
									 | 
									
								
											Git LFS Details
  | 
									
    	
        example_inputs/woman3.png
    ADDED
    
    
											 
									 | 
									
								
											Git LFS Details
  | 
									
    	
        example_inputs/woman4.jpeg
    ADDED
    
    
											 
									 | 
									
								
											Git LFS Details
  | 
									
    	
        models/.gitkeep
    ADDED
    
    | 
         
            File without changes
         
     | 
    	
        pyproject.toml
    ADDED
    
    | 
         @@ -0,0 +1,29 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            [tool.ruff]
         
     | 
| 2 | 
         
            +
            line-length = 120
         
     | 
| 3 | 
         
            +
            exclude = ['tools']
         
     | 
| 4 | 
         
            +
            # A list of file patterns to omit from linting, in addition to those specified by exclude.
         
     | 
| 5 | 
         
            +
            extend-exclude = ["__pycache__", "*.pyc", "*.egg-info", ".cache"]
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            select = ["E", "F", "W", "C90", "I", "UP", "B", "C4", "RET", "RUF", "SIM"]
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            ignore = [
         
     | 
| 11 | 
         
            +
                "UP006",    # UP006: Use list instead of typing.List for type annotations
         
     | 
| 12 | 
         
            +
                "UP007",    # UP007: Use X | Y for type annotations
         
     | 
| 13 | 
         
            +
                "UP009",
         
     | 
| 14 | 
         
            +
                "UP035",
         
     | 
| 15 | 
         
            +
                "UP038",
         
     | 
| 16 | 
         
            +
                "E402",
         
     | 
| 17 | 
         
            +
                "RET504",
         
     | 
| 18 | 
         
            +
                "C901",
         
     | 
| 19 | 
         
            +
                "RUF013",
         
     | 
| 20 | 
         
            +
                "B006",
         
     | 
| 21 | 
         
            +
            ]
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
            [tool.isort]
         
     | 
| 24 | 
         
            +
            profile = "black"
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
            [tool.black]
         
     | 
| 27 | 
         
            +
            line-length = 119
         
     | 
| 28 | 
         
            +
            skip-string-normalization = 1
         
     | 
| 29 | 
         
            +
            exclude = 'tools'
         
     | 
    	
        requirements.txt
    ADDED
    
    | 
         @@ -0,0 +1,12 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            --extra-index-url https://download.pytorch.org/whl/cu118
         
     | 
| 2 | 
         
            +
            torch==2.3.1+cu118
         
     | 
| 3 | 
         
            +
            torchvision==0.18.1+cu118
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            diffusers==0.31.0
         
     | 
| 6 | 
         
            +
            transformers==4.45.2
         
     | 
| 7 | 
         
            +
            sentencepiece
         
     | 
| 8 | 
         
            +
            spaces
         
     | 
| 9 | 
         
            +
            huggingface_hub
         
     | 
| 10 | 
         
            +
            accelerate==0.32.0
         
     | 
| 11 | 
         
            +
            peft
         
     | 
| 12 | 
         
            +
            git+https://github.com/ToTheBeginning/facexlib.git
         
     | 
    	
        tools/BEN2.py
    ADDED
    
    | 
         @@ -0,0 +1,1359 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2025 Prama LLC
         
     | 
| 2 | 
         
            +
            # SPDX-License-Identifier: MIT
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            import math
         
     | 
| 5 | 
         
            +
            import os
         
     | 
| 6 | 
         
            +
            import random
         
     | 
| 7 | 
         
            +
            import subprocess
         
     | 
| 8 | 
         
            +
            import tempfile
         
     | 
| 9 | 
         
            +
            import time
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            import cv2
         
     | 
| 12 | 
         
            +
            import numpy as np
         
     | 
| 13 | 
         
            +
            import torch
         
     | 
| 14 | 
         
            +
            import torch.nn as nn
         
     | 
| 15 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 16 | 
         
            +
            import torch.utils.checkpoint as checkpoint
         
     | 
| 17 | 
         
            +
            from einops import rearrange
         
     | 
| 18 | 
         
            +
            from PIL import Image, ImageOps
         
     | 
| 19 | 
         
            +
            from timm.models.layers import DropPath, to_2tuple, trunc_normal_
         
     | 
| 20 | 
         
            +
            from torchvision import transforms
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
            def set_random_seed(seed):
         
     | 
| 24 | 
         
            +
                random.seed(seed)
         
     | 
| 25 | 
         
            +
                np.random.seed(seed)
         
     | 
| 26 | 
         
            +
                torch.manual_seed(seed)
         
     | 
| 27 | 
         
            +
                torch.cuda.manual_seed(seed)
         
     | 
| 28 | 
         
            +
                torch.cuda.manual_seed_all(seed)
         
     | 
| 29 | 
         
            +
                torch.backends.cudnn.deterministic = True
         
     | 
| 30 | 
         
            +
                torch.backends.cudnn.benchmark = False
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
            # set_random_seed(9)
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
            torch.set_float32_matmul_precision('highest')
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
            class Mlp(nn.Module):
         
     | 
| 39 | 
         
            +
                """ Multilayer perceptron."""
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
         
     | 
| 42 | 
         
            +
                    super().__init__()
         
     | 
| 43 | 
         
            +
                    out_features = out_features or in_features
         
     | 
| 44 | 
         
            +
                    hidden_features = hidden_features or in_features
         
     | 
| 45 | 
         
            +
                    self.fc1 = nn.Linear(in_features, hidden_features)
         
     | 
| 46 | 
         
            +
                    self.act = act_layer()
         
     | 
| 47 | 
         
            +
                    self.fc2 = nn.Linear(hidden_features, out_features)
         
     | 
| 48 | 
         
            +
                    self.drop = nn.Dropout(drop)
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
                def forward(self, x):
         
     | 
| 51 | 
         
            +
                    x = self.fc1(x)
         
     | 
| 52 | 
         
            +
                    x = self.act(x)
         
     | 
| 53 | 
         
            +
                    x = self.drop(x)
         
     | 
| 54 | 
         
            +
                    x = self.fc2(x)
         
     | 
| 55 | 
         
            +
                    x = self.drop(x)
         
     | 
| 56 | 
         
            +
                    return x
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
            def window_partition(x, window_size):
         
     | 
| 60 | 
         
            +
                """
         
     | 
| 61 | 
         
            +
                Args:
         
     | 
| 62 | 
         
            +
                    x: (B, H, W, C)
         
     | 
| 63 | 
         
            +
                    window_size (int): window size
         
     | 
| 64 | 
         
            +
                Returns:
         
     | 
| 65 | 
         
            +
                    windows: (num_windows*B, window_size, window_size, C)
         
     | 
| 66 | 
         
            +
                """
         
     | 
| 67 | 
         
            +
                B, H, W, C = x.shape
         
     | 
| 68 | 
         
            +
                x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
         
     | 
| 69 | 
         
            +
                windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
         
     | 
| 70 | 
         
            +
                return windows
         
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
            def window_reverse(windows, window_size, H, W):
         
     | 
| 74 | 
         
            +
                """
         
     | 
| 75 | 
         
            +
                Args:
         
     | 
| 76 | 
         
            +
                    windows: (num_windows*B, window_size, window_size, C)
         
     | 
| 77 | 
         
            +
                    window_size (int): Window size
         
     | 
| 78 | 
         
            +
                    H (int): Height of image
         
     | 
| 79 | 
         
            +
                    W (int): Width of image
         
     | 
| 80 | 
         
            +
                Returns:
         
     | 
| 81 | 
         
            +
                    x: (B, H, W, C)
         
     | 
| 82 | 
         
            +
                """
         
     | 
| 83 | 
         
            +
                B = int(windows.shape[0] / (H * W / window_size / window_size))
         
     | 
| 84 | 
         
            +
                x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
         
     | 
| 85 | 
         
            +
                x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
         
     | 
| 86 | 
         
            +
                return x
         
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
            class WindowAttention(nn.Module):
         
     | 
| 90 | 
         
            +
                """ Window based multi-head self attention (W-MSA) module with relative position bias.
         
     | 
| 91 | 
         
            +
                It supports both of shifted and non-shifted window.
         
     | 
| 92 | 
         
            +
                Args:
         
     | 
| 93 | 
         
            +
                    dim (int): Number of input channels.
         
     | 
| 94 | 
         
            +
                    window_size (tuple[int]): The height and width of the window.
         
     | 
| 95 | 
         
            +
                    num_heads (int): Number of attention heads.
         
     | 
| 96 | 
         
            +
                    qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
         
     | 
| 97 | 
         
            +
                    qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
         
     | 
| 98 | 
         
            +
                    attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
         
     | 
| 99 | 
         
            +
                    proj_drop (float, optional): Dropout ratio of output. Default: 0.0
         
     | 
| 100 | 
         
            +
                """
         
     | 
| 101 | 
         
            +
             
     | 
| 102 | 
         
            +
                def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
         
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
                    super().__init__()
         
     | 
| 105 | 
         
            +
                    self.dim = dim
         
     | 
| 106 | 
         
            +
                    self.window_size = window_size  # Wh, Ww
         
     | 
| 107 | 
         
            +
                    self.num_heads = num_heads
         
     | 
| 108 | 
         
            +
                    head_dim = dim // num_heads
         
     | 
| 109 | 
         
            +
                    self.scale = qk_scale or head_dim ** -0.5
         
     | 
| 110 | 
         
            +
             
     | 
| 111 | 
         
            +
                    # define a parameter table of relative position bias
         
     | 
| 112 | 
         
            +
                    self.relative_position_bias_table = nn.Parameter(
         
     | 
| 113 | 
         
            +
                        torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH
         
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
                    # get pair-wise relative position index for each token inside the window
         
     | 
| 116 | 
         
            +
                    coords_h = torch.arange(self.window_size[0])
         
     | 
| 117 | 
         
            +
                    coords_w = torch.arange(self.window_size[1])
         
     | 
| 118 | 
         
            +
                    coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
         
     | 
| 119 | 
         
            +
                    coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
         
     | 
| 120 | 
         
            +
                    relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
         
     | 
| 121 | 
         
            +
                    relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
         
     | 
| 122 | 
         
            +
                    relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
         
     | 
| 123 | 
         
            +
                    relative_coords[:, :, 1] += self.window_size[1] - 1
         
     | 
| 124 | 
         
            +
                    relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
         
     | 
| 125 | 
         
            +
                    relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
         
     | 
| 126 | 
         
            +
                    self.register_buffer("relative_position_index", relative_position_index)
         
     | 
| 127 | 
         
            +
             
     | 
| 128 | 
         
            +
                    self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
         
     | 
| 129 | 
         
            +
                    self.attn_drop = nn.Dropout(attn_drop)
         
     | 
| 130 | 
         
            +
                    self.proj = nn.Linear(dim, dim)
         
     | 
| 131 | 
         
            +
                    self.proj_drop = nn.Dropout(proj_drop)
         
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
                    trunc_normal_(self.relative_position_bias_table, std=.02)
         
     | 
| 134 | 
         
            +
                    self.softmax = nn.Softmax(dim=-1)
         
     | 
| 135 | 
         
            +
             
     | 
| 136 | 
         
            +
                def forward(self, x, mask=None):
         
     | 
| 137 | 
         
            +
                    """ Forward function.
         
     | 
| 138 | 
         
            +
                    Args:
         
     | 
| 139 | 
         
            +
                        x: input features with shape of (num_windows*B, N, C)
         
     | 
| 140 | 
         
            +
                        mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
         
     | 
| 141 | 
         
            +
                    """
         
     | 
| 142 | 
         
            +
                    B_, N, C = x.shape
         
     | 
| 143 | 
         
            +
                    qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
         
     | 
| 144 | 
         
            +
                    q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)
         
     | 
| 145 | 
         
            +
             
     | 
| 146 | 
         
            +
                    q = q * self.scale
         
     | 
| 147 | 
         
            +
                    attn = (q @ k.transpose(-2, -1))
         
     | 
| 148 | 
         
            +
             
     | 
| 149 | 
         
            +
                    relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
         
     | 
| 150 | 
         
            +
                        self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
         
     | 
| 151 | 
         
            +
                    relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
         
     | 
| 152 | 
         
            +
                    attn = attn + relative_position_bias.unsqueeze(0)
         
     | 
| 153 | 
         
            +
             
     | 
| 154 | 
         
            +
                    if mask is not None:
         
     | 
| 155 | 
         
            +
                        nW = mask.shape[0]
         
     | 
| 156 | 
         
            +
                        attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
         
     | 
| 157 | 
         
            +
                        attn = attn.view(-1, self.num_heads, N, N)
         
     | 
| 158 | 
         
            +
                        attn = self.softmax(attn)
         
     | 
| 159 | 
         
            +
                    else:
         
     | 
| 160 | 
         
            +
                        attn = self.softmax(attn)
         
     | 
| 161 | 
         
            +
             
     | 
| 162 | 
         
            +
                    attn = self.attn_drop(attn)
         
     | 
| 163 | 
         
            +
             
     | 
| 164 | 
         
            +
                    x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
         
     | 
| 165 | 
         
            +
                    x = self.proj(x)
         
     | 
| 166 | 
         
            +
                    x = self.proj_drop(x)
         
     | 
| 167 | 
         
            +
                    return x
         
     | 
| 168 | 
         
            +
             
     | 
| 169 | 
         
            +
             
     | 
| 170 | 
         
            +
            class SwinTransformerBlock(nn.Module):
         
     | 
| 171 | 
         
            +
                """ Swin Transformer Block.
         
     | 
| 172 | 
         
            +
                Args:
         
     | 
| 173 | 
         
            +
                    dim (int): Number of input channels.
         
     | 
| 174 | 
         
            +
                    num_heads (int): Number of attention heads.
         
     | 
| 175 | 
         
            +
                    window_size (int): Window size.
         
     | 
| 176 | 
         
            +
                    shift_size (int): Shift size for SW-MSA.
         
     | 
| 177 | 
         
            +
                    mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
         
     | 
| 178 | 
         
            +
                    qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
         
     | 
| 179 | 
         
            +
                    qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
         
     | 
| 180 | 
         
            +
                    drop (float, optional): Dropout rate. Default: 0.0
         
     | 
| 181 | 
         
            +
                    attn_drop (float, optional): Attention dropout rate. Default: 0.0
         
     | 
| 182 | 
         
            +
                    drop_path (float, optional): Stochastic depth rate. Default: 0.0
         
     | 
| 183 | 
         
            +
                    act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
         
     | 
| 184 | 
         
            +
                    norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
         
     | 
| 185 | 
         
            +
                """
         
     | 
| 186 | 
         
            +
             
     | 
| 187 | 
         
            +
                def __init__(self, dim, num_heads, window_size=7, shift_size=0,
         
     | 
| 188 | 
         
            +
                             mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
         
     | 
| 189 | 
         
            +
                             act_layer=nn.GELU, norm_layer=nn.LayerNorm):
         
     | 
| 190 | 
         
            +
                    super().__init__()
         
     | 
| 191 | 
         
            +
                    self.dim = dim
         
     | 
| 192 | 
         
            +
                    self.num_heads = num_heads
         
     | 
| 193 | 
         
            +
                    self.window_size = window_size
         
     | 
| 194 | 
         
            +
                    self.shift_size = shift_size
         
     | 
| 195 | 
         
            +
                    self.mlp_ratio = mlp_ratio
         
     | 
| 196 | 
         
            +
                    assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
         
     | 
| 197 | 
         
            +
             
     | 
| 198 | 
         
            +
                    self.norm1 = norm_layer(dim)
         
     | 
| 199 | 
         
            +
                    self.attn = WindowAttention(
         
     | 
| 200 | 
         
            +
                        dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
         
     | 
| 201 | 
         
            +
                        qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
         
     | 
| 202 | 
         
            +
             
     | 
| 203 | 
         
            +
                    self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
         
     | 
| 204 | 
         
            +
                    self.norm2 = norm_layer(dim)
         
     | 
| 205 | 
         
            +
                    mlp_hidden_dim = int(dim * mlp_ratio)
         
     | 
| 206 | 
         
            +
                    self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
         
     | 
| 207 | 
         
            +
             
     | 
| 208 | 
         
            +
                    self.H = None
         
     | 
| 209 | 
         
            +
                    self.W = None
         
     | 
| 210 | 
         
            +
             
     | 
| 211 | 
         
            +
                def forward(self, x, mask_matrix):
         
     | 
| 212 | 
         
            +
                    """ Forward function.
         
     | 
| 213 | 
         
            +
                    Args:
         
     | 
| 214 | 
         
            +
                        x: Input feature, tensor size (B, H*W, C).
         
     | 
| 215 | 
         
            +
                        H, W: Spatial resolution of the input feature.
         
     | 
| 216 | 
         
            +
                        mask_matrix: Attention mask for cyclic shift.
         
     | 
| 217 | 
         
            +
                    """
         
     | 
| 218 | 
         
            +
                    B, L, C = x.shape
         
     | 
| 219 | 
         
            +
                    H, W = self.H, self.W
         
     | 
| 220 | 
         
            +
                    assert L == H * W, "input feature has wrong size"
         
     | 
| 221 | 
         
            +
             
     | 
| 222 | 
         
            +
                    shortcut = x
         
     | 
| 223 | 
         
            +
                    x = self.norm1(x)
         
     | 
| 224 | 
         
            +
                    x = x.view(B, H, W, C)
         
     | 
| 225 | 
         
            +
             
     | 
| 226 | 
         
            +
                    # pad feature maps to multiples of window size
         
     | 
| 227 | 
         
            +
                    pad_l = pad_t = 0
         
     | 
| 228 | 
         
            +
                    pad_r = (self.window_size - W % self.window_size) % self.window_size
         
     | 
| 229 | 
         
            +
                    pad_b = (self.window_size - H % self.window_size) % self.window_size
         
     | 
| 230 | 
         
            +
                    x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
         
     | 
| 231 | 
         
            +
                    _, Hp, Wp, _ = x.shape
         
     | 
| 232 | 
         
            +
             
     | 
| 233 | 
         
            +
                    # cyclic shift
         
     | 
| 234 | 
         
            +
                    if self.shift_size > 0:
         
     | 
| 235 | 
         
            +
                        shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
         
     | 
| 236 | 
         
            +
                        attn_mask = mask_matrix
         
     | 
| 237 | 
         
            +
                    else:
         
     | 
| 238 | 
         
            +
                        shifted_x = x
         
     | 
| 239 | 
         
            +
                        attn_mask = None
         
     | 
| 240 | 
         
            +
             
     | 
| 241 | 
         
            +
                    # partition windows
         
     | 
| 242 | 
         
            +
                    x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
         
     | 
| 243 | 
         
            +
                    x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C
         
     | 
| 244 | 
         
            +
             
     | 
| 245 | 
         
            +
                    # W-MSA/SW-MSA
         
     | 
| 246 | 
         
            +
                    attn_windows = self.attn(x_windows, mask=attn_mask)  # nW*B, window_size*window_size, C
         
     | 
| 247 | 
         
            +
             
     | 
| 248 | 
         
            +
                    # merge windows
         
     | 
| 249 | 
         
            +
                    attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
         
     | 
| 250 | 
         
            +
                    shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp)  # B H' W' C
         
     | 
| 251 | 
         
            +
             
     | 
| 252 | 
         
            +
                    # reverse cyclic shift
         
     | 
| 253 | 
         
            +
                    if self.shift_size > 0:
         
     | 
| 254 | 
         
            +
                        x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
         
     | 
| 255 | 
         
            +
                    else:
         
     | 
| 256 | 
         
            +
                        x = shifted_x
         
     | 
| 257 | 
         
            +
             
     | 
| 258 | 
         
            +
                    if pad_r > 0 or pad_b > 0:
         
     | 
| 259 | 
         
            +
                        x = x[:, :H, :W, :].contiguous()
         
     | 
| 260 | 
         
            +
             
     | 
| 261 | 
         
            +
                    x = x.view(B, H * W, C)
         
     | 
| 262 | 
         
            +
             
     | 
| 263 | 
         
            +
                    # FFN
         
     | 
| 264 | 
         
            +
                    x = shortcut + self.drop_path(x)
         
     | 
| 265 | 
         
            +
                    x = x + self.drop_path(self.mlp(self.norm2(x)))
         
     | 
| 266 | 
         
            +
             
     | 
| 267 | 
         
            +
                    return x
         
     | 
| 268 | 
         
            +
             
     | 
| 269 | 
         
            +
             
     | 
| 270 | 
         
            +
            class PatchMerging(nn.Module):
         
     | 
| 271 | 
         
            +
                """ Patch Merging Layer
         
     | 
| 272 | 
         
            +
                Args:
         
     | 
| 273 | 
         
            +
                    dim (int): Number of input channels.
         
     | 
| 274 | 
         
            +
                    norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
         
     | 
| 275 | 
         
            +
                """
         
     | 
| 276 | 
         
            +
             
     | 
| 277 | 
         
            +
                def __init__(self, dim, norm_layer=nn.LayerNorm):
         
     | 
| 278 | 
         
            +
                    super().__init__()
         
     | 
| 279 | 
         
            +
                    self.dim = dim
         
     | 
| 280 | 
         
            +
                    self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
         
     | 
| 281 | 
         
            +
                    self.norm = norm_layer(4 * dim)
         
     | 
| 282 | 
         
            +
             
     | 
| 283 | 
         
            +
                def forward(self, x, H, W):
         
     | 
| 284 | 
         
            +
                    """ Forward function.
         
     | 
| 285 | 
         
            +
                    Args:
         
     | 
| 286 | 
         
            +
                        x: Input feature, tensor size (B, H*W, C).
         
     | 
| 287 | 
         
            +
                        H, W: Spatial resolution of the input feature.
         
     | 
| 288 | 
         
            +
                    """
         
     | 
| 289 | 
         
            +
                    B, L, C = x.shape
         
     | 
| 290 | 
         
            +
                    assert L == H * W, "input feature has wrong size"
         
     | 
| 291 | 
         
            +
             
     | 
| 292 | 
         
            +
                    x = x.view(B, H, W, C)
         
     | 
| 293 | 
         
            +
             
     | 
| 294 | 
         
            +
                    # padding
         
     | 
| 295 | 
         
            +
                    pad_input = (H % 2 == 1) or (W % 2 == 1)
         
     | 
| 296 | 
         
            +
                    if pad_input:
         
     | 
| 297 | 
         
            +
                        x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
         
     | 
| 298 | 
         
            +
             
     | 
| 299 | 
         
            +
                    x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
         
     | 
| 300 | 
         
            +
                    x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
         
     | 
| 301 | 
         
            +
                    x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
         
     | 
| 302 | 
         
            +
                    x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
         
     | 
| 303 | 
         
            +
                    x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
         
     | 
| 304 | 
         
            +
                    x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C
         
     | 
| 305 | 
         
            +
             
     | 
| 306 | 
         
            +
                    x = self.norm(x)
         
     | 
| 307 | 
         
            +
                    x = self.reduction(x)
         
     | 
| 308 | 
         
            +
             
     | 
| 309 | 
         
            +
                    return x
         
     | 
| 310 | 
         
            +
             
     | 
| 311 | 
         
            +
             
     | 
| 312 | 
         
            +
            class BasicLayer(nn.Module):
         
     | 
| 313 | 
         
            +
                """ A basic Swin Transformer layer for one stage.
         
     | 
| 314 | 
         
            +
                Args:
         
     | 
| 315 | 
         
            +
                    dim (int): Number of feature channels
         
     | 
| 316 | 
         
            +
                    depth (int): Depths of this stage.
         
     | 
| 317 | 
         
            +
                    num_heads (int): Number of attention head.
         
     | 
| 318 | 
         
            +
                    window_size (int): Local window size. Default: 7.
         
     | 
| 319 | 
         
            +
                    mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
         
     | 
| 320 | 
         
            +
                    qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
         
     | 
| 321 | 
         
            +
                    qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
         
     | 
| 322 | 
         
            +
                    drop (float, optional): Dropout rate. Default: 0.0
         
     | 
| 323 | 
         
            +
                    attn_drop (float, optional): Attention dropout rate. Default: 0.0
         
     | 
| 324 | 
         
            +
                    drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
         
     | 
| 325 | 
         
            +
                    norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
         
     | 
| 326 | 
         
            +
                    downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
         
     | 
| 327 | 
         
            +
                    use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
         
     | 
| 328 | 
         
            +
                """
         
     | 
| 329 | 
         
            +
             
     | 
| 330 | 
         
            +
                def __init__(self,
         
     | 
| 331 | 
         
            +
                             dim,
         
     | 
| 332 | 
         
            +
                             depth,
         
     | 
| 333 | 
         
            +
                             num_heads,
         
     | 
| 334 | 
         
            +
                             window_size=7,
         
     | 
| 335 | 
         
            +
                             mlp_ratio=4.,
         
     | 
| 336 | 
         
            +
                             qkv_bias=True,
         
     | 
| 337 | 
         
            +
                             qk_scale=None,
         
     | 
| 338 | 
         
            +
                             drop=0.,
         
     | 
| 339 | 
         
            +
                             attn_drop=0.,
         
     | 
| 340 | 
         
            +
                             drop_path=0.,
         
     | 
| 341 | 
         
            +
                             norm_layer=nn.LayerNorm,
         
     | 
| 342 | 
         
            +
                             downsample=None,
         
     | 
| 343 | 
         
            +
                             use_checkpoint=False):
         
     | 
| 344 | 
         
            +
                    super().__init__()
         
     | 
| 345 | 
         
            +
                    self.window_size = window_size
         
     | 
| 346 | 
         
            +
                    self.shift_size = window_size // 2
         
     | 
| 347 | 
         
            +
                    self.depth = depth
         
     | 
| 348 | 
         
            +
                    self.use_checkpoint = use_checkpoint
         
     | 
| 349 | 
         
            +
             
     | 
| 350 | 
         
            +
                    # build blocks
         
     | 
| 351 | 
         
            +
                    self.blocks = nn.ModuleList([
         
     | 
| 352 | 
         
            +
                        SwinTransformerBlock(
         
     | 
| 353 | 
         
            +
                            dim=dim,
         
     | 
| 354 | 
         
            +
                            num_heads=num_heads,
         
     | 
| 355 | 
         
            +
                            window_size=window_size,
         
     | 
| 356 | 
         
            +
                            shift_size=0 if (i % 2 == 0) else window_size // 2,
         
     | 
| 357 | 
         
            +
                            mlp_ratio=mlp_ratio,
         
     | 
| 358 | 
         
            +
                            qkv_bias=qkv_bias,
         
     | 
| 359 | 
         
            +
                            qk_scale=qk_scale,
         
     | 
| 360 | 
         
            +
                            drop=drop,
         
     | 
| 361 | 
         
            +
                            attn_drop=attn_drop,
         
     | 
| 362 | 
         
            +
                            drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
         
     | 
| 363 | 
         
            +
                            norm_layer=norm_layer)
         
     | 
| 364 | 
         
            +
                        for i in range(depth)])
         
     | 
| 365 | 
         
            +
             
     | 
| 366 | 
         
            +
                    # patch merging layer
         
     | 
| 367 | 
         
            +
                    if downsample is not None:
         
     | 
| 368 | 
         
            +
                        self.downsample = downsample(dim=dim, norm_layer=norm_layer)
         
     | 
| 369 | 
         
            +
                    else:
         
     | 
| 370 | 
         
            +
                        self.downsample = None
         
     | 
| 371 | 
         
            +
             
     | 
| 372 | 
         
            +
                def forward(self, x, H, W):
         
     | 
| 373 | 
         
            +
                    """ Forward function.
         
     | 
| 374 | 
         
            +
                    Args:
         
     | 
| 375 | 
         
            +
                        x: Input feature, tensor size (B, H*W, C).
         
     | 
| 376 | 
         
            +
                        H, W: Spatial resolution of the input feature.
         
     | 
| 377 | 
         
            +
                    """
         
     | 
| 378 | 
         
            +
             
     | 
| 379 | 
         
            +
                    # calculate attention mask for SW-MSA
         
     | 
| 380 | 
         
            +
                    Hp = int(np.ceil(H / self.window_size)) * self.window_size
         
     | 
| 381 | 
         
            +
                    Wp = int(np.ceil(W / self.window_size)) * self.window_size
         
     | 
| 382 | 
         
            +
                    img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device)  # 1 Hp Wp 1
         
     | 
| 383 | 
         
            +
                    h_slices = (slice(0, -self.window_size),
         
     | 
| 384 | 
         
            +
                                slice(-self.window_size, -self.shift_size),
         
     | 
| 385 | 
         
            +
                                slice(-self.shift_size, None))
         
     | 
| 386 | 
         
            +
                    w_slices = (slice(0, -self.window_size),
         
     | 
| 387 | 
         
            +
                                slice(-self.window_size, -self.shift_size),
         
     | 
| 388 | 
         
            +
                                slice(-self.shift_size, None))
         
     | 
| 389 | 
         
            +
                    cnt = 0
         
     | 
| 390 | 
         
            +
                    for h in h_slices:
         
     | 
| 391 | 
         
            +
                        for w in w_slices:
         
     | 
| 392 | 
         
            +
                            img_mask[:, h, w, :] = cnt
         
     | 
| 393 | 
         
            +
                            cnt += 1
         
     | 
| 394 | 
         
            +
             
     | 
| 395 | 
         
            +
                    mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
         
     | 
| 396 | 
         
            +
                    mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
         
     | 
| 397 | 
         
            +
                    attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
         
     | 
| 398 | 
         
            +
                    attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
         
     | 
| 399 | 
         
            +
             
     | 
| 400 | 
         
            +
                    for blk in self.blocks:
         
     | 
| 401 | 
         
            +
                        blk.H, blk.W = H, W
         
     | 
| 402 | 
         
            +
                        if self.use_checkpoint:
         
     | 
| 403 | 
         
            +
                            x = checkpoint.checkpoint(blk, x, attn_mask)
         
     | 
| 404 | 
         
            +
                        else:
         
     | 
| 405 | 
         
            +
                            x = blk(x, attn_mask)
         
     | 
| 406 | 
         
            +
                    if self.downsample is not None:
         
     | 
| 407 | 
         
            +
                        x_down = self.downsample(x, H, W)
         
     | 
| 408 | 
         
            +
                        Wh, Ww = (H + 1) // 2, (W + 1) // 2
         
     | 
| 409 | 
         
            +
                        return x, H, W, x_down, Wh, Ww
         
     | 
| 410 | 
         
            +
                    else:
         
     | 
| 411 | 
         
            +
                        return x, H, W, x, H, W
         
     | 
| 412 | 
         
            +
             
     | 
| 413 | 
         
            +
             
     | 
| 414 | 
         
            +
            class PatchEmbed(nn.Module):
         
     | 
| 415 | 
         
            +
                """ Image to Patch Embedding
         
     | 
| 416 | 
         
            +
                Args:
         
     | 
| 417 | 
         
            +
                    patch_size (int): Patch token size. Default: 4.
         
     | 
| 418 | 
         
            +
                    in_chans (int): Number of input image channels. Default: 3.
         
     | 
| 419 | 
         
            +
                    embed_dim (int): Number of linear projection output channels. Default: 96.
         
     | 
| 420 | 
         
            +
                    norm_layer (nn.Module, optional): Normalization layer. Default: None
         
     | 
| 421 | 
         
            +
                """
         
     | 
| 422 | 
         
            +
             
     | 
| 423 | 
         
            +
                def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
         
     | 
| 424 | 
         
            +
                    super().__init__()
         
     | 
| 425 | 
         
            +
                    patch_size = to_2tuple(patch_size)
         
     | 
| 426 | 
         
            +
                    self.patch_size = patch_size
         
     | 
| 427 | 
         
            +
             
     | 
| 428 | 
         
            +
                    self.in_chans = in_chans
         
     | 
| 429 | 
         
            +
                    self.embed_dim = embed_dim
         
     | 
| 430 | 
         
            +
             
     | 
| 431 | 
         
            +
                    self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
         
     | 
| 432 | 
         
            +
                    if norm_layer is not None:
         
     | 
| 433 | 
         
            +
                        self.norm = norm_layer(embed_dim)
         
     | 
| 434 | 
         
            +
                    else:
         
     | 
| 435 | 
         
            +
                        self.norm = None
         
     | 
| 436 | 
         
            +
             
     | 
| 437 | 
         
            +
                def forward(self, x):
         
     | 
| 438 | 
         
            +
                    """Forward function."""
         
     | 
| 439 | 
         
            +
                    # padding
         
     | 
| 440 | 
         
            +
                    _, _, H, W = x.size()
         
     | 
| 441 | 
         
            +
                    if W % self.patch_size[1] != 0:
         
     | 
| 442 | 
         
            +
                        x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
         
     | 
| 443 | 
         
            +
                    if H % self.patch_size[0] != 0:
         
     | 
| 444 | 
         
            +
                        x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
         
     | 
| 445 | 
         
            +
             
     | 
| 446 | 
         
            +
                    x = self.proj(x)  # B C Wh Ww
         
     | 
| 447 | 
         
            +
                    if self.norm is not None:
         
     | 
| 448 | 
         
            +
                        Wh, Ww = x.size(2), x.size(3)
         
     | 
| 449 | 
         
            +
                        x = x.flatten(2).transpose(1, 2)
         
     | 
| 450 | 
         
            +
                        x = self.norm(x)
         
     | 
| 451 | 
         
            +
                        x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
         
     | 
| 452 | 
         
            +
             
     | 
| 453 | 
         
            +
                    return x
         
     | 
| 454 | 
         
            +
             
     | 
| 455 | 
         
            +
             
     | 
| 456 | 
         
            +
            class SwinTransformer(nn.Module):
         
     | 
| 457 | 
         
            +
                """ Swin Transformer backbone.
         
     | 
| 458 | 
         
            +
                    A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`  -
         
     | 
| 459 | 
         
            +
                      https://arxiv.org/pdf/2103.14030
         
     | 
| 460 | 
         
            +
                Args:
         
     | 
| 461 | 
         
            +
                    pretrain_img_size (int): Input image size for training the pretrained model,
         
     | 
| 462 | 
         
            +
                        used in absolute postion embedding. Default 224.
         
     | 
| 463 | 
         
            +
                    patch_size (int | tuple(int)): Patch size. Default: 4.
         
     | 
| 464 | 
         
            +
                    in_chans (int): Number of input image channels. Default: 3.
         
     | 
| 465 | 
         
            +
                    embed_dim (int): Number of linear projection output channels. Default: 96.
         
     | 
| 466 | 
         
            +
                    depths (tuple[int]): Depths of each Swin Transformer stage.
         
     | 
| 467 | 
         
            +
                    num_heads (tuple[int]): Number of attention head of each stage.
         
     | 
| 468 | 
         
            +
                    window_size (int): Window size. Default: 7.
         
     | 
| 469 | 
         
            +
                    mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
         
     | 
| 470 | 
         
            +
                    qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
         
     | 
| 471 | 
         
            +
                    qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
         
     | 
| 472 | 
         
            +
                    drop_rate (float): Dropout rate.
         
     | 
| 473 | 
         
            +
                    attn_drop_rate (float): Attention dropout rate. Default: 0.
         
     | 
| 474 | 
         
            +
                    drop_path_rate (float): Stochastic depth rate. Default: 0.2.
         
     | 
| 475 | 
         
            +
                    norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
         
     | 
| 476 | 
         
            +
                    ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
         
     | 
| 477 | 
         
            +
                    patch_norm (bool): If True, add normalization after patch embedding. Default: True.
         
     | 
| 478 | 
         
            +
                    out_indices (Sequence[int]): Output from which stages.
         
     | 
| 479 | 
         
            +
                    frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
         
     | 
| 480 | 
         
            +
                        -1 means not freezing any parameters.
         
     | 
| 481 | 
         
            +
                    use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
         
     | 
| 482 | 
         
            +
                """
         
     | 
| 483 | 
         
            +
             
     | 
| 484 | 
         
            +
                def __init__(self,
         
     | 
| 485 | 
         
            +
                             pretrain_img_size=224,
         
     | 
| 486 | 
         
            +
                             patch_size=4,
         
     | 
| 487 | 
         
            +
                             in_chans=3,
         
     | 
| 488 | 
         
            +
                             embed_dim=96,
         
     | 
| 489 | 
         
            +
                             depths=[2, 2, 6, 2],
         
     | 
| 490 | 
         
            +
                             num_heads=[3, 6, 12, 24],
         
     | 
| 491 | 
         
            +
                             window_size=7,
         
     | 
| 492 | 
         
            +
                             mlp_ratio=4.,
         
     | 
| 493 | 
         
            +
                             qkv_bias=True,
         
     | 
| 494 | 
         
            +
                             qk_scale=None,
         
     | 
| 495 | 
         
            +
                             drop_rate=0.,
         
     | 
| 496 | 
         
            +
                             attn_drop_rate=0.,
         
     | 
| 497 | 
         
            +
                             drop_path_rate=0.2,
         
     | 
| 498 | 
         
            +
                             norm_layer=nn.LayerNorm,
         
     | 
| 499 | 
         
            +
                             ape=False,
         
     | 
| 500 | 
         
            +
                             patch_norm=True,
         
     | 
| 501 | 
         
            +
                             out_indices=(0, 1, 2, 3),
         
     | 
| 502 | 
         
            +
                             frozen_stages=-1,
         
     | 
| 503 | 
         
            +
                             use_checkpoint=False):
         
     | 
| 504 | 
         
            +
                    super().__init__()
         
     | 
| 505 | 
         
            +
             
     | 
| 506 | 
         
            +
                    self.pretrain_img_size = pretrain_img_size
         
     | 
| 507 | 
         
            +
                    self.num_layers = len(depths)
         
     | 
| 508 | 
         
            +
                    self.embed_dim = embed_dim
         
     | 
| 509 | 
         
            +
                    self.ape = ape
         
     | 
| 510 | 
         
            +
                    self.patch_norm = patch_norm
         
     | 
| 511 | 
         
            +
                    self.out_indices = out_indices
         
     | 
| 512 | 
         
            +
                    self.frozen_stages = frozen_stages
         
     | 
| 513 | 
         
            +
             
     | 
| 514 | 
         
            +
                    # split image into non-overlapping patches
         
     | 
| 515 | 
         
            +
                    self.patch_embed = PatchEmbed(
         
     | 
| 516 | 
         
            +
                        patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
         
     | 
| 517 | 
         
            +
                        norm_layer=norm_layer if self.patch_norm else None)
         
     | 
| 518 | 
         
            +
             
     | 
| 519 | 
         
            +
                    # absolute position embedding
         
     | 
| 520 | 
         
            +
                    if self.ape:
         
     | 
| 521 | 
         
            +
                        pretrain_img_size = to_2tuple(pretrain_img_size)
         
     | 
| 522 | 
         
            +
                        patch_size = to_2tuple(patch_size)
         
     | 
| 523 | 
         
            +
                        patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1]]
         
     | 
| 524 | 
         
            +
             
     | 
| 525 | 
         
            +
                        self.absolute_pos_embed = nn.Parameter(
         
     | 
| 526 | 
         
            +
                            torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1]))
         
     | 
| 527 | 
         
            +
                        trunc_normal_(self.absolute_pos_embed, std=.02)
         
     | 
| 528 | 
         
            +
             
     | 
| 529 | 
         
            +
                    self.pos_drop = nn.Dropout(p=drop_rate)
         
     | 
| 530 | 
         
            +
             
     | 
| 531 | 
         
            +
                    # stochastic depth
         
     | 
| 532 | 
         
            +
                    dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule
         
     | 
| 533 | 
         
            +
             
     | 
| 534 | 
         
            +
                    # build layers
         
     | 
| 535 | 
         
            +
                    self.layers = nn.ModuleList()
         
     | 
| 536 | 
         
            +
                    for i_layer in range(self.num_layers):
         
     | 
| 537 | 
         
            +
                        layer = BasicLayer(
         
     | 
| 538 | 
         
            +
                            dim=int(embed_dim * 2 ** i_layer),
         
     | 
| 539 | 
         
            +
                            depth=depths[i_layer],
         
     | 
| 540 | 
         
            +
                            num_heads=num_heads[i_layer],
         
     | 
| 541 | 
         
            +
                            window_size=window_size,
         
     | 
| 542 | 
         
            +
                            mlp_ratio=mlp_ratio,
         
     | 
| 543 | 
         
            +
                            qkv_bias=qkv_bias,
         
     | 
| 544 | 
         
            +
                            qk_scale=qk_scale,
         
     | 
| 545 | 
         
            +
                            drop=drop_rate,
         
     | 
| 546 | 
         
            +
                            attn_drop=attn_drop_rate,
         
     | 
| 547 | 
         
            +
                            drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
         
     | 
| 548 | 
         
            +
                            norm_layer=norm_layer,
         
     | 
| 549 | 
         
            +
                            downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
         
     | 
| 550 | 
         
            +
                            use_checkpoint=use_checkpoint)
         
     | 
| 551 | 
         
            +
                        self.layers.append(layer)
         
     | 
| 552 | 
         
            +
             
     | 
| 553 | 
         
            +
                    num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
         
     | 
| 554 | 
         
            +
                    self.num_features = num_features
         
     | 
| 555 | 
         
            +
             
     | 
| 556 | 
         
            +
                    # add a norm layer for each output
         
     | 
| 557 | 
         
            +
                    for i_layer in out_indices:
         
     | 
| 558 | 
         
            +
                        layer = norm_layer(num_features[i_layer])
         
     | 
| 559 | 
         
            +
                        layer_name = f'norm{i_layer}'
         
     | 
| 560 | 
         
            +
                        self.add_module(layer_name, layer)
         
     | 
| 561 | 
         
            +
             
     | 
| 562 | 
         
            +
                    self._freeze_stages()
         
     | 
| 563 | 
         
            +
             
     | 
| 564 | 
         
            +
                def _freeze_stages(self):
         
     | 
| 565 | 
         
            +
                    if self.frozen_stages >= 0:
         
     | 
| 566 | 
         
            +
                        self.patch_embed.eval()
         
     | 
| 567 | 
         
            +
                        for param in self.patch_embed.parameters():
         
     | 
| 568 | 
         
            +
                            param.requires_grad = False
         
     | 
| 569 | 
         
            +
             
     | 
| 570 | 
         
            +
                    if self.frozen_stages >= 1 and self.ape:
         
     | 
| 571 | 
         
            +
                        self.absolute_pos_embed.requires_grad = False
         
     | 
| 572 | 
         
            +
             
     | 
| 573 | 
         
            +
                    if self.frozen_stages >= 2:
         
     | 
| 574 | 
         
            +
                        self.pos_drop.eval()
         
     | 
| 575 | 
         
            +
                        for i in range(0, self.frozen_stages - 1):
         
     | 
| 576 | 
         
            +
                            m = self.layers[i]
         
     | 
| 577 | 
         
            +
                            m.eval()
         
     | 
| 578 | 
         
            +
                            for param in m.parameters():
         
     | 
| 579 | 
         
            +
                                param.requires_grad = False
         
     | 
| 580 | 
         
            +
             
     | 
| 581 | 
         
            +
                def forward(self, x):
         
     | 
| 582 | 
         
            +
             
     | 
| 583 | 
         
            +
                    x = self.patch_embed(x)
         
     | 
| 584 | 
         
            +
             
     | 
| 585 | 
         
            +
                    Wh, Ww = x.size(2), x.size(3)
         
     | 
| 586 | 
         
            +
                    if self.ape:
         
     | 
| 587 | 
         
            +
                        # interpolate the position embedding to the corresponding size
         
     | 
| 588 | 
         
            +
                        absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
         
     | 
| 589 | 
         
            +
                        x = (x + absolute_pos_embed)  # B Wh*Ww C
         
     | 
| 590 | 
         
            +
             
     | 
| 591 | 
         
            +
                    outs = [x.contiguous()]
         
     | 
| 592 | 
         
            +
                    x = x.flatten(2).transpose(1, 2)
         
     | 
| 593 | 
         
            +
                    x = self.pos_drop(x)
         
     | 
| 594 | 
         
            +
             
     | 
| 595 | 
         
            +
                    for i in range(self.num_layers):
         
     | 
| 596 | 
         
            +
                        layer = self.layers[i]
         
     | 
| 597 | 
         
            +
                        x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
         
     | 
| 598 | 
         
            +
             
     | 
| 599 | 
         
            +
                        if i in self.out_indices:
         
     | 
| 600 | 
         
            +
                            norm_layer = getattr(self, f'norm{i}')
         
     | 
| 601 | 
         
            +
                            x_out = norm_layer(x_out)
         
     | 
| 602 | 
         
            +
             
     | 
| 603 | 
         
            +
                            out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
         
     | 
| 604 | 
         
            +
                            outs.append(out)
         
     | 
| 605 | 
         
            +
             
     | 
| 606 | 
         
            +
                    return tuple(outs)
         
     | 
| 607 | 
         
            +
             
     | 
| 608 | 
         
            +
             
     | 
| 609 | 
         
            +
            def get_activation_fn(activation):
         
     | 
| 610 | 
         
            +
                """Return an activation function given a string"""
         
     | 
| 611 | 
         
            +
                if activation == "gelu":
         
     | 
| 612 | 
         
            +
                    return F.gelu
         
     | 
| 613 | 
         
            +
             
     | 
| 614 | 
         
            +
                raise RuntimeError(F"activation should be gelu, not {activation}.")
         
     | 
| 615 | 
         
            +
             
     | 
| 616 | 
         
            +
             
     | 
| 617 | 
         
            +
            def make_cbr(in_dim, out_dim):
         
     | 
| 618 | 
         
            +
                return nn.Sequential(nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1), nn.InstanceNorm2d(out_dim), nn.GELU())
         
     | 
| 619 | 
         
            +
             
     | 
| 620 | 
         
            +
             
     | 
| 621 | 
         
            +
            def make_cbg(in_dim, out_dim):
         
     | 
| 622 | 
         
            +
                return nn.Sequential(nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1), nn.InstanceNorm2d(out_dim), nn.GELU())
         
     | 
| 623 | 
         
            +
             
     | 
| 624 | 
         
            +
             
     | 
| 625 | 
         
            +
            def rescale_to(x, scale_factor: float = 2, interpolation='nearest'):
         
     | 
| 626 | 
         
            +
                return F.interpolate(x, scale_factor=scale_factor, mode=interpolation)
         
     | 
| 627 | 
         
            +
             
     | 
| 628 | 
         
            +
             
     | 
| 629 | 
         
            +
            def resize_as(x, y, interpolation='bilinear'):
         
     | 
| 630 | 
         
            +
                return F.interpolate(x, size=y.shape[-2:], mode=interpolation)
         
     | 
| 631 | 
         
            +
             
     | 
| 632 | 
         
            +
             
     | 
| 633 | 
         
            +
            def image2patches(x):
         
     | 
| 634 | 
         
            +
                """b c (hg h) (wg w) -> (hg wg b) c h w"""
         
     | 
| 635 | 
         
            +
                x = rearrange(x, 'b c (hg h) (wg w) -> (hg wg b) c h w', hg=2, wg=2)
         
     | 
| 636 | 
         
            +
                return x
         
     | 
| 637 | 
         
            +
             
     | 
| 638 | 
         
            +
             
     | 
| 639 | 
         
            +
            def patches2image(x):
         
     | 
| 640 | 
         
            +
                """(hg wg b) c h w -> b c (hg h) (wg w)"""
         
     | 
| 641 | 
         
            +
                x = rearrange(x, '(hg wg b) c h w -> b c (hg h) (wg w)', hg=2, wg=2)
         
     | 
| 642 | 
         
            +
                return x
         
     | 
| 643 | 
         
            +
             
     | 
| 644 | 
         
            +
             
     | 
| 645 | 
         
            +
            class PositionEmbeddingSine:
         
     | 
| 646 | 
         
            +
                def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
         
     | 
| 647 | 
         
            +
                    super().__init__()
         
     | 
| 648 | 
         
            +
                    self.num_pos_feats = num_pos_feats
         
     | 
| 649 | 
         
            +
                    self.temperature = temperature
         
     | 
| 650 | 
         
            +
                    self.normalize = normalize
         
     | 
| 651 | 
         
            +
                    if scale is not None and normalize is False:
         
     | 
| 652 | 
         
            +
                        raise ValueError("normalize should be True if scale is passed")
         
     | 
| 653 | 
         
            +
                    if scale is None:
         
     | 
| 654 | 
         
            +
                        scale = 2 * math.pi
         
     | 
| 655 | 
         
            +
                    self.scale = scale
         
     | 
| 656 | 
         
            +
                    self.dim_t = torch.arange(0, self.num_pos_feats, dtype=torch.float32)
         
     | 
| 657 | 
         
            +
             
     | 
| 658 | 
         
            +
                def __call__(self, b, h, w):
         
     | 
| 659 | 
         
            +
                    device = self.dim_t.device
         
     | 
| 660 | 
         
            +
                    mask = torch.zeros([b, h, w], dtype=torch.bool, device=device)
         
     | 
| 661 | 
         
            +
                    assert mask is not None
         
     | 
| 662 | 
         
            +
                    not_mask = ~mask
         
     | 
| 663 | 
         
            +
                    y_embed = not_mask.cumsum(dim=1, dtype=torch.float32)
         
     | 
| 664 | 
         
            +
                    x_embed = not_mask.cumsum(dim=2, dtype=torch.float32)
         
     | 
| 665 | 
         
            +
                    if self.normalize:
         
     | 
| 666 | 
         
            +
                        eps = 1e-6
         
     | 
| 667 | 
         
            +
                        y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale
         
     | 
| 668 | 
         
            +
                        x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale
         
     | 
| 669 | 
         
            +
             
     | 
| 670 | 
         
            +
                    dim_t = self.temperature ** (2 * (self.dim_t.to(device) // 2) / self.num_pos_feats)
         
     | 
| 671 | 
         
            +
                    pos_x = x_embed[:, :, :, None] / dim_t
         
     | 
| 672 | 
         
            +
                    pos_y = y_embed[:, :, :, None] / dim_t
         
     | 
| 673 | 
         
            +
             
     | 
| 674 | 
         
            +
                    pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
         
     | 
| 675 | 
         
            +
                    pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
         
     | 
| 676 | 
         
            +
             
     | 
| 677 | 
         
            +
                    return torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
         
     | 
| 678 | 
         
            +
             
     | 
| 679 | 
         
            +
             
     | 
| 680 | 
         
            +
            class PositionEmbeddingSine:
         
     | 
| 681 | 
         
            +
                def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
         
     | 
| 682 | 
         
            +
                    super().__init__()
         
     | 
| 683 | 
         
            +
                    self.num_pos_feats = num_pos_feats
         
     | 
| 684 | 
         
            +
                    self.temperature = temperature
         
     | 
| 685 | 
         
            +
                    self.normalize = normalize
         
     | 
| 686 | 
         
            +
                    if scale is not None and normalize is False:
         
     | 
| 687 | 
         
            +
                        raise ValueError("normalize should be True if scale is passed")
         
     | 
| 688 | 
         
            +
                    if scale is None:
         
     | 
| 689 | 
         
            +
                        scale = 2 * math.pi
         
     | 
| 690 | 
         
            +
                    self.scale = scale
         
     | 
| 691 | 
         
            +
                    self.dim_t = torch.arange(0, self.num_pos_feats, dtype=torch.float32)
         
     | 
| 692 | 
         
            +
             
     | 
| 693 | 
         
            +
                def __call__(self, b, h, w):
         
     | 
| 694 | 
         
            +
                    device = self.dim_t.device
         
     | 
| 695 | 
         
            +
                    mask = torch.zeros([b, h, w], dtype=torch.bool, device=device)
         
     | 
| 696 | 
         
            +
                    assert mask is not None
         
     | 
| 697 | 
         
            +
                    not_mask = ~mask
         
     | 
| 698 | 
         
            +
                    y_embed = not_mask.cumsum(dim=1, dtype=torch.float32)
         
     | 
| 699 | 
         
            +
                    x_embed = not_mask.cumsum(dim=2, dtype=torch.float32)
         
     | 
| 700 | 
         
            +
                    if self.normalize:
         
     | 
| 701 | 
         
            +
                        eps = 1e-6
         
     | 
| 702 | 
         
            +
                        y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale
         
     | 
| 703 | 
         
            +
                        x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale
         
     | 
| 704 | 
         
            +
             
     | 
| 705 | 
         
            +
                    dim_t = self.temperature ** (2 * (self.dim_t.to(device) // 2) / self.num_pos_feats)
         
     | 
| 706 | 
         
            +
                    pos_x = x_embed[:, :, :, None] / dim_t
         
     | 
| 707 | 
         
            +
                    pos_y = y_embed[:, :, :, None] / dim_t
         
     | 
| 708 | 
         
            +
             
     | 
| 709 | 
         
            +
                    pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
         
     | 
| 710 | 
         
            +
                    pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
         
     | 
| 711 | 
         
            +
             
     | 
| 712 | 
         
            +
                    return torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
         
     | 
| 713 | 
         
            +
             
     | 
| 714 | 
         
            +
             
     | 
| 715 | 
         
            +
            class MCLM(nn.Module):
         
     | 
| 716 | 
         
            +
                def __init__(self, d_model, num_heads, pool_ratios=[1, 4, 8]):
         
     | 
| 717 | 
         
            +
                    super(MCLM, self).__init__()
         
     | 
| 718 | 
         
            +
                    self.attention = nn.ModuleList([
         
     | 
| 719 | 
         
            +
                        nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
         
     | 
| 720 | 
         
            +
                        nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
         
     | 
| 721 | 
         
            +
                        nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
         
     | 
| 722 | 
         
            +
                        nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
         
     | 
| 723 | 
         
            +
                        nn.MultiheadAttention(d_model, num_heads, dropout=0.1)
         
     | 
| 724 | 
         
            +
                    ])
         
     | 
| 725 | 
         
            +
             
     | 
| 726 | 
         
            +
                    self.linear1 = nn.Linear(d_model, d_model * 2)
         
     | 
| 727 | 
         
            +
                    self.linear2 = nn.Linear(d_model * 2, d_model)
         
     | 
| 728 | 
         
            +
                    self.linear3 = nn.Linear(d_model, d_model * 2)
         
     | 
| 729 | 
         
            +
                    self.linear4 = nn.Linear(d_model * 2, d_model)
         
     | 
| 730 | 
         
            +
                    self.norm1 = nn.LayerNorm(d_model)
         
     | 
| 731 | 
         
            +
                    self.norm2 = nn.LayerNorm(d_model)
         
     | 
| 732 | 
         
            +
                    self.dropout = nn.Dropout(0.1)
         
     | 
| 733 | 
         
            +
                    self.dropout1 = nn.Dropout(0.1)
         
     | 
| 734 | 
         
            +
                    self.dropout2 = nn.Dropout(0.1)
         
     | 
| 735 | 
         
            +
                    self.activation = get_activation_fn('gelu')
         
     | 
| 736 | 
         
            +
                    self.pool_ratios = pool_ratios
         
     | 
| 737 | 
         
            +
                    self.p_poses = []
         
     | 
| 738 | 
         
            +
                    self.g_pos = None
         
     | 
| 739 | 
         
            +
                    self.positional_encoding = PositionEmbeddingSine(num_pos_feats=d_model // 2, normalize=True)
         
     | 
| 740 | 
         
            +
             
     | 
| 741 | 
         
            +
                def forward(self, l, g):
         
     | 
| 742 | 
         
            +
                    """
         
     | 
| 743 | 
         
            +
                    l: 4,c,h,w
         
     | 
| 744 | 
         
            +
                    g: 1,c,h,w
         
     | 
| 745 | 
         
            +
                    """
         
     | 
| 746 | 
         
            +
                    self.p_poses = []
         
     | 
| 747 | 
         
            +
                    self.g_pos = None
         
     | 
| 748 | 
         
            +
                    b, c, h, w = l.size()
         
     | 
| 749 | 
         
            +
                    # 4,c,h,w -> 1,c,2h,2w
         
     | 
| 750 | 
         
            +
                    concated_locs = rearrange(l, '(hg wg b) c h w -> b c (hg h) (wg w)', hg=2, wg=2)
         
     | 
| 751 | 
         
            +
             
     | 
| 752 | 
         
            +
                    pools = []
         
     | 
| 753 | 
         
            +
                    for pool_ratio in self.pool_ratios:
         
     | 
| 754 | 
         
            +
                        # b,c,h,w
         
     | 
| 755 | 
         
            +
                        tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
         
     | 
| 756 | 
         
            +
                        pool = F.adaptive_avg_pool2d(concated_locs, tgt_hw)
         
     | 
| 757 | 
         
            +
                        pools.append(rearrange(pool, 'b c h w -> (h w) b c'))
         
     | 
| 758 | 
         
            +
                        if self.g_pos is None:
         
     | 
| 759 | 
         
            +
                            pos_emb = self.positional_encoding(pool.shape[0], pool.shape[2], pool.shape[3])
         
     | 
| 760 | 
         
            +
                            pos_emb = rearrange(pos_emb, 'b c h w -> (h w) b c')
         
     | 
| 761 | 
         
            +
                            self.p_poses.append(pos_emb)
         
     | 
| 762 | 
         
            +
                    pools = torch.cat(pools, 0)
         
     | 
| 763 | 
         
            +
                    if self.g_pos is None:
         
     | 
| 764 | 
         
            +
                        self.p_poses = torch.cat(self.p_poses, dim=0)
         
     | 
| 765 | 
         
            +
                        pos_emb = self.positional_encoding(g.shape[0], g.shape[2], g.shape[3])
         
     | 
| 766 | 
         
            +
                        self.g_pos = rearrange(pos_emb, 'b c h w -> (h w) b c')
         
     | 
| 767 | 
         
            +
             
     | 
| 768 | 
         
            +
                    device = pools.device
         
     | 
| 769 | 
         
            +
                    self.p_poses = self.p_poses.to(device)
         
     | 
| 770 | 
         
            +
                    self.g_pos = self.g_pos.to(device)
         
     | 
| 771 | 
         
            +
             
     | 
| 772 | 
         
            +
                    # attention between glb (q) & multisensory concated-locs (k,v)
         
     | 
| 773 | 
         
            +
                    g_hw_b_c = rearrange(g, 'b c h w -> (h w) b c')
         
     | 
| 774 | 
         
            +
             
     | 
| 775 | 
         
            +
                    g_hw_b_c = g_hw_b_c + self.dropout1(self.attention[0](g_hw_b_c + self.g_pos, pools + self.p_poses, pools)[0])
         
     | 
| 776 | 
         
            +
                    g_hw_b_c = self.norm1(g_hw_b_c)
         
     | 
| 777 | 
         
            +
                    g_hw_b_c = g_hw_b_c + self.dropout2(self.linear2(self.dropout(self.activation(self.linear1(g_hw_b_c)).clone())))
         
     | 
| 778 | 
         
            +
                    g_hw_b_c = self.norm2(g_hw_b_c)
         
     | 
| 779 | 
         
            +
             
     | 
| 780 | 
         
            +
                    # attention between origin locs (q) & freashed glb (k,v)
         
     | 
| 781 | 
         
            +
                    l_hw_b_c = rearrange(l, "b c h w -> (h w) b c")
         
     | 
| 782 | 
         
            +
                    _g_hw_b_c = rearrange(g_hw_b_c, '(h w) b c -> h w b c', h=h, w=w)
         
     | 
| 783 | 
         
            +
                    _g_hw_b_c = rearrange(_g_hw_b_c, "(ng h) (nw w) b c -> (h w) (ng nw b) c", ng=2, nw=2)
         
     | 
| 784 | 
         
            +
                    outputs_re = []
         
     | 
| 785 | 
         
            +
                    for i, (_l, _g) in enumerate(zip(l_hw_b_c.chunk(4, dim=1), _g_hw_b_c.chunk(4, dim=1))):
         
     | 
| 786 | 
         
            +
                        outputs_re.append(self.attention[i + 1](_l, _g, _g)[0])  # (h w) 1 c
         
     | 
| 787 | 
         
            +
                    outputs_re = torch.cat(outputs_re, 1)  # (h w) 4 c
         
     | 
| 788 | 
         
            +
             
     | 
| 789 | 
         
            +
                    l_hw_b_c = l_hw_b_c + self.dropout1(outputs_re)
         
     | 
| 790 | 
         
            +
                    l_hw_b_c = self.norm1(l_hw_b_c)
         
     | 
| 791 | 
         
            +
                    l_hw_b_c = l_hw_b_c + self.dropout2(self.linear4(self.dropout(self.activation(self.linear3(l_hw_b_c)).clone())))
         
     | 
| 792 | 
         
            +
                    l_hw_b_c = self.norm2(l_hw_b_c)
         
     | 
| 793 | 
         
            +
             
     | 
| 794 | 
         
            +
                    l = torch.cat((l_hw_b_c, g_hw_b_c), 1)  # hw,b(5),c
         
     | 
| 795 | 
         
            +
                    return rearrange(l, "(h w) b c -> b c h w", h=h, w=w)  ## (5,c,h*w)
         
     | 
| 796 | 
         
            +
             
     | 
| 797 | 
         
            +
             
     | 
| 798 | 
         
            +
            class MCRM(nn.Module):
         
     | 
| 799 | 
         
            +
                def __init__(self, d_model, num_heads, pool_ratios=[4, 8, 16], h=None):
         
     | 
| 800 | 
         
            +
                    super(MCRM, self).__init__()
         
     | 
| 801 | 
         
            +
                    self.attention = nn.ModuleList([
         
     | 
| 802 | 
         
            +
                        nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
         
     | 
| 803 | 
         
            +
                        nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
         
     | 
| 804 | 
         
            +
                        nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
         
     | 
| 805 | 
         
            +
                        nn.MultiheadAttention(d_model, num_heads, dropout=0.1)
         
     | 
| 806 | 
         
            +
                    ])
         
     | 
| 807 | 
         
            +
                    self.linear3 = nn.Linear(d_model, d_model * 2)
         
     | 
| 808 | 
         
            +
                    self.linear4 = nn.Linear(d_model * 2, d_model)
         
     | 
| 809 | 
         
            +
                    self.norm1 = nn.LayerNorm(d_model)
         
     | 
| 810 | 
         
            +
                    self.norm2 = nn.LayerNorm(d_model)
         
     | 
| 811 | 
         
            +
                    self.dropout = nn.Dropout(0.1)
         
     | 
| 812 | 
         
            +
                    self.dropout1 = nn.Dropout(0.1)
         
     | 
| 813 | 
         
            +
                    self.dropout2 = nn.Dropout(0.1)
         
     | 
| 814 | 
         
            +
                    self.sigmoid = nn.Sigmoid()
         
     | 
| 815 | 
         
            +
                    self.activation = get_activation_fn('gelu')
         
     | 
| 816 | 
         
            +
                    self.sal_conv = nn.Conv2d(d_model, 1, 1)
         
     | 
| 817 | 
         
            +
                    self.pool_ratios = pool_ratios
         
     | 
| 818 | 
         
            +
             
     | 
| 819 | 
         
            +
                def forward(self, x):
         
     | 
| 820 | 
         
            +
                    device = x.device
         
     | 
| 821 | 
         
            +
                    b, c, h, w = x.size()
         
     | 
| 822 | 
         
            +
                    loc, glb = x.split([4, 1], dim=0)  # 4,c,h,w; 1,c,h,w
         
     | 
| 823 | 
         
            +
             
     | 
| 824 | 
         
            +
                    patched_glb = rearrange(glb, 'b c (hg h) (wg w) -> (hg wg b) c h w', hg=2, wg=2)
         
     | 
| 825 | 
         
            +
             
     | 
| 826 | 
         
            +
                    token_attention_map = self.sigmoid(self.sal_conv(glb))
         
     | 
| 827 | 
         
            +
                    token_attention_map = F.interpolate(token_attention_map, size=patches2image(loc).shape[-2:], mode='nearest')
         
     | 
| 828 | 
         
            +
                    loc = loc * rearrange(token_attention_map, 'b c (hg h) (wg w) -> (hg wg b) c h w', hg=2, wg=2)
         
     | 
| 829 | 
         
            +
             
     | 
| 830 | 
         
            +
                    pools = []
         
     | 
| 831 | 
         
            +
                    for pool_ratio in self.pool_ratios:
         
     | 
| 832 | 
         
            +
                        tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
         
     | 
| 833 | 
         
            +
                        pool = F.adaptive_avg_pool2d(patched_glb, tgt_hw)
         
     | 
| 834 | 
         
            +
                        pools.append(rearrange(pool, 'nl c h w -> nl c (h w)'))  # nl(4),c,hw
         
     | 
| 835 | 
         
            +
             
     | 
| 836 | 
         
            +
                    pools = rearrange(torch.cat(pools, 2), "nl c nphw -> nl nphw 1 c")
         
     | 
| 837 | 
         
            +
                    loc_ = rearrange(loc, 'nl c h w -> nl (h w) 1 c')
         
     | 
| 838 | 
         
            +
             
     | 
| 839 | 
         
            +
                    outputs = []
         
     | 
| 840 | 
         
            +
                    for i, q in enumerate(loc_.unbind(dim=0)):  # traverse all local patches
         
     | 
| 841 | 
         
            +
                        v = pools[i]
         
     | 
| 842 | 
         
            +
                        k = v
         
     | 
| 843 | 
         
            +
                        outputs.append(self.attention[i](q, k, v)[0])
         
     | 
| 844 | 
         
            +
             
     | 
| 845 | 
         
            +
                    outputs = torch.cat(outputs, 1)
         
     | 
| 846 | 
         
            +
                    src = loc.view(4, c, -1).permute(2, 0, 1) + self.dropout1(outputs)
         
     | 
| 847 | 
         
            +
                    src = self.norm1(src)
         
     | 
| 848 | 
         
            +
                    src = src + self.dropout2(self.linear4(self.dropout(self.activation(self.linear3(src)).clone())))
         
     | 
| 849 | 
         
            +
                    src = self.norm2(src)
         
     | 
| 850 | 
         
            +
                    src = src.permute(1, 2, 0).reshape(4, c, h, w)  # freshed loc
         
     | 
| 851 | 
         
            +
                    glb = glb + F.interpolate(patches2image(src), size=glb.shape[-2:], mode='nearest')  # freshed glb
         
     | 
| 852 | 
         
            +
             
     | 
| 853 | 
         
            +
                    return torch.cat((src, glb), 0), token_attention_map
         
     | 
| 854 | 
         
            +
             
     | 
| 855 | 
         
            +
             
     | 
| 856 | 
         
            +
            class BEN_Base(nn.Module):
         
     | 
| 857 | 
         
            +
                def __init__(self):
         
     | 
| 858 | 
         
            +
                    super().__init__()
         
     | 
| 859 | 
         
            +
             
     | 
| 860 | 
         
            +
                    self.backbone = SwinTransformer(embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=12)
         
     | 
| 861 | 
         
            +
                    emb_dim = 128
         
     | 
| 862 | 
         
            +
                    self.sideout5 = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
         
     | 
| 863 | 
         
            +
                    self.sideout4 = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
         
     | 
| 864 | 
         
            +
                    self.sideout3 = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
         
     | 
| 865 | 
         
            +
                    self.sideout2 = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
         
     | 
| 866 | 
         
            +
                    self.sideout1 = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
         
     | 
| 867 | 
         
            +
             
     | 
| 868 | 
         
            +
                    self.output5 = make_cbr(1024, emb_dim)
         
     | 
| 869 | 
         
            +
                    self.output4 = make_cbr(512, emb_dim)
         
     | 
| 870 | 
         
            +
                    self.output3 = make_cbr(256, emb_dim)
         
     | 
| 871 | 
         
            +
                    self.output2 = make_cbr(128, emb_dim)
         
     | 
| 872 | 
         
            +
                    self.output1 = make_cbr(128, emb_dim)
         
     | 
| 873 | 
         
            +
             
     | 
| 874 | 
         
            +
                    self.multifieldcrossatt = MCLM(emb_dim, 1, [1, 4, 8])
         
     | 
| 875 | 
         
            +
                    self.conv1 = make_cbr(emb_dim, emb_dim)
         
     | 
| 876 | 
         
            +
                    self.conv2 = make_cbr(emb_dim, emb_dim)
         
     | 
| 877 | 
         
            +
                    self.conv3 = make_cbr(emb_dim, emb_dim)
         
     | 
| 878 | 
         
            +
                    self.conv4 = make_cbr(emb_dim, emb_dim)
         
     | 
| 879 | 
         
            +
                    self.dec_blk1 = MCRM(emb_dim, 1, [2, 4, 8])
         
     | 
| 880 | 
         
            +
                    self.dec_blk2 = MCRM(emb_dim, 1, [2, 4, 8])
         
     | 
| 881 | 
         
            +
                    self.dec_blk3 = MCRM(emb_dim, 1, [2, 4, 8])
         
     | 
| 882 | 
         
            +
                    self.dec_blk4 = MCRM(emb_dim, 1, [2, 4, 8])
         
     | 
| 883 | 
         
            +
             
     | 
| 884 | 
         
            +
                    self.insmask_head = nn.Sequential(
         
     | 
| 885 | 
         
            +
                        nn.Conv2d(emb_dim, 384, kernel_size=3, padding=1),
         
     | 
| 886 | 
         
            +
                        nn.InstanceNorm2d(384),
         
     | 
| 887 | 
         
            +
                        nn.GELU(),
         
     | 
| 888 | 
         
            +
                        nn.Conv2d(384, 384, kernel_size=3, padding=1),
         
     | 
| 889 | 
         
            +
                        nn.InstanceNorm2d(384),
         
     | 
| 890 | 
         
            +
                        nn.GELU(),
         
     | 
| 891 | 
         
            +
                        nn.Conv2d(384, emb_dim, kernel_size=3, padding=1)
         
     | 
| 892 | 
         
            +
                    )
         
     | 
| 893 | 
         
            +
             
     | 
| 894 | 
         
            +
                    self.shallow = nn.Sequential(nn.Conv2d(3, emb_dim, kernel_size=3, padding=1))
         
     | 
| 895 | 
         
            +
                    self.upsample1 = make_cbg(emb_dim, emb_dim)
         
     | 
| 896 | 
         
            +
                    self.upsample2 = make_cbg(emb_dim, emb_dim)
         
     | 
| 897 | 
         
            +
                    self.output = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
         
     | 
| 898 | 
         
            +
             
     | 
| 899 | 
         
            +
                    for m in self.modules():
         
     | 
| 900 | 
         
            +
                        if isinstance(m, nn.GELU) or isinstance(m, nn.Dropout):
         
     | 
| 901 | 
         
            +
                            m.inplace = True
         
     | 
| 902 | 
         
            +
             
     | 
| 903 | 
         
            +
                @torch.inference_mode()
         
     | 
| 904 | 
         
            +
                @torch.autocast(device_type="cuda", dtype=torch.float16)
         
     | 
| 905 | 
         
            +
                def forward(self, x):
         
     | 
| 906 | 
         
            +
                    real_batch = x.size(0)
         
     | 
| 907 | 
         
            +
             
     | 
| 908 | 
         
            +
                    shallow_batch = self.shallow(x)
         
     | 
| 909 | 
         
            +
                    glb_batch = rescale_to(x, scale_factor=0.5, interpolation='bilinear')
         
     | 
| 910 | 
         
            +
             
     | 
| 911 | 
         
            +
                    final_input = None
         
     | 
| 912 | 
         
            +
                    for i in range(real_batch):
         
     | 
| 913 | 
         
            +
                        start = i * 4
         
     | 
| 914 | 
         
            +
                        end = (i + 1) * 4
         
     | 
| 915 | 
         
            +
                        loc_batch = image2patches(x[i, :, :, :].unsqueeze(dim=0))
         
     | 
| 916 | 
         
            +
                        input_ = torch.cat((loc_batch, glb_batch[i, :, :, :].unsqueeze(dim=0)), dim=0)
         
     | 
| 917 | 
         
            +
             
     | 
| 918 | 
         
            +
                        if final_input == None:
         
     | 
| 919 | 
         
            +
                            final_input = input_
         
     | 
| 920 | 
         
            +
                        else:
         
     | 
| 921 | 
         
            +
                            final_input = torch.cat((final_input, input_), dim=0)
         
     | 
| 922 | 
         
            +
             
     | 
| 923 | 
         
            +
                    features = self.backbone(final_input)
         
     | 
| 924 | 
         
            +
                    outputs = []
         
     | 
| 925 | 
         
            +
             
     | 
| 926 | 
         
            +
                    for i in range(real_batch):
         
     | 
| 927 | 
         
            +
                        start = i * 5
         
     | 
| 928 | 
         
            +
                        end = (i + 1) * 5
         
     | 
| 929 | 
         
            +
             
     | 
| 930 | 
         
            +
                        f4 = features[4][start:end, :, :, :]  # shape: [5, C, H, W]
         
     | 
| 931 | 
         
            +
                        f3 = features[3][start:end, :, :, :]
         
     | 
| 932 | 
         
            +
                        f2 = features[2][start:end, :, :, :]
         
     | 
| 933 | 
         
            +
                        f1 = features[1][start:end, :, :, :]
         
     | 
| 934 | 
         
            +
                        f0 = features[0][start:end, :, :, :]
         
     | 
| 935 | 
         
            +
                        e5 = self.output5(f4)
         
     | 
| 936 | 
         
            +
                        e4 = self.output4(f3)
         
     | 
| 937 | 
         
            +
                        e3 = self.output3(f2)
         
     | 
| 938 | 
         
            +
                        e2 = self.output2(f1)
         
     | 
| 939 | 
         
            +
                        e1 = self.output1(f0)
         
     | 
| 940 | 
         
            +
                        loc_e5, glb_e5 = e5.split([4, 1], dim=0)
         
     | 
| 941 | 
         
            +
                        e5 = self.multifieldcrossatt(loc_e5, glb_e5)  # (4,128,16,16)
         
     | 
| 942 | 
         
            +
             
     | 
| 943 | 
         
            +
                        e4, tokenattmap4 = self.dec_blk4(e4 + resize_as(e5, e4))
         
     | 
| 944 | 
         
            +
                        e4 = self.conv4(e4)
         
     | 
| 945 | 
         
            +
                        e3, tokenattmap3 = self.dec_blk3(e3 + resize_as(e4, e3))
         
     | 
| 946 | 
         
            +
                        e3 = self.conv3(e3)
         
     | 
| 947 | 
         
            +
                        e2, tokenattmap2 = self.dec_blk2(e2 + resize_as(e3, e2))
         
     | 
| 948 | 
         
            +
                        e2 = self.conv2(e2)
         
     | 
| 949 | 
         
            +
                        e1, tokenattmap1 = self.dec_blk1(e1 + resize_as(e2, e1))
         
     | 
| 950 | 
         
            +
                        e1 = self.conv1(e1)
         
     | 
| 951 | 
         
            +
             
     | 
| 952 | 
         
            +
                        loc_e1, glb_e1 = e1.split([4, 1], dim=0)
         
     | 
| 953 | 
         
            +
             
     | 
| 954 | 
         
            +
                        output1_cat = patches2image(loc_e1)  # (1,128,256,256)
         
     | 
| 955 | 
         
            +
             
     | 
| 956 | 
         
            +
                        # add glb feat in
         
     | 
| 957 | 
         
            +
                        output1_cat = output1_cat + resize_as(glb_e1, output1_cat)
         
     | 
| 958 | 
         
            +
                        # merge
         
     | 
| 959 | 
         
            +
                        final_output = self.insmask_head(output1_cat)  # (1,128,256,256)
         
     | 
| 960 | 
         
            +
                        # shallow feature merge
         
     | 
| 961 | 
         
            +
                        shallow = shallow_batch[i, :, :, :].unsqueeze(dim=0)
         
     | 
| 962 | 
         
            +
                        final_output = final_output + resize_as(shallow, final_output)
         
     | 
| 963 | 
         
            +
                        final_output = self.upsample1(rescale_to(final_output))
         
     | 
| 964 | 
         
            +
                        final_output = rescale_to(final_output + resize_as(shallow, final_output))
         
     | 
| 965 | 
         
            +
                        final_output = self.upsample2(final_output)
         
     | 
| 966 | 
         
            +
                        final_output = self.output(final_output)
         
     | 
| 967 | 
         
            +
                        mask = final_output.sigmoid()
         
     | 
| 968 | 
         
            +
                        outputs.append(mask)
         
     | 
| 969 | 
         
            +
             
     | 
| 970 | 
         
            +
                    return torch.cat(outputs, dim=0)
         
     | 
| 971 | 
         
            +
             
     | 
| 972 | 
         
            +
                def loadcheckpoints(self, model_path):
         
     | 
| 973 | 
         
            +
                    model_dict = torch.load(model_path, map_location="cpu", weights_only=True)
         
     | 
| 974 | 
         
            +
                    self.load_state_dict(model_dict['model_state_dict'], strict=True)
         
     | 
| 975 | 
         
            +
                    del model_path
         
     | 
| 976 | 
         
            +
             
     | 
| 977 | 
         
            +
                def inference(self, image, refine_foreground=False):
         
     | 
| 978 | 
         
            +
             
     | 
| 979 | 
         
            +
                    # set_random_seed(9)
         
     | 
| 980 | 
         
            +
                    # image = ImageOps.exif_transpose(image)
         
     | 
| 981 | 
         
            +
                    if isinstance(image, Image.Image):
         
     | 
| 982 | 
         
            +
                        image, h, w, original_image = rgb_loader_refiner(image)
         
     | 
| 983 | 
         
            +
                        if torch.cuda.is_available():
         
     | 
| 984 | 
         
            +
             
     | 
| 985 | 
         
            +
                            img_tensor = img_transform(image).unsqueeze(0).to(next(self.parameters()).device)
         
     | 
| 986 | 
         
            +
                        else:
         
     | 
| 987 | 
         
            +
                            img_tensor = img_transform32(image).unsqueeze(0).to(next(self.parameters()).device)
         
     | 
| 988 | 
         
            +
             
     | 
| 989 | 
         
            +
                        with torch.no_grad():
         
     | 
| 990 | 
         
            +
                            res = self.forward(img_tensor)
         
     | 
| 991 | 
         
            +
             
     | 
| 992 | 
         
            +
                        # Show Results
         
     | 
| 993 | 
         
            +
                        if refine_foreground == True:
         
     | 
| 994 | 
         
            +
             
     | 
| 995 | 
         
            +
                            pred_pil = transforms.ToPILImage()(res.squeeze())
         
     | 
| 996 | 
         
            +
                            image_masked = refine_foreground_process(original_image, pred_pil)
         
     | 
| 997 | 
         
            +
             
     | 
| 998 | 
         
            +
                            image_masked.putalpha(pred_pil.resize(original_image.size))
         
     | 
| 999 | 
         
            +
                            return image_masked
         
     | 
| 1000 | 
         
            +
             
     | 
| 1001 | 
         
            +
                        else:
         
     | 
| 1002 | 
         
            +
                            alpha = postprocess_image(res, im_size=[w, h])
         
     | 
| 1003 | 
         
            +
                            pred_pil = transforms.ToPILImage()(alpha)
         
     | 
| 1004 | 
         
            +
                            mask = pred_pil.resize(original_image.size)
         
     | 
| 1005 | 
         
            +
                            original_image.putalpha(mask)
         
     | 
| 1006 | 
         
            +
                            # mask = Image.fromarray(alpha)
         
     | 
| 1007 | 
         
            +
             
     | 
| 1008 | 
         
            +
                            # 将背景置为白色
         
     | 
| 1009 | 
         
            +
                            white_background = Image.new('RGB', original_image.size, (255, 255, 255))
         
     | 
| 1010 | 
         
            +
                            white_background.paste(original_image, mask=original_image.split()[3])
         
     | 
| 1011 | 
         
            +
                            original_image = white_background
         
     | 
| 1012 | 
         
            +
             
     | 
| 1013 | 
         
            +
                            return original_image
         
     | 
| 1014 | 
         
            +
             
     | 
| 1015 | 
         
            +
             
     | 
| 1016 | 
         
            +
                    else:
         
     | 
| 1017 | 
         
            +
                        foregrounds = []
         
     | 
| 1018 | 
         
            +
                        for batch in image:
         
     | 
| 1019 | 
         
            +
                            image, h, w, original_image = rgb_loader_refiner(batch)
         
     | 
| 1020 | 
         
            +
                            if torch.cuda.is_available():
         
     | 
| 1021 | 
         
            +
             
     | 
| 1022 | 
         
            +
                                img_tensor = img_transform(image).unsqueeze(0).to(next(self.parameters()).device)
         
     | 
| 1023 | 
         
            +
                            else:
         
     | 
| 1024 | 
         
            +
                                img_tensor = img_transform32(image).unsqueeze(0).to(next(self.parameters()).device)
         
     | 
| 1025 | 
         
            +
             
     | 
| 1026 | 
         
            +
                            with torch.no_grad():
         
     | 
| 1027 | 
         
            +
                                res = self.forward(img_tensor)
         
     | 
| 1028 | 
         
            +
             
     | 
| 1029 | 
         
            +
                            if refine_foreground == True:
         
     | 
| 1030 | 
         
            +
             
     | 
| 1031 | 
         
            +
                                pred_pil = transforms.ToPILImage()(res.squeeze())
         
     | 
| 1032 | 
         
            +
                                image_masked = refine_foreground_process(original_image, pred_pil)
         
     | 
| 1033 | 
         
            +
             
     | 
| 1034 | 
         
            +
                                image_masked.putalpha(pred_pil.resize(original_image.size))
         
     | 
| 1035 | 
         
            +
             
     | 
| 1036 | 
         
            +
                                foregrounds.append(image_masked)
         
     | 
| 1037 | 
         
            +
                            else:
         
     | 
| 1038 | 
         
            +
                                alpha = postprocess_image(res, im_size=[w, h])
         
     | 
| 1039 | 
         
            +
                                pred_pil = transforms.ToPILImage()(alpha)
         
     | 
| 1040 | 
         
            +
                                mask = pred_pil.resize(original_image.size)
         
     | 
| 1041 | 
         
            +
                                original_image.putalpha(mask)
         
     | 
| 1042 | 
         
            +
                                # mask = Image.fromarray(alpha)
         
     | 
| 1043 | 
         
            +
                                foregrounds.append(original_image)
         
     | 
| 1044 | 
         
            +
             
     | 
| 1045 | 
         
            +
                        return foregrounds
         
     | 
| 1046 | 
         
            +
             
     | 
| 1047 | 
         
            +
                def segment_video(self, video_path, output_path="./", fps=0, refine_foreground=False, batch=1,
         
     | 
| 1048 | 
         
            +
                                  print_frames_processed=True, webm=False, rgb_value=(0, 255, 0)):
         
     | 
| 1049 | 
         
            +
             
     | 
| 1050 | 
         
            +
                    """
         
     | 
| 1051 | 
         
            +
                    Segments the given video to extract the foreground (with alpha) from each frame
         
     | 
| 1052 | 
         
            +
                    and saves the result as either a WebM video (with alpha channel) or MP4 (with a
         
     | 
| 1053 | 
         
            +
                    color background).
         
     | 
| 1054 | 
         
            +
             
     | 
| 1055 | 
         
            +
                    Args:
         
     | 
| 1056 | 
         
            +
                        video_path (str):
         
     | 
| 1057 | 
         
            +
                            Path to the input video file.
         
     | 
| 1058 | 
         
            +
             
     | 
| 1059 | 
         
            +
                        output_path (str, optional):
         
     | 
| 1060 | 
         
            +
                            Directory (or full path) where the output video and/or files will be saved.
         
     | 
| 1061 | 
         
            +
                            Defaults to "./".
         
     | 
| 1062 | 
         
            +
             
     | 
| 1063 | 
         
            +
                        fps (int, optional):
         
     | 
| 1064 | 
         
            +
                            The frames per second (FPS) to use for the output video. If 0 (default), the
         
     | 
| 1065 | 
         
            +
                            original FPS of the input video is used. Otherwise, overrides it.
         
     | 
| 1066 | 
         
            +
             
     | 
| 1067 | 
         
            +
                        refine_foreground (bool, optional):
         
     | 
| 1068 | 
         
            +
                            Whether to run an additional “refine foreground” process on each frame.
         
     | 
| 1069 | 
         
            +
                            Defaults to False.
         
     | 
| 1070 | 
         
            +
             
     | 
| 1071 | 
         
            +
                        batch (int, optional):
         
     | 
| 1072 | 
         
            +
                            Number of frames to process at once (inference batch size). Large batch sizes
         
     | 
| 1073 | 
         
            +
                            may require more GPU memory. Defaults to 1.
         
     | 
| 1074 | 
         
            +
             
     | 
| 1075 | 
         
            +
                        print_frames_processed (bool, optional):
         
     | 
| 1076 | 
         
            +
                            If True (default), prints progress (how many frames have been processed) to
         
     | 
| 1077 | 
         
            +
                            the console.
         
     | 
| 1078 | 
         
            +
             
     | 
| 1079 | 
         
            +
                        webm (bool, optional):
         
     | 
| 1080 | 
         
            +
                            If True (default), exports a WebM video with alpha channel (VP9 / yuva420p).
         
     | 
| 1081 | 
         
            +
                            If False, exports an MP4 video composited over a solid color background.
         
     | 
| 1082 | 
         
            +
             
     | 
| 1083 | 
         
            +
                        rgb_value (tuple, optional):
         
     | 
| 1084 | 
         
            +
                            The RGB background color (e.g., green screen) used to composite frames when
         
     | 
| 1085 | 
         
            +
                            saving to MP4. Defaults to (0, 255, 0).
         
     | 
| 1086 | 
         
            +
             
     | 
| 1087 | 
         
            +
                    Returns:
         
     | 
| 1088 | 
         
            +
                        None. Writes the output video(s) to disk in the specified format.
         
     | 
| 1089 | 
         
            +
                    """
         
     | 
| 1090 | 
         
            +
             
     | 
| 1091 | 
         
            +
                    cap = cv2.VideoCapture(video_path)
         
     | 
| 1092 | 
         
            +
                    if not cap.isOpened():
         
     | 
| 1093 | 
         
            +
                        raise IOError(f"Cannot open video: {video_path}")
         
     | 
| 1094 | 
         
            +
             
     | 
| 1095 | 
         
            +
                    original_fps = cap.get(cv2.CAP_PROP_FPS)
         
     | 
| 1096 | 
         
            +
                    original_fps = 30 if original_fps == 0 else original_fps
         
     | 
| 1097 | 
         
            +
                    fps = original_fps if fps == 0 else fps
         
     | 
| 1098 | 
         
            +
             
     | 
| 1099 | 
         
            +
                    ret, first_frame = cap.read()
         
     | 
| 1100 | 
         
            +
                    if not ret:
         
     | 
| 1101 | 
         
            +
                        raise ValueError("No frames found in the video.")
         
     | 
| 1102 | 
         
            +
                    height, width = first_frame.shape[:2]
         
     | 
| 1103 | 
         
            +
                    cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
         
     | 
| 1104 | 
         
            +
             
     | 
| 1105 | 
         
            +
                    foregrounds = []
         
     | 
| 1106 | 
         
            +
                    frame_idx = 0
         
     | 
| 1107 | 
         
            +
                    processed_count = 0
         
     | 
| 1108 | 
         
            +
                    batch_frames = []
         
     | 
| 1109 | 
         
            +
                    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
         
     | 
| 1110 | 
         
            +
             
     | 
| 1111 | 
         
            +
                    while True:
         
     | 
| 1112 | 
         
            +
                        ret, frame = cap.read()
         
     | 
| 1113 | 
         
            +
                        if not ret:
         
     | 
| 1114 | 
         
            +
                            if batch_frames:
         
     | 
| 1115 | 
         
            +
                                batch_results = self.inference(batch_frames, refine_foreground)
         
     | 
| 1116 | 
         
            +
                                if isinstance(batch_results, Image.Image):
         
     | 
| 1117 | 
         
            +
                                    foregrounds.append(batch_results)
         
     | 
| 1118 | 
         
            +
                                else:
         
     | 
| 1119 | 
         
            +
                                    foregrounds.extend(batch_results)
         
     | 
| 1120 | 
         
            +
                                if print_frames_processed:
         
     | 
| 1121 | 
         
            +
                                    print(f"Processed frames {frame_idx - len(batch_frames) + 1} to {frame_idx} of {total_frames}")
         
     | 
| 1122 | 
         
            +
                            break
         
     | 
| 1123 | 
         
            +
             
     | 
| 1124 | 
         
            +
                        # Process every frame instead of using intervals
         
     | 
| 1125 | 
         
            +
                        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
         
     | 
| 1126 | 
         
            +
                        pil_frame = Image.fromarray(frame_rgb)
         
     | 
| 1127 | 
         
            +
                        batch_frames.append(pil_frame)
         
     | 
| 1128 | 
         
            +
             
     | 
| 1129 | 
         
            +
                        if len(batch_frames) == batch:
         
     | 
| 1130 | 
         
            +
                            batch_results = self.inference(batch_frames, refine_foreground)
         
     | 
| 1131 | 
         
            +
                            if isinstance(batch_results, Image.Image):
         
     | 
| 1132 | 
         
            +
                                foregrounds.append(batch_results)
         
     | 
| 1133 | 
         
            +
                            else:
         
     | 
| 1134 | 
         
            +
                                foregrounds.extend(batch_results)
         
     | 
| 1135 | 
         
            +
                            if print_frames_processed:
         
     | 
| 1136 | 
         
            +
                                print(f"Processed frames {frame_idx - batch + 1} to {frame_idx} of {total_frames}")
         
     | 
| 1137 | 
         
            +
                            batch_frames = []
         
     | 
| 1138 | 
         
            +
                            processed_count += batch
         
     | 
| 1139 | 
         
            +
             
     | 
| 1140 | 
         
            +
                        frame_idx += 1
         
     | 
| 1141 | 
         
            +
             
     | 
| 1142 | 
         
            +
                    if webm:
         
     | 
| 1143 | 
         
            +
                        alpha_webm_path = os.path.join(output_path, "foreground.webm")
         
     | 
| 1144 | 
         
            +
                        pil_images_to_webm_alpha(foregrounds, alpha_webm_path, fps=original_fps)
         
     | 
| 1145 | 
         
            +
             
     | 
| 1146 | 
         
            +
                    else:
         
     | 
| 1147 | 
         
            +
                        cap.release()
         
     | 
| 1148 | 
         
            +
                        fg_output = os.path.join(output_path, 'foreground.mp4')
         
     | 
| 1149 | 
         
            +
             
     | 
| 1150 | 
         
            +
                        pil_images_to_mp4(foregrounds, fg_output, fps=original_fps, rgb_value=rgb_value)
         
     | 
| 1151 | 
         
            +
                        cv2.destroyAllWindows()
         
     | 
| 1152 | 
         
            +
             
     | 
| 1153 | 
         
            +
                        try:
         
     | 
| 1154 | 
         
            +
                            fg_audio_output = os.path.join(output_path, 'foreground_output_with_audio.mp4')
         
     | 
| 1155 | 
         
            +
                            add_audio_to_video(fg_output, video_path, fg_audio_output)
         
     | 
| 1156 | 
         
            +
                        except Exception as e:
         
     | 
| 1157 | 
         
            +
                            print("No audio found in the original video")
         
     | 
| 1158 | 
         
            +
                            print(e)
         
     | 
| 1159 | 
         
            +
             
     | 
| 1160 | 
         
            +
             
     | 
| 1161 | 
         
            +
            def rgb_loader_refiner(original_image):
         
     | 
| 1162 | 
         
            +
                h, w = original_image.size
         
     | 
| 1163 | 
         
            +
             
     | 
| 1164 | 
         
            +
                image = original_image
         
     | 
| 1165 | 
         
            +
                # Convert to RGB if necessary
         
     | 
| 1166 | 
         
            +
                if image.mode != 'RGB':
         
     | 
| 1167 | 
         
            +
                    image = image.convert('RGB')
         
     | 
| 1168 | 
         
            +
             
     | 
| 1169 | 
         
            +
                # Resize the image
         
     | 
| 1170 | 
         
            +
                image = image.resize((1024, 1024), resample=Image.LANCZOS)
         
     | 
| 1171 | 
         
            +
             
     | 
| 1172 | 
         
            +
                return image.convert('RGB'), h, w, original_image
         
     | 
| 1173 | 
         
            +
             
     | 
| 1174 | 
         
            +
             
     | 
| 1175 | 
         
            +
            # Define the image transformation
         
     | 
| 1176 | 
         
            +
            img_transform = transforms.Compose([
         
     | 
| 1177 | 
         
            +
                transforms.ToTensor(),
         
     | 
| 1178 | 
         
            +
                transforms.ConvertImageDtype(torch.float16),
         
     | 
| 1179 | 
         
            +
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
         
     | 
| 1180 | 
         
            +
            ])
         
     | 
| 1181 | 
         
            +
             
     | 
| 1182 | 
         
            +
            img_transform32 = transforms.Compose([
         
     | 
| 1183 | 
         
            +
                transforms.ToTensor(),
         
     | 
| 1184 | 
         
            +
                transforms.ConvertImageDtype(torch.float32),
         
     | 
| 1185 | 
         
            +
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
         
     | 
| 1186 | 
         
            +
            ])
         
     | 
| 1187 | 
         
            +
             
     | 
| 1188 | 
         
            +
             
     | 
| 1189 | 
         
            +
            def pil_images_to_mp4(images, output_path, fps=24, rgb_value=(0, 255, 0)):
         
     | 
| 1190 | 
         
            +
                """
         
     | 
| 1191 | 
         
            +
                Converts an array of PIL images to an MP4 video.
         
     | 
| 1192 | 
         
            +
             
     | 
| 1193 | 
         
            +
                Args:
         
     | 
| 1194 | 
         
            +
                    images: List of PIL images
         
     | 
| 1195 | 
         
            +
                    output_path: Path to save the MP4 file
         
     | 
| 1196 | 
         
            +
                    fps: Frames per second (default: 24)
         
     | 
| 1197 | 
         
            +
                    rgb_value: Background RGB color tuple (default: green (0, 255, 0))
         
     | 
| 1198 | 
         
            +
                """
         
     | 
| 1199 | 
         
            +
                if not images:
         
     | 
| 1200 | 
         
            +
                    raise ValueError("No images provided to convert to MP4.")
         
     | 
| 1201 | 
         
            +
             
     | 
| 1202 | 
         
            +
                width, height = images[0].size
         
     | 
| 1203 | 
         
            +
                fourcc = cv2.VideoWriter_fourcc(*'mp4v')
         
     | 
| 1204 | 
         
            +
                video_writer = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
         
     | 
| 1205 | 
         
            +
             
     | 
| 1206 | 
         
            +
                for image in images:
         
     | 
| 1207 | 
         
            +
                    # If image has alpha channel, composite onto the specified background color
         
     | 
| 1208 | 
         
            +
                    if image.mode == 'RGBA':
         
     | 
| 1209 | 
         
            +
                        # Create background image with specified RGB color
         
     | 
| 1210 | 
         
            +
                        background = Image.new('RGB', image.size, rgb_value)
         
     | 
| 1211 | 
         
            +
                        background = background.convert('RGBA')
         
     | 
| 1212 | 
         
            +
                        # Composite the image onto the background
         
     | 
| 1213 | 
         
            +
                        image = Image.alpha_composite(background, image)
         
     | 
| 1214 | 
         
            +
                        image = image.convert('RGB')
         
     | 
| 1215 | 
         
            +
                    else:
         
     | 
| 1216 | 
         
            +
                        # Ensure RGB format for non-alpha images
         
     | 
| 1217 | 
         
            +
                        image = image.convert('RGB')
         
     | 
| 1218 | 
         
            +
             
     | 
| 1219 | 
         
            +
                    # Convert to OpenCV format and write
         
     | 
| 1220 | 
         
            +
                    open_cv_image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
         
     | 
| 1221 | 
         
            +
                    video_writer.write(open_cv_image)
         
     | 
| 1222 | 
         
            +
             
     | 
| 1223 | 
         
            +
                video_writer.release()
         
     | 
| 1224 | 
         
            +
             
     | 
| 1225 | 
         
            +
             
     | 
| 1226 | 
         
            +
            def pil_images_to_webm_alpha(images, output_path, fps=30):
         
     | 
| 1227 | 
         
            +
                """
         
     | 
| 1228 | 
         
            +
                Converts a list of PIL RGBA images to a VP9 .webm video with alpha channel.
         
     | 
| 1229 | 
         
            +
             
     | 
| 1230 | 
         
            +
                NOTE: Not all players will display alpha in WebM.
         
     | 
| 1231 | 
         
            +
                      Browsers like Chrome/Firefox typically do support VP9 alpha.
         
     | 
| 1232 | 
         
            +
                """
         
     | 
| 1233 | 
         
            +
                if not images:
         
     | 
| 1234 | 
         
            +
                    raise ValueError("No images provided for WebM with alpha.")
         
     | 
| 1235 | 
         
            +
             
     | 
| 1236 | 
         
            +
                # Ensure output directory exists
         
     | 
| 1237 | 
         
            +
                os.makedirs(os.path.dirname(output_path), exist_ok=True)
         
     | 
| 1238 | 
         
            +
             
     | 
| 1239 | 
         
            +
                with tempfile.TemporaryDirectory() as tmpdir:
         
     | 
| 1240 | 
         
            +
                    # Save frames as PNG (with alpha)
         
     | 
| 1241 | 
         
            +
                    for idx, img in enumerate(images):
         
     | 
| 1242 | 
         
            +
                        if img.mode != "RGBA":
         
     | 
| 1243 | 
         
            +
                            img = img.convert("RGBA")
         
     | 
| 1244 | 
         
            +
                        out_path = os.path.join(tmpdir, f"{idx:06d}.png")
         
     | 
| 1245 | 
         
            +
                        img.save(out_path, "PNG")
         
     | 
| 1246 | 
         
            +
             
     | 
| 1247 | 
         
            +
                    # Construct ffmpeg command
         
     | 
| 1248 | 
         
            +
                    # -c:v libvpx-vp9 => VP9 encoder
         
     | 
| 1249 | 
         
            +
                    # -pix_fmt yuva420p => alpha-enabled pixel format
         
     | 
| 1250 | 
         
            +
                    # -auto-alt-ref 0 => helps preserve alpha frames (libvpx quirk)
         
     | 
| 1251 | 
         
            +
                    ffmpeg_cmd = [
         
     | 
| 1252 | 
         
            +
                        "ffmpeg", "-y",
         
     | 
| 1253 | 
         
            +
                        "-framerate", str(fps),
         
     | 
| 1254 | 
         
            +
                        "-i", os.path.join(tmpdir, "%06d.png"),
         
     | 
| 1255 | 
         
            +
                        "-c:v", "libvpx-vp9",
         
     | 
| 1256 | 
         
            +
                        "-pix_fmt", "yuva420p",
         
     | 
| 1257 | 
         
            +
                        "-auto-alt-ref", "0",
         
     | 
| 1258 | 
         
            +
                        output_path
         
     | 
| 1259 | 
         
            +
                    ]
         
     | 
| 1260 | 
         
            +
             
     | 
| 1261 | 
         
            +
                    subprocess.run(ffmpeg_cmd, check=True)
         
     | 
| 1262 | 
         
            +
             
     | 
| 1263 | 
         
            +
                print(f"WebM with alpha saved to {output_path}")
         
     | 
| 1264 | 
         
            +
             
     | 
| 1265 | 
         
            +
             
     | 
| 1266 | 
         
            +
            def add_audio_to_video(video_without_audio_path, original_video_path, output_path):
         
     | 
| 1267 | 
         
            +
                """
         
     | 
| 1268 | 
         
            +
                Check if the original video has an audio stream. If yes, add it. If not, skip.
         
     | 
| 1269 | 
         
            +
                """
         
     | 
| 1270 | 
         
            +
                # 1) Probe original video for audio streams
         
     | 
| 1271 | 
         
            +
                probe_command = [
         
     | 
| 1272 | 
         
            +
                    'ffprobe', '-v', 'error',
         
     | 
| 1273 | 
         
            +
                    '-select_streams', 'a:0',
         
     | 
| 1274 | 
         
            +
                    '-show_entries', 'stream=index',
         
     | 
| 1275 | 
         
            +
                    '-of', 'csv=p=0',
         
     | 
| 1276 | 
         
            +
                    original_video_path
         
     | 
| 1277 | 
         
            +
                ]
         
     | 
| 1278 | 
         
            +
                result = subprocess.run(probe_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
         
     | 
| 1279 | 
         
            +
             
     | 
| 1280 | 
         
            +
                # result.stdout is empty if no audio stream found
         
     | 
| 1281 | 
         
            +
                if not result.stdout.strip():
         
     | 
| 1282 | 
         
            +
                    print("No audio track found in original video, skipping audio addition.")
         
     | 
| 1283 | 
         
            +
                    return
         
     | 
| 1284 | 
         
            +
             
     | 
| 1285 | 
         
            +
                print("Audio track detected; proceeding to mux audio.")
         
     | 
| 1286 | 
         
            +
                # 2) If audio found, run ffmpeg to add it
         
     | 
| 1287 | 
         
            +
                command = [
         
     | 
| 1288 | 
         
            +
                    'ffmpeg', '-y',
         
     | 
| 1289 | 
         
            +
                    '-i', video_without_audio_path,
         
     | 
| 1290 | 
         
            +
                    '-i', original_video_path,
         
     | 
| 1291 | 
         
            +
                    '-c', 'copy',
         
     | 
| 1292 | 
         
            +
                    '-map', '0:v:0',
         
     | 
| 1293 | 
         
            +
                    '-map', '1:a:0',  # we know there's an audio track now
         
     | 
| 1294 | 
         
            +
                    output_path
         
     | 
| 1295 | 
         
            +
                ]
         
     | 
| 1296 | 
         
            +
                subprocess.run(command, check=True)
         
     | 
| 1297 | 
         
            +
                print(f"Audio added successfully => {output_path}")
         
     | 
| 1298 | 
         
            +
             
     | 
| 1299 | 
         
            +
             
     | 
| 1300 | 
         
            +
            ### Thanks to the source: https://huggingface.co/ZhengPeng7/BiRefNet/blob/main/handler.py
         
     | 
| 1301 | 
         
            +
            def refine_foreground_process(image, mask, r=90):
         
     | 
| 1302 | 
         
            +
                if mask.size != image.size:
         
     | 
| 1303 | 
         
            +
                    mask = mask.resize(image.size)
         
     | 
| 1304 | 
         
            +
                image = np.array(image) / 255.0
         
     | 
| 1305 | 
         
            +
                mask = np.array(mask) / 255.0
         
     | 
| 1306 | 
         
            +
                estimated_foreground = FB_blur_fusion_foreground_estimator_2(image, mask, r=r)
         
     | 
| 1307 | 
         
            +
                image_masked = Image.fromarray((estimated_foreground * 255.0).astype(np.uint8))
         
     | 
| 1308 | 
         
            +
                return image_masked
         
     | 
| 1309 | 
         
            +
             
     | 
| 1310 | 
         
            +
             
     | 
| 1311 | 
         
            +
            def FB_blur_fusion_foreground_estimator_2(image, alpha, r=90):
         
     | 
| 1312 | 
         
            +
                # Thanks to the source: https://github.com/Photoroom/fast-foreground-estimation
         
     | 
| 1313 | 
         
            +
                alpha = alpha[:, :, None]
         
     | 
| 1314 | 
         
            +
                F, blur_B = FB_blur_fusion_foreground_estimator(image, image, image, alpha, r)
         
     | 
| 1315 | 
         
            +
                return FB_blur_fusion_foreground_estimator(image, F, blur_B, alpha, r=6)[0]
         
     | 
| 1316 | 
         
            +
             
     | 
| 1317 | 
         
            +
             
     | 
| 1318 | 
         
            +
            def FB_blur_fusion_foreground_estimator(image, F, B, alpha, r=90):
         
     | 
| 1319 | 
         
            +
                if isinstance(image, Image.Image):
         
     | 
| 1320 | 
         
            +
                    image = np.array(image) / 255.0
         
     | 
| 1321 | 
         
            +
                blurred_alpha = cv2.blur(alpha, (r, r))[:, :, None]
         
     | 
| 1322 | 
         
            +
             
     | 
| 1323 | 
         
            +
                blurred_FA = cv2.blur(F * alpha, (r, r))
         
     | 
| 1324 | 
         
            +
                blurred_F = blurred_FA / (blurred_alpha + 1e-5)
         
     | 
| 1325 | 
         
            +
             
     | 
| 1326 | 
         
            +
                blurred_B1A = cv2.blur(B * (1 - alpha), (r, r))
         
     | 
| 1327 | 
         
            +
                blurred_B = blurred_B1A / ((1 - blurred_alpha) + 1e-5)
         
     | 
| 1328 | 
         
            +
                F = blurred_F + alpha * \
         
     | 
| 1329 | 
         
            +
                    (image - alpha * blurred_F - (1 - alpha) * blurred_B)
         
     | 
| 1330 | 
         
            +
                F = np.clip(F, 0, 1)
         
     | 
| 1331 | 
         
            +
                return F, blurred_B
         
     | 
| 1332 | 
         
            +
             
     | 
| 1333 | 
         
            +
             
     | 
| 1334 | 
         
            +
            def postprocess_image(result: torch.Tensor, im_size: list) -> np.ndarray:
         
     | 
| 1335 | 
         
            +
                result = torch.squeeze(F.interpolate(result, size=im_size, mode='bilinear'), 0)
         
     | 
| 1336 | 
         
            +
                ma = torch.max(result)
         
     | 
| 1337 | 
         
            +
                mi = torch.min(result)
         
     | 
| 1338 | 
         
            +
                result = (result - mi) / (ma - mi)
         
     | 
| 1339 | 
         
            +
                im_array = (result * 255).permute(1, 2, 0).cpu().data.numpy().astype(np.uint8)
         
     | 
| 1340 | 
         
            +
                im_array = np.squeeze(im_array)
         
     | 
| 1341 | 
         
            +
                return im_array
         
     | 
| 1342 | 
         
            +
             
     | 
| 1343 | 
         
            +
             
     | 
| 1344 | 
         
            +
            def rgb_loader_refiner(original_image):
         
     | 
| 1345 | 
         
            +
                h, w = original_image.size
         
     | 
| 1346 | 
         
            +
                # # Apply EXIF orientation
         
     | 
| 1347 | 
         
            +
             
     | 
| 1348 | 
         
            +
                image = ImageOps.exif_transpose(original_image)
         
     | 
| 1349 | 
         
            +
             
     | 
| 1350 | 
         
            +
                if original_image.mode != 'RGB':
         
     | 
| 1351 | 
         
            +
                    original_image = original_image.convert('RGB')
         
     | 
| 1352 | 
         
            +
             
     | 
| 1353 | 
         
            +
                image = original_image
         
     | 
| 1354 | 
         
            +
                # Convert to RGB if necessary
         
     | 
| 1355 | 
         
            +
             
     | 
| 1356 | 
         
            +
                # Resize the image
         
     | 
| 1357 | 
         
            +
                image = image.resize((1024, 1024), resample=Image.LANCZOS)
         
     | 
| 1358 | 
         
            +
             
     | 
| 1359 | 
         
            +
                return image, h, w, original_image
         
     |