Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	
		wondervictor
		
	commited on
		
		
					Commit 
							
							·
						
						2422035
	
1
								Parent(s):
							
							f6bd4fa
								
update README
Browse filesThis view is limited to 50 files because it contains too many changes.  
							See raw diff
- .gitignore +163 -0
 - app.py +29 -4
 - app_canny.py +100 -0
 - app_depth.py +92 -0
 - autoregressive/models/README.md +6 -0
 - autoregressive/models/dinov2_adapter.py +36 -0
 - autoregressive/models/generate.py +204 -0
 - autoregressive/models/gpt_t2i.py +561 -0
 - autoregressive/sample/sample_c2i.py +151 -0
 - autoregressive/sample/sample_c2i_ddp.py +188 -0
 - autoregressive/sample/sample_t2i.py +215 -0
 - autoregressive/sample/sample_t2i_MR.py +237 -0
 - autoregressive/sample/sample_t2i_ddp.py +229 -0
 - checkpoints/vq_ds16_t2i.pt +3 -0
 - condition/README.md +23 -0
 - condition/canny.py +25 -0
 - condition/depth.py +47 -0
 - condition/example/t2i/multi_resolution/bird.jpg +0 -0
 - condition/example/t2i/multi_resolution/car.jpg +0 -0
 - condition/example/t2i/multigen/doll.jpg +0 -0
 - condition/example/t2i/multigen/girl.jpg +0 -0
 - condition/example/t2i/multigen/house.jpg +0 -0
 - condition/example/t2i/multigen/sofa.png +0 -0
 - condition/hed.py +117 -0
 - condition/lineart.py +98 -0
 - condition/midas/depth.py +223 -0
 - condition/midas/midas/__init__.py +0 -0
 - condition/midas/midas/base_model.py +16 -0
 - condition/midas/midas/blocks.py +341 -0
 - condition/midas/midas/dpt_depth.py +108 -0
 - condition/midas/midas/midas_net.py +76 -0
 - condition/midas/midas/midas_net_custom.py +128 -0
 - condition/midas/midas/transforms.py +234 -0
 - condition/midas/midas/vit.py +491 -0
 - condition/utils.py +38 -0
 - language/README.md +14 -0
 - language/extract_t5_feature.py +129 -0
 - language/t5.py +201 -0
 - model.py +242 -0
 - style.css +10 -0
 - tokenizer/consistencydecoder/README.md +14 -0
 - tokenizer/consistencydecoder/cd_demo.py +57 -0
 - tokenizer/consistencydecoder/reconstruction_cd_ddp.py +208 -0
 - tokenizer/tokenizer_image/cache/vgg.pth +3 -0
 - tokenizer/tokenizer_image/discriminator.py +255 -0
 - tokenizer/tokenizer_image/discriminator_patchgan.py +152 -0
 - tokenizer/tokenizer_image/discriminator_stylegan.py +101 -0
 - tokenizer/tokenizer_image/lpips.py +164 -0
 - tokenizer/tokenizer_image/reconstruction_vq_ddp.py +207 -0
 - tokenizer/tokenizer_image/vq_demo.py +84 -0
 
    	
        .gitignore
    ADDED
    
    | 
         @@ -0,0 +1,163 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Byte-compiled / optimized / DLL files
         
     | 
| 2 | 
         
            +
            __pycache__/
         
     | 
| 3 | 
         
            +
            *.py[cod]
         
     | 
| 4 | 
         
            +
            *$py.class
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            # C extensions
         
     | 
| 7 | 
         
            +
            *.so
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            # Distribution / packaging
         
     | 
| 10 | 
         
            +
            .Python
         
     | 
| 11 | 
         
            +
            build/
         
     | 
| 12 | 
         
            +
            develop-eggs/
         
     | 
| 13 | 
         
            +
            dist/
         
     | 
| 14 | 
         
            +
            downloads/
         
     | 
| 15 | 
         
            +
            eggs/
         
     | 
| 16 | 
         
            +
            .eggs/
         
     | 
| 17 | 
         
            +
            lib/
         
     | 
| 18 | 
         
            +
            lib64/
         
     | 
| 19 | 
         
            +
            parts/
         
     | 
| 20 | 
         
            +
            sdist/
         
     | 
| 21 | 
         
            +
            var/
         
     | 
| 22 | 
         
            +
            wheels/
         
     | 
| 23 | 
         
            +
            share/python-wheels/
         
     | 
| 24 | 
         
            +
            *.egg-info/
         
     | 
| 25 | 
         
            +
            .installed.cfg
         
     | 
| 26 | 
         
            +
            *.egg
         
     | 
| 27 | 
         
            +
            MANIFEST
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
            # PyInstaller
         
     | 
| 30 | 
         
            +
            #  Usually these files are written by a python script from a template
         
     | 
| 31 | 
         
            +
            #  before PyInstaller builds the exe, so as to inject date/other infos into it.
         
     | 
| 32 | 
         
            +
            *.manifest
         
     | 
| 33 | 
         
            +
            *.spec
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
            # Installer logs
         
     | 
| 36 | 
         
            +
            pip-log.txt
         
     | 
| 37 | 
         
            +
            pip-delete-this-directory.txt
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
            # Unit test / coverage reports
         
     | 
| 40 | 
         
            +
            htmlcov/
         
     | 
| 41 | 
         
            +
            .tox/
         
     | 
| 42 | 
         
            +
            .nox/
         
     | 
| 43 | 
         
            +
            .coverage
         
     | 
| 44 | 
         
            +
            .coverage.*
         
     | 
| 45 | 
         
            +
            .cache
         
     | 
| 46 | 
         
            +
            nosetests.xml
         
     | 
| 47 | 
         
            +
            coverage.xml
         
     | 
| 48 | 
         
            +
            *.cover
         
     | 
| 49 | 
         
            +
            *.py,cover
         
     | 
| 50 | 
         
            +
            .hypothesis/
         
     | 
| 51 | 
         
            +
            .pytest_cache/
         
     | 
| 52 | 
         
            +
            cover/
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
            # Translations
         
     | 
| 55 | 
         
            +
            *.mo
         
     | 
| 56 | 
         
            +
            *.pot
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
            # Django stuff:
         
     | 
| 59 | 
         
            +
            *.log
         
     | 
| 60 | 
         
            +
            local_settings.py
         
     | 
| 61 | 
         
            +
            db.sqlite3
         
     | 
| 62 | 
         
            +
            db.sqlite3-journal
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
            # Flask stuff:
         
     | 
| 65 | 
         
            +
            instance/
         
     | 
| 66 | 
         
            +
            .webassets-cache
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
            # Scrapy stuff:
         
     | 
| 69 | 
         
            +
            .scrapy
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
            # Sphinx documentation
         
     | 
| 72 | 
         
            +
            docs/_build/
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
            # PyBuilder
         
     | 
| 75 | 
         
            +
            .pybuilder/
         
     | 
| 76 | 
         
            +
            target/
         
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
            # Jupyter Notebook
         
     | 
| 79 | 
         
            +
            .ipynb_checkpoints
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
            # IPython
         
     | 
| 82 | 
         
            +
            profile_default/
         
     | 
| 83 | 
         
            +
            ipython_config.py
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
            # pyenv
         
     | 
| 86 | 
         
            +
            #   For a library or package, you might want to ignore these files since the code is
         
     | 
| 87 | 
         
            +
            #   intended to run in multiple environments; otherwise, check them in:
         
     | 
| 88 | 
         
            +
            # .python-version
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
            # pipenv
         
     | 
| 91 | 
         
            +
            #   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
         
     | 
| 92 | 
         
            +
            #   However, in case of collaboration, if having platform-specific dependencies or dependencies
         
     | 
| 93 | 
         
            +
            #   having no cross-platform support, pipenv may install dependencies that don't work, or not
         
     | 
| 94 | 
         
            +
            #   install all needed dependencies.
         
     | 
| 95 | 
         
            +
            #Pipfile.lock
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
            # poetry
         
     | 
| 98 | 
         
            +
            #   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
         
     | 
| 99 | 
         
            +
            #   This is especially recommended for binary packages to ensure reproducibility, and is more
         
     | 
| 100 | 
         
            +
            #   commonly ignored for libraries.
         
     | 
| 101 | 
         
            +
            #   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
         
     | 
| 102 | 
         
            +
            #poetry.lock
         
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
            # pdm
         
     | 
| 105 | 
         
            +
            #   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
         
     | 
| 106 | 
         
            +
            #pdm.lock
         
     | 
| 107 | 
         
            +
            #   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
         
     | 
| 108 | 
         
            +
            #   in version control.
         
     | 
| 109 | 
         
            +
            #   https://pdm.fming.dev/latest/usage/project/#working-with-version-control
         
     | 
| 110 | 
         
            +
            .pdm.toml
         
     | 
| 111 | 
         
            +
            .pdm-python
         
     | 
| 112 | 
         
            +
            .pdm-build/
         
     | 
| 113 | 
         
            +
             
     | 
| 114 | 
         
            +
            # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
         
     | 
| 115 | 
         
            +
            __pypackages__/
         
     | 
| 116 | 
         
            +
             
     | 
| 117 | 
         
            +
            # Celery stuff
         
     | 
| 118 | 
         
            +
            celerybeat-schedule
         
     | 
| 119 | 
         
            +
            celerybeat.pid
         
     | 
| 120 | 
         
            +
             
     | 
| 121 | 
         
            +
            # SageMath parsed files
         
     | 
| 122 | 
         
            +
            *.sage.py
         
     | 
| 123 | 
         
            +
             
     | 
| 124 | 
         
            +
            # Environments
         
     | 
| 125 | 
         
            +
            .env
         
     | 
| 126 | 
         
            +
            .venv
         
     | 
| 127 | 
         
            +
            env/
         
     | 
| 128 | 
         
            +
            venv/
         
     | 
| 129 | 
         
            +
            ENV/
         
     | 
| 130 | 
         
            +
            env.bak/
         
     | 
| 131 | 
         
            +
            venv.bak/
         
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
            # Spyder project settings
         
     | 
| 134 | 
         
            +
            .spyderproject
         
     | 
| 135 | 
         
            +
            .spyproject
         
     | 
| 136 | 
         
            +
             
     | 
| 137 | 
         
            +
            # Rope project settings
         
     | 
| 138 | 
         
            +
            .ropeproject
         
     | 
| 139 | 
         
            +
             
     | 
| 140 | 
         
            +
            # mkdocs documentation
         
     | 
| 141 | 
         
            +
            /site
         
     | 
| 142 | 
         
            +
             
     | 
| 143 | 
         
            +
            # mypy
         
     | 
| 144 | 
         
            +
            .mypy_cache/
         
     | 
| 145 | 
         
            +
            .dmypy.json
         
     | 
| 146 | 
         
            +
            dmypy.json
         
     | 
| 147 | 
         
            +
             
     | 
| 148 | 
         
            +
            # Pyre type checker
         
     | 
| 149 | 
         
            +
            .pyre/
         
     | 
| 150 | 
         
            +
             
     | 
| 151 | 
         
            +
            # pytype static type analyzer
         
     | 
| 152 | 
         
            +
            .pytype/
         
     | 
| 153 | 
         
            +
             
     | 
| 154 | 
         
            +
            # Cython debug symbols
         
     | 
| 155 | 
         
            +
            cython_debug/
         
     | 
| 156 | 
         
            +
             
     | 
| 157 | 
         
            +
            # PyCharm
         
     | 
| 158 | 
         
            +
            #  JetBrains specific template is maintained in a separate JetBrains.gitignore that can
         
     | 
| 159 | 
         
            +
            #  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
         
     | 
| 160 | 
         
            +
            #  and can be added to the global gitignore or merged into this file.  For a more nuclear
         
     | 
| 161 | 
         
            +
            #  option (not recommended) you can uncomment the following to ignore the entire idea folder.
         
     | 
| 162 | 
         
            +
            #.idea/
         
     | 
| 163 | 
         
            +
             
     | 
    	
        app.py
    CHANGED
    
    | 
         @@ -1,7 +1,32 @@ 
     | 
|
| 
         | 
|
| 1 | 
         
             
            import gradio as gr
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 2 | 
         | 
| 3 | 
         
            -
            def greet(name):
         
     | 
| 4 | 
         
            -
                return "Hello " + name + "!!"
         
     | 
| 5 | 
         | 
| 6 | 
         
            -
             
     | 
| 7 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from PIL import Image
         
     | 
| 2 | 
         
             
            import gradio as gr
         
     | 
| 3 | 
         
            +
            from huggingface_hub import hf_hub_download
         
     | 
| 4 | 
         
            +
            from model import Model
         
     | 
| 5 | 
         
            +
            from app_canny import create_demo as create_demo_canny
         
     | 
| 6 | 
         
            +
            from app_depth import create_demo as create_demo_depth
         
     | 
| 7 | 
         
            +
            import os
         
     | 
| 8 | 
         | 
| 
         | 
|
| 
         | 
|
| 9 | 
         | 
| 10 | 
         
            +
            hf_hub_download('wondervictor/ControlAR', filename='canny_MR.safetensors', cache_dir='./checkpoints/')
         
     | 
| 11 | 
         
            +
            hf_hub_download('wondervictor/ControlAR', filename='depth_MR.safetensors', cache_dir='./checkpoints/')
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            DESCRIPTION = "# [ControlAR: Controllable Image Generation with Autoregressive Models](https://arxiv.org/abs/2410.02705) \n ### The first row in outputs is the input image and condition. The second row is the images generated by ControlAR.  \n ### You can run locally by following the instruction on our [Github Repo](https://github.com/hustvl/ControlAR)."
         
     | 
| 15 | 
         
            +
            SHOW_DUPLICATE_BUTTON = os.getenv("SHOW_DUPLICATE_BUTTON") == "1"
         
     | 
| 16 | 
         
            +
            model = Model()
         
     | 
| 17 | 
         
            +
            device = "cuda"
         
     | 
| 18 | 
         
            +
            with gr.Blocks(css="style.css") as demo:
         
     | 
| 19 | 
         
            +
                gr.Markdown(DESCRIPTION)
         
     | 
| 20 | 
         
            +
                gr.DuplicateButton(
         
     | 
| 21 | 
         
            +
                    value="Duplicate Space for private use",
         
     | 
| 22 | 
         
            +
                    elem_id="duplicate-button",
         
     | 
| 23 | 
         
            +
                    visible=SHOW_DUPLICATE_BUTTON,
         
     | 
| 24 | 
         
            +
                )
         
     | 
| 25 | 
         
            +
                with gr.Tabs():
         
     | 
| 26 | 
         
            +
                    with gr.TabItem("Depth"):
         
     | 
| 27 | 
         
            +
                        create_demo_depth(model.process_depth)
         
     | 
| 28 | 
         
            +
                    with gr.TabItem("Canny"):
         
     | 
| 29 | 
         
            +
                        create_demo_canny(model.process_canny)
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 32 | 
         
            +
                demo.queue().launch(share=False, server_name="0.0.0.0")
         
     | 
    	
        app_canny.py
    ADDED
    
    | 
         @@ -0,0 +1,100 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import gradio as gr
         
     | 
| 2 | 
         
            +
            import random
         
     | 
| 3 | 
         
            +
            def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
         
     | 
| 4 | 
         
            +
                if randomize_seed:
         
     | 
| 5 | 
         
            +
                    seed = random.randint(0, 100000000)
         
     | 
| 6 | 
         
            +
                return seed
         
     | 
| 7 | 
         
            +
            examples = [
         
     | 
| 8 | 
         
            +
                [
         
     | 
| 9 | 
         
            +
                    "condition/example/t2i/multigen/doll.png",
         
     | 
| 10 | 
         
            +
                    "A stuffed animal wearing a mask and a leash, sitting on a blanket",
         
     | 
| 11 | 
         
            +
                    "(512, 512)"
         
     | 
| 12 | 
         
            +
                ],
         
     | 
| 13 | 
         
            +
                [
         
     | 
| 14 | 
         
            +
                    "condition/example/t2i/multigen/girl.png",
         
     | 
| 15 | 
         
            +
                    "An anime style girl with blue hair",
         
     | 
| 16 | 
         
            +
                    "(512, 512)"
         
     | 
| 17 | 
         
            +
                ],
         
     | 
| 18 | 
         
            +
                [
         
     | 
| 19 | 
         
            +
                    "condition/example/t2i/multi_resolution/bird.jpg",
         
     | 
| 20 | 
         
            +
                    "colorful bird",
         
     | 
| 21 | 
         
            +
                    "(921, 564)"
         
     | 
| 22 | 
         
            +
                ],
         
     | 
| 23 | 
         
            +
            ]
         
     | 
| 24 | 
         
            +
            def create_demo(process):
         
     | 
| 25 | 
         
            +
                with gr.Blocks() as demo:
         
     | 
| 26 | 
         
            +
                    with gr.Row():
         
     | 
| 27 | 
         
            +
                        with gr.Column():
         
     | 
| 28 | 
         
            +
                            image = gr.Image()
         
     | 
| 29 | 
         
            +
                            prompt = gr.Textbox(label="Prompt")
         
     | 
| 30 | 
         
            +
                            run_button = gr.Button("Run")
         
     | 
| 31 | 
         
            +
                            with gr.Accordion("Advanced options", open=False):
         
     | 
| 32 | 
         
            +
                                canny_low_threshold = gr.Slider(
         
     | 
| 33 | 
         
            +
                                    label="Canny low threshold", minimum=0, maximum=1000, value=100, step=50
         
     | 
| 34 | 
         
            +
                                )
         
     | 
| 35 | 
         
            +
                                canny_high_threshold = gr.Slider(
         
     | 
| 36 | 
         
            +
                                    label="Canny high threshold", minimum=0, maximum=1000, value=200, step=50
         
     | 
| 37 | 
         
            +
                                )
         
     | 
| 38 | 
         
            +
                                cfg_scale = gr.Slider(label="Guidance scale", minimum=0.1, maximum=30.0, value=4, step=0.1)
         
     | 
| 39 | 
         
            +
                                relolution = gr.Slider(label="(H, W)", minimum=384, maximum=768, value=512, step=16)
         
     | 
| 40 | 
         
            +
                                top_k = gr.Slider(minimum=1, maximum=16384, step=1, value=2000, label='Top-K')
         
     | 
| 41 | 
         
            +
                                top_p = gr.Slider(minimum=0., maximum=1.0, step=0.1, value=1.0, label="Top-P")
         
     | 
| 42 | 
         
            +
                                temperature = gr.Slider(minimum=0., maximum=1.0, step=0.1, value=1.0, label='Temperature')
         
     | 
| 43 | 
         
            +
                                seed = gr.Slider(label="Seed", minimum=0, maximum=100000000, step=1, value=0)
         
     | 
| 44 | 
         
            +
                                randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
         
     | 
| 45 | 
         
            +
                        with gr.Column():
         
     | 
| 46 | 
         
            +
                            result = gr.Gallery(label="Output", show_label=False, height='800px', columns=2, object_fit="scale-down")
         
     | 
| 47 | 
         
            +
                    gr.Examples(
         
     | 
| 48 | 
         
            +
                        examples=examples,
         
     | 
| 49 | 
         
            +
                        inputs=[
         
     | 
| 50 | 
         
            +
                            image,
         
     | 
| 51 | 
         
            +
                            prompt,
         
     | 
| 52 | 
         
            +
                            relolution,
         
     | 
| 53 | 
         
            +
                        ],
         
     | 
| 54 | 
         
            +
                        outputs=result,
         
     | 
| 55 | 
         
            +
                        fn=process,
         
     | 
| 56 | 
         
            +
                    )
         
     | 
| 57 | 
         
            +
                    inputs = [
         
     | 
| 58 | 
         
            +
                        image,
         
     | 
| 59 | 
         
            +
                        prompt,
         
     | 
| 60 | 
         
            +
                        cfg_scale,
         
     | 
| 61 | 
         
            +
                        temperature,
         
     | 
| 62 | 
         
            +
                        top_k,
         
     | 
| 63 | 
         
            +
                        top_p,
         
     | 
| 64 | 
         
            +
                        seed,
         
     | 
| 65 | 
         
            +
                        canny_low_threshold,
         
     | 
| 66 | 
         
            +
                        canny_high_threshold,
         
     | 
| 67 | 
         
            +
                    ]
         
     | 
| 68 | 
         
            +
                    prompt.submit(
         
     | 
| 69 | 
         
            +
                        fn=randomize_seed_fn,
         
     | 
| 70 | 
         
            +
                        inputs=[seed, randomize_seed],
         
     | 
| 71 | 
         
            +
                        outputs=seed,
         
     | 
| 72 | 
         
            +
                        queue=False,
         
     | 
| 73 | 
         
            +
                        api_name=False,
         
     | 
| 74 | 
         
            +
                    ).then(
         
     | 
| 75 | 
         
            +
                        fn=process,
         
     | 
| 76 | 
         
            +
                        inputs=inputs,
         
     | 
| 77 | 
         
            +
                        outputs=result,
         
     | 
| 78 | 
         
            +
                        api_name=False,
         
     | 
| 79 | 
         
            +
                    )
         
     | 
| 80 | 
         
            +
                    run_button.click(
         
     | 
| 81 | 
         
            +
                        fn=randomize_seed_fn,
         
     | 
| 82 | 
         
            +
                        inputs=[seed, randomize_seed],
         
     | 
| 83 | 
         
            +
                        outputs=seed,
         
     | 
| 84 | 
         
            +
                        queue=False,
         
     | 
| 85 | 
         
            +
                        api_name=False,
         
     | 
| 86 | 
         
            +
                    ).then(
         
     | 
| 87 | 
         
            +
                        fn=process,
         
     | 
| 88 | 
         
            +
                        inputs=inputs,
         
     | 
| 89 | 
         
            +
                        outputs=result,
         
     | 
| 90 | 
         
            +
                        api_name="canny",
         
     | 
| 91 | 
         
            +
                    )
         
     | 
| 92 | 
         
            +
                return demo
         
     | 
| 93 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 94 | 
         
            +
                from model import Model
         
     | 
| 95 | 
         
            +
                model = Model()
         
     | 
| 96 | 
         
            +
                demo = create_demo(model.process_canny)
         
     | 
| 97 | 
         
            +
                demo.queue().launch(
         
     | 
| 98 | 
         
            +
                share=False,
         
     | 
| 99 | 
         
            +
                server_name="0.0.0.0"
         
     | 
| 100 | 
         
            +
                )
         
     | 
    	
        app_depth.py
    ADDED
    
    | 
         @@ -0,0 +1,92 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import gradio as gr
         
     | 
| 2 | 
         
            +
            import random
         
     | 
| 3 | 
         
            +
            def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
         
     | 
| 4 | 
         
            +
                if randomize_seed:
         
     | 
| 5 | 
         
            +
                    seed = random.randint(0, 100000000)
         
     | 
| 6 | 
         
            +
                return seed
         
     | 
| 7 | 
         
            +
            examples = [
         
     | 
| 8 | 
         
            +
                [
         
     | 
| 9 | 
         
            +
                    "condition/example/t2i/multigen/sofa.png",
         
     | 
| 10 | 
         
            +
                    "The red sofa in the living room has several pillows on it",
         
     | 
| 11 | 
         
            +
                    "(512, 512)"
         
     | 
| 12 | 
         
            +
                ],
         
     | 
| 13 | 
         
            +
                [
         
     | 
| 14 | 
         
            +
                    "condition/example/t2i/multigen/house.png",
         
     | 
| 15 | 
         
            +
                    "A brick house with a chimney under a starry sky.",
         
     | 
| 16 | 
         
            +
                    "(512, 512)"
         
     | 
| 17 | 
         
            +
                ],
         
     | 
| 18 | 
         
            +
                [
         
     | 
| 19 | 
         
            +
                    "condition/example/t2i/multi_resolution/car.jpg",
         
     | 
| 20 | 
         
            +
                    "a sport car",
         
     | 
| 21 | 
         
            +
                    "(448, 768)"
         
     | 
| 22 | 
         
            +
                ]
         
     | 
| 23 | 
         
            +
            ]
         
     | 
| 24 | 
         
            +
            def create_demo(process):
         
     | 
| 25 | 
         
            +
                with gr.Blocks() as demo:
         
     | 
| 26 | 
         
            +
                    with gr.Row():
         
     | 
| 27 | 
         
            +
                        with gr.Column():
         
     | 
| 28 | 
         
            +
                            image = gr.Image()
         
     | 
| 29 | 
         
            +
                            prompt = gr.Textbox(label="Prompt")
         
     | 
| 30 | 
         
            +
                            run_button = gr.Button("Run")
         
     | 
| 31 | 
         
            +
                            with gr.Accordion("Advanced options", open=False):
         
     | 
| 32 | 
         
            +
                                cfg_scale = gr.Slider(label="Guidance scale", minimum=0.1, maximum=30.0, value=4, step=0.1)
         
     | 
| 33 | 
         
            +
                                resolution = gr.Slider(label="(H, W)", minimum=384, maximum=768, value=512, step=16)
         
     | 
| 34 | 
         
            +
                                top_k = gr.Slider(minimum=1, maximum=16384, step=1, value=2000, label='Top-K')
         
     | 
| 35 | 
         
            +
                                top_p = gr.Slider(minimum=0., maximum=1.0, step=0.1, value=1.0, label="Top-P")
         
     | 
| 36 | 
         
            +
                                temperature = gr.Slider(minimum=0., maximum=1.0, step=0.1, value=1.0, label='Temperature')
         
     | 
| 37 | 
         
            +
                                seed = gr.Slider(label="Seed", minimum=0, maximum=100000000, step=1, value=0)
         
     | 
| 38 | 
         
            +
                                randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
         
     | 
| 39 | 
         
            +
                        with gr.Column():
         
     | 
| 40 | 
         
            +
                            result = gr.Gallery(label="Output", show_label=False, height='800px', columns=2, object_fit="scale-down")
         
     | 
| 41 | 
         
            +
                    gr.Examples(
         
     | 
| 42 | 
         
            +
                        examples=examples,
         
     | 
| 43 | 
         
            +
                        inputs=[
         
     | 
| 44 | 
         
            +
                            image,
         
     | 
| 45 | 
         
            +
                            prompt,
         
     | 
| 46 | 
         
            +
                            resolution,
         
     | 
| 47 | 
         
            +
                        ],
         
     | 
| 48 | 
         
            +
                        outputs=result,
         
     | 
| 49 | 
         
            +
                        fn=process,
         
     | 
| 50 | 
         
            +
                    )
         
     | 
| 51 | 
         
            +
                    inputs = [
         
     | 
| 52 | 
         
            +
                        image,
         
     | 
| 53 | 
         
            +
                        prompt,
         
     | 
| 54 | 
         
            +
                        cfg_scale,
         
     | 
| 55 | 
         
            +
                        temperature,
         
     | 
| 56 | 
         
            +
                        top_k,
         
     | 
| 57 | 
         
            +
                        top_p,
         
     | 
| 58 | 
         
            +
                        seed,
         
     | 
| 59 | 
         
            +
                    ]
         
     | 
| 60 | 
         
            +
                    prompt.submit(
         
     | 
| 61 | 
         
            +
                        fn=randomize_seed_fn,
         
     | 
| 62 | 
         
            +
                        inputs=[seed, randomize_seed],
         
     | 
| 63 | 
         
            +
                        outputs=seed,
         
     | 
| 64 | 
         
            +
                        queue=False,
         
     | 
| 65 | 
         
            +
                        api_name=False,
         
     | 
| 66 | 
         
            +
                    ).then(
         
     | 
| 67 | 
         
            +
                        fn=process,
         
     | 
| 68 | 
         
            +
                        inputs=inputs,
         
     | 
| 69 | 
         
            +
                        outputs=result,
         
     | 
| 70 | 
         
            +
                        api_name=False,
         
     | 
| 71 | 
         
            +
                    )
         
     | 
| 72 | 
         
            +
                    run_button.click(
         
     | 
| 73 | 
         
            +
                        fn=randomize_seed_fn,
         
     | 
| 74 | 
         
            +
                        inputs=[seed, randomize_seed],
         
     | 
| 75 | 
         
            +
                        outputs=seed,
         
     | 
| 76 | 
         
            +
                        queue=False,
         
     | 
| 77 | 
         
            +
                        api_name=False,
         
     | 
| 78 | 
         
            +
                    ).then(
         
     | 
| 79 | 
         
            +
                        fn=process,
         
     | 
| 80 | 
         
            +
                        inputs=inputs,
         
     | 
| 81 | 
         
            +
                        outputs=result,
         
     | 
| 82 | 
         
            +
                        api_name="canny",
         
     | 
| 83 | 
         
            +
                    )
         
     | 
| 84 | 
         
            +
                return demo
         
     | 
| 85 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 86 | 
         
            +
                from model import Model
         
     | 
| 87 | 
         
            +
                model = Model()
         
     | 
| 88 | 
         
            +
                demo = create_demo(model.process_depth)
         
     | 
| 89 | 
         
            +
                demo.queue().launch(
         
     | 
| 90 | 
         
            +
                share=False,
         
     | 
| 91 | 
         
            +
                server_name="0.0.0.0"
         
     | 
| 92 | 
         
            +
                )
         
     | 
    	
        autoregressive/models/README.md
    ADDED
    
    | 
         @@ -0,0 +1,6 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            Download the vit weight first 
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            ViT-small: https://huggingface.co/WinKawaks/vit-small-patch16-224 \
         
     | 
| 4 | 
         
            +
            Dinov2-small: https://huggingface.co/facebook/dinov2-small
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            Put them here
         
     | 
    	
        autoregressive/models/dinov2_adapter.py
    ADDED
    
    | 
         @@ -0,0 +1,36 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from transformers import AutoImageProcessor, AutoModel
         
     | 
| 2 | 
         
            +
            from PIL import Image
         
     | 
| 3 | 
         
            +
            import requests
         
     | 
| 4 | 
         
            +
            import torch
         
     | 
| 5 | 
         
            +
            import torch.nn as nn
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            class Dinov2_Adapter(nn.Module):
         
     | 
| 9 | 
         
            +
                def __init__(self, input_dim=1, output_dim=768, attention=False, pool=False, nheads=8, dropout=0.1, adapter_size='small', condition_type='canny'):
         
     | 
| 10 | 
         
            +
                    super(Dinov2_Adapter, self).__init__()
         
     | 
| 11 | 
         
            +
                    print(f"Choose adapter size: {adapter_size}")
         
     | 
| 12 | 
         
            +
                    print(f"condition type: {condition_type}")
         
     | 
| 13 | 
         
            +
                    self.model = AutoModel.from_pretrained(f'autoregressive/models/dinov2-{adapter_size}')
         
     | 
| 14 | 
         
            +
                    self.condition_type = condition_type
         
     | 
| 15 | 
         
            +
                
         
     | 
| 16 | 
         
            +
                def to_patch14(self, input):
         
     | 
| 17 | 
         
            +
                    H, W = input.shape[2:]
         
     | 
| 18 | 
         
            +
                    new_H = (H // 16) * 14
         
     | 
| 19 | 
         
            +
                    new_W = (W // 16) * 14
         
     | 
| 20 | 
         
            +
                    if self.condition_type in ['canny', 'seg']:
         
     | 
| 21 | 
         
            +
                        output = torch.nn.functional.interpolate(input, size=(new_H, new_W), mode='nearest')#, align_corners=True)  canny, seg
         
     | 
| 22 | 
         
            +
                    else:
         
     | 
| 23 | 
         
            +
                        output = torch.nn.functional.interpolate(input, size=(new_H, new_W), mode='bicubic', align_corners=True) # depth, lineart, hed
         
     | 
| 24 | 
         
            +
                    return output
         
     | 
| 25 | 
         
            +
                    
         
     | 
| 26 | 
         
            +
                def forward(self, x):
         
     | 
| 27 | 
         
            +
                    x = self.to_patch14(x)
         
     | 
| 28 | 
         
            +
                    x = self.model(x)
         
     | 
| 29 | 
         
            +
                    return x.last_hidden_state[:, 1:]
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
            if __name__ == '__main__':
         
     | 
| 33 | 
         
            +
                model = Dinov2_Adapter().cuda()
         
     | 
| 34 | 
         
            +
                inputs = torch.randn(4,3,512,512).cuda()
         
     | 
| 35 | 
         
            +
                outputs = model(inputs)
         
     | 
| 36 | 
         
            +
                print(outputs.shape)
         
     | 
    	
        autoregressive/models/generate.py
    ADDED
    
    | 
         @@ -0,0 +1,204 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Modified from:
         
     | 
| 2 | 
         
            +
            #   gpt-fast: https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py
         
     | 
| 3 | 
         
            +
            #   DiT:      https://github.com/facebookresearch/DiT/blob/main/models.py
         
     | 
| 4 | 
         
            +
            import torch
         
     | 
| 5 | 
         
            +
            import torch.nn as nn
         
     | 
| 6 | 
         
            +
            from torch.nn import functional as F
         
     | 
| 7 | 
         
            +
            import torch._dynamo.config
         
     | 
| 8 | 
         
            +
            import torch._inductor.config
         
     | 
| 9 | 
         
            +
            import copy
         
     | 
| 10 | 
         
            +
            import time
         
     | 
| 11 | 
         
            +
            # torch._inductor.config.coordinate_descent_tuning = True
         
     | 
| 12 | 
         
            +
            # torch._inductor.config.triton.unique_kernel_names = True
         
     | 
| 13 | 
         
            +
            # torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            ### from https://huggingface.co/transformers/v3.2.0/_modules/transformers/generation_utils.html
         
     | 
| 17 | 
         
            +
            def top_k_top_p_filtering(
         
     | 
| 18 | 
         
            +
                logits,
         
     | 
| 19 | 
         
            +
                top_k: int = 0,
         
     | 
| 20 | 
         
            +
                top_p: float = 1.0,
         
     | 
| 21 | 
         
            +
                filter_value: float = -float("Inf"),
         
     | 
| 22 | 
         
            +
                min_tokens_to_keep: int = 1,
         
     | 
| 23 | 
         
            +
            ):
         
     | 
| 24 | 
         
            +
                """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
         
     | 
| 25 | 
         
            +
                Args:
         
     | 
| 26 | 
         
            +
                    logits: logits distribution shape (batch size, vocabulary size)
         
     | 
| 27 | 
         
            +
                    if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
         
     | 
| 28 | 
         
            +
                    if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
         
     | 
| 29 | 
         
            +
                        Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
         
     | 
| 30 | 
         
            +
                    Make sure we keep at least min_tokens_to_keep per batch example in the output
         
     | 
| 31 | 
         
            +
                From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
         
     | 
| 32 | 
         
            +
                """
         
     | 
| 33 | 
         
            +
                if top_k > 0:
         
     | 
| 34 | 
         
            +
                    # import pdb;pdb.set_trace()
         
     | 
| 35 | 
         
            +
                    top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1))  # Safety check
         
     | 
| 36 | 
         
            +
                    # Remove all tokens with a probability less than the last token of the top-k
         
     | 
| 37 | 
         
            +
                    indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
         
     | 
| 38 | 
         
            +
                    logits[indices_to_remove] = filter_value
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
                if top_p < 1.0:
         
     | 
| 41 | 
         
            +
                    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
         
     | 
| 42 | 
         
            +
                    cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
                    # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
         
     | 
| 45 | 
         
            +
                    sorted_indices_to_remove = cumulative_probs > top_p
         
     | 
| 46 | 
         
            +
                    if min_tokens_to_keep > 1:
         
     | 
| 47 | 
         
            +
                        # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
         
     | 
| 48 | 
         
            +
                        sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
         
     | 
| 49 | 
         
            +
                    # Shift the indices to the right to keep also the first token above the threshold
         
     | 
| 50 | 
         
            +
                    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
         
     | 
| 51 | 
         
            +
                    sorted_indices_to_remove[..., 0] = 0
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
                    # scatter sorted tensors to original indexing
         
     | 
| 54 | 
         
            +
                    indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
         
     | 
| 55 | 
         
            +
                    logits[indices_to_remove] = filter_value
         
     | 
| 56 | 
         
            +
                return logits
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
            def sample(logits, temperature: float=1.0, top_k: int=2000, top_p: float=1.0, sample_logits=True):        
         
     | 
| 60 | 
         
            +
                logits = logits[:, -1, :] / max(temperature, 1e-5)
         
     | 
| 61 | 
         
            +
                if top_k > 0 or top_p < 1.0:
         
     | 
| 62 | 
         
            +
                    logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
         
     | 
| 63 | 
         
            +
                probs = F.softmax(logits, dim=-1)
         
     | 
| 64 | 
         
            +
                # values, indices = torch.max(probs, dim=1, keepdim=True)
         
     | 
| 65 | 
         
            +
                # mask = (probs == values).float()
         
     | 
| 66 | 
         
            +
                # probs = probs * (1 - mask)
         
     | 
| 67 | 
         
            +
                # values, indices = torch.max(probs, dim=1, keepdim=True)
         
     | 
| 68 | 
         
            +
                # mask = (probs == values).float()
         
     | 
| 69 | 
         
            +
                # probs = probs * (1 - mask)
         
     | 
| 70 | 
         
            +
                if sample_logits:
         
     | 
| 71 | 
         
            +
                    idx = torch.multinomial(probs, num_samples=1)
         
     | 
| 72 | 
         
            +
                else:
         
     | 
| 73 | 
         
            +
                    _, idx = torch.topk(probs, k=1, dim=-1)
         
     | 
| 74 | 
         
            +
                return idx, probs
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
            def logits_to_probs(logits, temperature: float = 1.0, top_p: float=1.0, top_k: int = None, **kwargs):
         
     | 
| 78 | 
         
            +
                logits = logits / max(temperature, 1e-5)
         
     | 
| 79 | 
         
            +
                if top_k > 0 or top_p < 1.0:
         
     | 
| 80 | 
         
            +
                    logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
         
     | 
| 81 | 
         
            +
                probs = torch.nn.functional.softmax(logits, dim=-1)
         
     | 
| 82 | 
         
            +
                return probs
         
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
            def prefill(model, cond_idx: torch.Tensor, input_pos: torch.Tensor, cfg_scale: float, condition:torch.Tensor, **sampling_kwargs):
         
     | 
| 86 | 
         
            +
                if cfg_scale > 1.0:
         
     | 
| 87 | 
         
            +
                    logits, _ = model(None, cond_idx, input_pos, condition=condition)
         
     | 
| 88 | 
         
            +
                    logits_combined = logits
         
     | 
| 89 | 
         
            +
                    cond_logits, uncond_logits = torch.split(logits_combined, len(logits_combined) // 2, dim=0)
         
     | 
| 90 | 
         
            +
                    logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale
         
     | 
| 91 | 
         
            +
                else:
         
     | 
| 92 | 
         
            +
                    logits, _ = model(None, cond_idx, input_pos, condition=condition)
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
                return sample(logits, **sampling_kwargs)[0]
         
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
            def decode_one_token(model, x: torch.Tensor, input_pos: torch.Tensor, cfg_scale: float, cfg_flag: bool, condition: torch.Tensor,  **sampling_kwargs):
         
     | 
| 98 | 
         
            +
                assert input_pos.shape[-1] == 1
         
     | 
| 99 | 
         
            +
                if cfg_scale > 1.0:
         
     | 
| 100 | 
         
            +
                    x_combined = torch.cat([x, x])
         
     | 
| 101 | 
         
            +
                    logits, _ = model(x_combined, cond_idx=None, input_pos=input_pos, condition=condition)
         
     | 
| 102 | 
         
            +
                    logits_combined = logits
         
     | 
| 103 | 
         
            +
                    cond_logits, uncond_logits = torch.split(logits_combined, len(logits_combined) // 2, dim=0) 
         
     | 
| 104 | 
         
            +
                    if cfg_flag:
         
     | 
| 105 | 
         
            +
                        logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale
         
     | 
| 106 | 
         
            +
                    else:
         
     | 
| 107 | 
         
            +
                        logits = cond_logits
         
     | 
| 108 | 
         
            +
                else:
         
     | 
| 109 | 
         
            +
                    logits, _ = model(x, cond_idx=None, input_pos=input_pos, condition=None)
         
     | 
| 110 | 
         
            +
                return sample(logits, **sampling_kwargs)
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
             
     | 
| 113 | 
         
            +
            def decode_n_tokens(
         
     | 
| 114 | 
         
            +
                model, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int, 
         
     | 
| 115 | 
         
            +
                cfg_scale: float, cfg_interval: int, condition: torch.Tensor,
         
     | 
| 116 | 
         
            +
                **sampling_kwargs):
         
     | 
| 117 | 
         
            +
                new_tokens, new_probs = [], []
         
     | 
| 118 | 
         
            +
                cfg_flag = True
         
     | 
| 119 | 
         
            +
                for i in range(num_new_tokens):
         
     | 
| 120 | 
         
            +
                    with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): # Actually better for Inductor to codegen attention here
         
     | 
| 121 | 
         
            +
                        if cfg_interval > -1 and i > cfg_interval:
         
     | 
| 122 | 
         
            +
                            cfg_flag = False
         
     | 
| 123 | 
         
            +
                        next_token, next_prob = decode_one_token(
         
     | 
| 124 | 
         
            +
                            model, cur_token, input_pos, cfg_scale, cfg_flag, condition=condition, **sampling_kwargs
         
     | 
| 125 | 
         
            +
                        )
         
     | 
| 126 | 
         
            +
                        input_pos += 1
         
     | 
| 127 | 
         
            +
                        new_tokens.append(next_token.clone())
         
     | 
| 128 | 
         
            +
                        new_probs.append(next_prob.clone())
         
     | 
| 129 | 
         
            +
                        cur_token = next_token.view(-1, 1)
         
     | 
| 130 | 
         
            +
                
         
     | 
| 131 | 
         
            +
                return new_tokens, new_probs
         
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
             
     | 
| 134 | 
         
            +
            @torch.no_grad()
         
     | 
| 135 | 
         
            +
            def generate(model, cond, max_new_tokens, emb_masks=None, cfg_scale=1.0, cfg_interval=-1, condition=None, condition_null=None, condition_token_nums=0, **sampling_kwargs):
         
     | 
| 136 | 
         
            +
                if condition is not None:
         
     | 
| 137 | 
         
            +
                    condition = model.adapter(condition)
         
     | 
| 138 | 
         
            +
                    condition = model.adapter_mlp(condition)
         
     | 
| 139 | 
         
            +
                if model.model_type == 'c2i':
         
     | 
| 140 | 
         
            +
                    if cfg_scale > 1.0:
         
     | 
| 141 | 
         
            +
                        cond_null = torch.ones_like(cond) * model.num_classes
         
     | 
| 142 | 
         
            +
                        cond_combined = torch.cat([cond, cond_null])
         
     | 
| 143 | 
         
            +
                        if condition is not None:
         
     | 
| 144 | 
         
            +
                            condition_null = torch.zeros_like(condition)
         
     | 
| 145 | 
         
            +
                            condition_combined = torch.cat((condition, condition_null), dim=0)
         
     | 
| 146 | 
         
            +
                        else:
         
     | 
| 147 | 
         
            +
                            condition_combined = None
         
     | 
| 148 | 
         
            +
                    else:
         
     | 
| 149 | 
         
            +
                        cond_combined = cond
         
     | 
| 150 | 
         
            +
                        if condition is not None:
         
     | 
| 151 | 
         
            +
                            condition_combined = condition
         
     | 
| 152 | 
         
            +
                        else:
         
     | 
| 153 | 
         
            +
                            condition_combined = None
         
     | 
| 154 | 
         
            +
                    T = 1+condition_token_nums
         
     | 
| 155 | 
         
            +
                elif model.model_type == 't2i':
         
     | 
| 156 | 
         
            +
                    if cfg_scale > 1.0:
         
     | 
| 157 | 
         
            +
                        cond_null = torch.zeros_like(cond) + model.cls_embedding.uncond_embedding
         
     | 
| 158 | 
         
            +
                        cond_combined = torch.cat([cond, cond_null])
         
     | 
| 159 | 
         
            +
                        
         
     | 
| 160 | 
         
            +
                        if condition is not None:
         
     | 
| 161 | 
         
            +
                            condition_null = torch.zeros_like(condition)
         
     | 
| 162 | 
         
            +
                            condition_combined = torch.cat((condition, condition_null), dim=0)
         
     | 
| 163 | 
         
            +
                        else:
         
     | 
| 164 | 
         
            +
                            condition_combined = None
         
     | 
| 165 | 
         
            +
                    else:
         
     | 
| 166 | 
         
            +
                        cond_combined = cond
         
     | 
| 167 | 
         
            +
                        if condition is not None:
         
     | 
| 168 | 
         
            +
                            condition_combined = condition
         
     | 
| 169 | 
         
            +
                        else:
         
     | 
| 170 | 
         
            +
                            condition_combined = None
         
     | 
| 171 | 
         
            +
                    T = cond.shape[1]      
         
     | 
| 172 | 
         
            +
                else:
         
     | 
| 173 | 
         
            +
                    raise Exception("please check model type")
         
     | 
| 174 | 
         
            +
             
     | 
| 175 | 
         
            +
                T_new = T + max_new_tokens
         
     | 
| 176 | 
         
            +
                max_seq_length = T_new
         
     | 
| 177 | 
         
            +
                max_batch_size = cond.shape[0]
         
     | 
| 178 | 
         
            +
             
     | 
| 179 | 
         
            +
                device = cond.device
         
     | 
| 180 | 
         
            +
                with torch.device(device):
         
     | 
| 181 | 
         
            +
                    max_batch_size_cfg = max_batch_size * 2 if cfg_scale > 1.0 else max_batch_size
         
     | 
| 182 | 
         
            +
                    model.setup_caches(max_batch_size=max_batch_size_cfg, max_seq_length=max_seq_length, dtype=model.tok_embeddings.weight.dtype)
         
     | 
| 183 | 
         
            +
                
         
     | 
| 184 | 
         
            +
                if emb_masks is not None:
         
     | 
| 185 | 
         
            +
                    assert emb_masks.shape[0] == max_batch_size
         
     | 
| 186 | 
         
            +
                    assert emb_masks.shape[-1] == T
         
     | 
| 187 | 
         
            +
                    if cfg_scale > 1.0:
         
     | 
| 188 | 
         
            +
                        model.causal_mask[:, :, :T] = model.causal_mask[:, :, :T] * torch.cat([emb_masks, emb_masks]).unsqueeze(1)
         
     | 
| 189 | 
         
            +
                    else:
         
     | 
| 190 | 
         
            +
                        model.causal_mask[:, :, :T] = model.causal_mask[:, :, :T] * emb_masks.unsqueeze(1)
         
     | 
| 191 | 
         
            +
             
     | 
| 192 | 
         
            +
                    eye_matrix = torch.eye(model.causal_mask.size(1), model.causal_mask.size(2), device=device)
         
     | 
| 193 | 
         
            +
                    model.causal_mask[:] = model.causal_mask * (1 - eye_matrix) + eye_matrix
         
     | 
| 194 | 
         
            +
                
         
     | 
| 195 | 
         
            +
                # create an empty tensor of the expected final shape and fill in the current tokens
         
     | 
| 196 | 
         
            +
                seq = torch.empty((max_batch_size, T_new), dtype=torch.int, device=device)
         
     | 
| 197 | 
         
            +
                input_pos = torch.arange(0, T, device=device)
         
     | 
| 198 | 
         
            +
                next_token = prefill(model, cond_combined, input_pos, cfg_scale, condition_combined, **sampling_kwargs)
         
     | 
| 199 | 
         
            +
                seq[:, T:T+1] = next_token
         
     | 
| 200 | 
         
            +
             
     | 
| 201 | 
         
            +
                input_pos = torch.tensor([T], device=device, dtype=torch.int)
         
     | 
| 202 | 
         
            +
                generated_tokens, _ = decode_n_tokens(model, next_token, input_pos, max_new_tokens-1, cfg_scale, cfg_interval, condition=condition_combined, **sampling_kwargs)
         
     | 
| 203 | 
         
            +
                seq[:, T+1:] = torch.cat(generated_tokens, dim=1)
         
     | 
| 204 | 
         
            +
                return seq[:, T:]
         
     | 
    	
        autoregressive/models/gpt_t2i.py
    ADDED
    
    | 
         @@ -0,0 +1,561 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Modified from:
         
     | 
| 2 | 
         
            +
            #   VQGAN:    https://github.com/CompVis/taming-transformers/blob/master/taming/modules/transformer/mingpt.py
         
     | 
| 3 | 
         
            +
            #   DiT:      https://github.com/facebookresearch/DiT/blob/main/models.py  
         
     | 
| 4 | 
         
            +
            #   nanoGPT:  https://github.com/karpathy/nanoGPT/blob/master/model.py
         
     | 
| 5 | 
         
            +
            #   llama:    https://github.com/facebookresearch/llama/blob/main/llama/model.py
         
     | 
| 6 | 
         
            +
            #   gpt-fast: https://github.com/pytorch-labs/gpt-fast/blob/main/model.py
         
     | 
| 7 | 
         
            +
            #   PixArt:   https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
         
     | 
| 8 | 
         
            +
            from dataclasses import dataclass
         
     | 
| 9 | 
         
            +
            from typing import Optional, List
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            import torch
         
     | 
| 13 | 
         
            +
            import torch.nn as nn
         
     | 
| 14 | 
         
            +
            from torch.nn import functional as F
         
     | 
| 15 | 
         
            +
            from utils.drop_path import DropPath
         
     | 
| 16 | 
         
            +
            # from autoregressive.models.vit_adapter import ViT_Adapter
         
     | 
| 17 | 
         
            +
            from autoregressive.models.dinov2_adapter import Dinov2_Adapter
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            def get_causal_mask(seq_length):
         
     | 
| 21 | 
         
            +
                mask = torch.triu(torch.ones(seq_length, seq_length), diagonal=1).type(torch.bool)
         
     | 
| 22 | 
         
            +
                mask = mask.masked_fill(mask, float('-inf'))  
         
     | 
| 23 | 
         
            +
                mask = mask.masked_fill(~mask, float(0.0))  
         
     | 
| 24 | 
         
            +
                return mask
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
            def find_multiple(n: int, k: int):
         
     | 
| 27 | 
         
            +
                if n % k == 0:
         
     | 
| 28 | 
         
            +
                    return n
         
     | 
| 29 | 
         
            +
                return n + k - (n % k)
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
            @dataclass
         
     | 
| 32 | 
         
            +
            class ModelArgs:
         
     | 
| 33 | 
         
            +
                dim: int = 4096
         
     | 
| 34 | 
         
            +
                n_layer: int = 32
         
     | 
| 35 | 
         
            +
                n_head: int = 32
         
     | 
| 36 | 
         
            +
                n_kv_head: Optional[int] = None
         
     | 
| 37 | 
         
            +
                multiple_of: int = 256  # make SwiGLU hidden layer size multiple of large power of 2
         
     | 
| 38 | 
         
            +
                ffn_dim_multiplier: Optional[float] = None
         
     | 
| 39 | 
         
            +
                rope_base: float = 10000
         
     | 
| 40 | 
         
            +
                norm_eps: float = 1e-5
         
     | 
| 41 | 
         
            +
                initializer_range: float = 0.02
         
     | 
| 42 | 
         
            +
                
         
     | 
| 43 | 
         
            +
                token_dropout_p: float = 0.1
         
     | 
| 44 | 
         
            +
                attn_dropout_p: float = 0.0
         
     | 
| 45 | 
         
            +
                resid_dropout_p: float = 0.1
         
     | 
| 46 | 
         
            +
                ffn_dropout_p: float = 0.1
         
     | 
| 47 | 
         
            +
                drop_path_rate: float = 0.0
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
                num_classes: int = 1000
         
     | 
| 50 | 
         
            +
                caption_dim: int = 2048
         
     | 
| 51 | 
         
            +
                class_dropout_prob: float = 0.1
         
     | 
| 52 | 
         
            +
                model_type: str = 'c2i'
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
                vocab_size: int = 16384
         
     | 
| 55 | 
         
            +
                cls_token_num: int = 1
         
     | 
| 56 | 
         
            +
                block_size: int = 256
         
     | 
| 57 | 
         
            +
                max_batch_size: int = 32
         
     | 
| 58 | 
         
            +
                max_seq_len: int = 2048
         
     | 
| 59 | 
         
            +
                adapter_size: str = 'small'
         
     | 
| 60 | 
         
            +
                condition_type: str = 'canny'
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
            #################################################################################
         
     | 
| 65 | 
         
            +
            #                      Embedding Layers for Class Labels                        #
         
     | 
| 66 | 
         
            +
            #################################################################################
         
     | 
| 67 | 
         
            +
            class LabelEmbedder(nn.Module):
         
     | 
| 68 | 
         
            +
                """
         
     | 
| 69 | 
         
            +
                Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
         
     | 
| 70 | 
         
            +
                """
         
     | 
| 71 | 
         
            +
                def __init__(self, num_classes, hidden_size, dropout_prob):
         
     | 
| 72 | 
         
            +
                    super().__init__()
         
     | 
| 73 | 
         
            +
                    use_cfg_embedding = dropout_prob > 0
         
     | 
| 74 | 
         
            +
                    self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
         
     | 
| 75 | 
         
            +
                    self.num_classes = num_classes
         
     | 
| 76 | 
         
            +
                    self.dropout_prob = dropout_prob
         
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
                def token_drop(self, labels, force_drop_ids=None):
         
     | 
| 79 | 
         
            +
                    """
         
     | 
| 80 | 
         
            +
                    Drops labels to enable classifier-free guidance.
         
     | 
| 81 | 
         
            +
                    """
         
     | 
| 82 | 
         
            +
                    if force_drop_ids is None:
         
     | 
| 83 | 
         
            +
                        drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
         
     | 
| 84 | 
         
            +
                    else:
         
     | 
| 85 | 
         
            +
                        drop_ids = force_drop_ids == 1
         
     | 
| 86 | 
         
            +
                    labels = torch.where(drop_ids, self.num_classes, labels)
         
     | 
| 87 | 
         
            +
                    return labels, drop_ids
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
                def forward(self, labels, train, force_drop_ids=None):
         
     | 
| 90 | 
         
            +
                    use_dropout = self.dropout_prob > 0
         
     | 
| 91 | 
         
            +
                    if (train and use_dropout) or (force_drop_ids is not None):
         
     | 
| 92 | 
         
            +
                        labels,drop_ids = self.token_drop(labels, force_drop_ids)
         
     | 
| 93 | 
         
            +
                    embeddings = self.embedding_table(labels).unsqueeze(1)
         
     | 
| 94 | 
         
            +
                    if (train and use_dropout) or (force_drop_ids is not None):
         
     | 
| 95 | 
         
            +
                        return embeddings,drop_ids
         
     | 
| 96 | 
         
            +
                    else:
         
     | 
| 97 | 
         
            +
                        return embeddings
         
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
            class ConditionEmbedder(nn.Module):
         
     | 
| 101 | 
         
            +
                """
         
     | 
| 102 | 
         
            +
                Embeds Condition into vector representations. Also handles label dropout for classifier-free guidance.
         
     | 
| 103 | 
         
            +
                """
         
     | 
| 104 | 
         
            +
                def __init__(self, in_channels, hidden_size, uncond_prob, token_num=120, vocab_size=16384):
         
     | 
| 105 | 
         
            +
                    super().__init__()
         
     | 
| 106 | 
         
            +
                    self.cap_proj = MLP(in_features=hidden_size, hidden_features=hidden_size, out_features=hidden_size)
         
     | 
| 107 | 
         
            +
                    self.register_buffer("uncond_embedding", torch.zeros(token_num, hidden_size) / hidden_size ** 0.5)
         
     | 
| 108 | 
         
            +
                    self.uncond_prob = uncond_prob
         
     | 
| 109 | 
         
            +
             
     | 
| 110 | 
         
            +
                def token_drop(self, caption, force_drop_ids=None, drop_ids=None):
         
     | 
| 111 | 
         
            +
                    """
         
     | 
| 112 | 
         
            +
                    Drops labels to enable classifier-free guidance.
         
     | 
| 113 | 
         
            +
                    """
         
     | 
| 114 | 
         
            +
                    if force_drop_ids is None:
         
     | 
| 115 | 
         
            +
                        if drop_ids is None:
         
     | 
| 116 | 
         
            +
                            drop_ids = torch.rand(caption.shape[0], device=caption.device) < self.uncond_prob
         
     | 
| 117 | 
         
            +
                    else:
         
     | 
| 118 | 
         
            +
                        drop_ids = force_drop_ids == 1
         
     | 
| 119 | 
         
            +
             
     | 
| 120 | 
         
            +
                    caption = torch.where(drop_ids[:, None, None], self.uncond_embedding[:caption.shape[1]], caption)
         
     | 
| 121 | 
         
            +
                    return caption
         
     | 
| 122 | 
         
            +
             
     | 
| 123 | 
         
            +
                def forward(self, caption, train, force_drop_ids=None, drop_ids=None):
         
     | 
| 124 | 
         
            +
                    use_dropout = self.uncond_prob > 0
         
     | 
| 125 | 
         
            +
                    if (train and use_dropout) or (force_drop_ids is not None):
         
     | 
| 126 | 
         
            +
                        caption = self.token_drop(caption, force_drop_ids, drop_ids)
         
     | 
| 127 | 
         
            +
                    embeddings = self.cap_proj(caption)
         
     | 
| 128 | 
         
            +
                    return embeddings
         
     | 
| 129 | 
         
            +
             
     | 
| 130 | 
         
            +
            #################################################################################
         
     | 
| 131 | 
         
            +
            #                      Embedding Layers for Text Feature                        #
         
     | 
| 132 | 
         
            +
            #################################################################################
         
     | 
| 133 | 
         
            +
            class CaptionEmbedder(nn.Module):
         
     | 
| 134 | 
         
            +
                """
         
     | 
| 135 | 
         
            +
                Embeds text caption into vector representations. Also handles label dropout for classifier-free guidance.
         
     | 
| 136 | 
         
            +
                """
         
     | 
| 137 | 
         
            +
                def __init__(self, in_channels, hidden_size, uncond_prob, token_num=120):
         
     | 
| 138 | 
         
            +
                    super().__init__()
         
     | 
| 139 | 
         
            +
                    self.cap_proj = MLP(in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size)
         
     | 
| 140 | 
         
            +
                    self.register_buffer("uncond_embedding", nn.Parameter(torch.randn(token_num, in_channels) / in_channels ** 0.5))
         
     | 
| 141 | 
         
            +
                    self.uncond_prob = uncond_prob
         
     | 
| 142 | 
         
            +
             
     | 
| 143 | 
         
            +
                def token_drop(self, caption, force_drop_ids=None):
         
     | 
| 144 | 
         
            +
                    """
         
     | 
| 145 | 
         
            +
                    Drops labels to enable classifier-free guidance.
         
     | 
| 146 | 
         
            +
                    """
         
     | 
| 147 | 
         
            +
                    if force_drop_ids is None:
         
     | 
| 148 | 
         
            +
                        drop_ids = torch.rand(caption.shape[0], device=caption.device) < self.uncond_prob
         
     | 
| 149 | 
         
            +
                    else:
         
     | 
| 150 | 
         
            +
                        drop_ids = force_drop_ids == 1
         
     | 
| 151 | 
         
            +
                    caption = torch.where(drop_ids[:, None, None], self.uncond_embedding, caption)
         
     | 
| 152 | 
         
            +
                    return caption, drop_ids
         
     | 
| 153 | 
         
            +
             
     | 
| 154 | 
         
            +
                def forward(self, caption, train, force_drop_ids=None):
         
     | 
| 155 | 
         
            +
                    use_dropout = self.uncond_prob > 0
         
     | 
| 156 | 
         
            +
                    if (train and use_dropout) or (force_drop_ids is not None):
         
     | 
| 157 | 
         
            +
                        caption, drop_ids = self.token_drop(caption, force_drop_ids)
         
     | 
| 158 | 
         
            +
                    embeddings = self.cap_proj(caption)
         
     | 
| 159 | 
         
            +
                    if (train and use_dropout) or (force_drop_ids is not None):
         
     | 
| 160 | 
         
            +
                        return embeddings,drop_ids
         
     | 
| 161 | 
         
            +
                    else:
         
     | 
| 162 | 
         
            +
                        return embeddings
         
     | 
| 163 | 
         
            +
             
     | 
| 164 | 
         
            +
             
     | 
| 165 | 
         
            +
            class MLP(nn.Module):
         
     | 
| 166 | 
         
            +
                def __init__(self, in_features, hidden_features, out_features):
         
     | 
| 167 | 
         
            +
                    super().__init__()
         
     | 
| 168 | 
         
            +
                    out_features = out_features or in_features
         
     | 
| 169 | 
         
            +
                    hidden_features = hidden_features or in_features
         
     | 
| 170 | 
         
            +
                    self.fc1 = nn.Linear(in_features, hidden_features, bias=False)
         
     | 
| 171 | 
         
            +
                    self.act = nn.GELU(approximate='tanh')
         
     | 
| 172 | 
         
            +
                    self.fc2 = nn.Linear(hidden_features, out_features, bias=False)
         
     | 
| 173 | 
         
            +
                    
         
     | 
| 174 | 
         
            +
                    nn.init.zeros_(self.fc1.weight)
         
     | 
| 175 | 
         
            +
                    nn.init.zeros_(self.fc2.weight)
         
     | 
| 176 | 
         
            +
             
     | 
| 177 | 
         
            +
                def forward(self, x):
         
     | 
| 178 | 
         
            +
                    x = self.fc1(x)
         
     | 
| 179 | 
         
            +
                    x = self.act(x)
         
     | 
| 180 | 
         
            +
                    x = self.fc2(x)
         
     | 
| 181 | 
         
            +
                    return x
         
     | 
| 182 | 
         
            +
             
     | 
| 183 | 
         
            +
             
     | 
| 184 | 
         
            +
            #################################################################################
         
     | 
| 185 | 
         
            +
            #                                  GPT Model                                    #
         
     | 
| 186 | 
         
            +
            #################################################################################
         
     | 
| 187 | 
         
            +
            class RMSNorm(torch.nn.Module):
         
     | 
| 188 | 
         
            +
                def __init__(self, dim: int, eps: float = 1e-5):
         
     | 
| 189 | 
         
            +
                    super().__init__()
         
     | 
| 190 | 
         
            +
                    self.eps = eps
         
     | 
| 191 | 
         
            +
                    self.weight = nn.Parameter(torch.ones(dim))
         
     | 
| 192 | 
         
            +
             
     | 
| 193 | 
         
            +
                def _norm(self, x):
         
     | 
| 194 | 
         
            +
                    return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
         
     | 
| 195 | 
         
            +
             
     | 
| 196 | 
         
            +
                def forward(self, x):
         
     | 
| 197 | 
         
            +
                    output = self._norm(x.float()).type_as(x)
         
     | 
| 198 | 
         
            +
                    return output * self.weight
         
     | 
| 199 | 
         
            +
             
     | 
| 200 | 
         
            +
             
     | 
| 201 | 
         
            +
            class FeedForward(nn.Module):
         
     | 
| 202 | 
         
            +
                def __init__(self, config: ModelArgs):
         
     | 
| 203 | 
         
            +
                    super().__init__()
         
     | 
| 204 | 
         
            +
                    hidden_dim = 4 * config.dim
         
     | 
| 205 | 
         
            +
                    hidden_dim = int(2 * hidden_dim / 3)
         
     | 
| 206 | 
         
            +
                    # custom dim factor multiplier
         
     | 
| 207 | 
         
            +
                    if config.ffn_dim_multiplier is not None:
         
     | 
| 208 | 
         
            +
                        hidden_dim = int(config.ffn_dim_multiplier * hidden_dim)
         
     | 
| 209 | 
         
            +
                    hidden_dim = find_multiple(hidden_dim, config.multiple_of)
         
     | 
| 210 | 
         
            +
             
     | 
| 211 | 
         
            +
                    self.w1 = nn.Linear(config.dim, hidden_dim, bias=False)
         
     | 
| 212 | 
         
            +
                    self.w3 = nn.Linear(config.dim, hidden_dim, bias=False)
         
     | 
| 213 | 
         
            +
                    self.w2 = nn.Linear(hidden_dim, config.dim, bias=False)
         
     | 
| 214 | 
         
            +
                    self.ffn_dropout = nn.Dropout(config.ffn_dropout_p)
         
     | 
| 215 | 
         
            +
             
     | 
| 216 | 
         
            +
                def forward(self, x):
         
     | 
| 217 | 
         
            +
                    return self.ffn_dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
         
     | 
| 218 | 
         
            +
             
     | 
| 219 | 
         
            +
             
     | 
| 220 | 
         
            +
            class KVCache(nn.Module):
         
     | 
| 221 | 
         
            +
                def __init__(self, max_batch_size, max_seq_length, n_head, head_dim, dtype):
         
     | 
| 222 | 
         
            +
                    super().__init__()
         
     | 
| 223 | 
         
            +
                    cache_shape = (max_batch_size, n_head, max_seq_length, head_dim)
         
     | 
| 224 | 
         
            +
                    self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype))
         
     | 
| 225 | 
         
            +
                    self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype))
         
     | 
| 226 | 
         
            +
             
     | 
| 227 | 
         
            +
                def update(self, input_pos, k_val, v_val):
         
     | 
| 228 | 
         
            +
                    # input_pos: [S], k_val: [B, H, S, D]
         
     | 
| 229 | 
         
            +
                    assert input_pos.shape[0] == k_val.shape[2]
         
     | 
| 230 | 
         
            +
                    k_out = self.k_cache
         
     | 
| 231 | 
         
            +
                    v_out = self.v_cache
         
     | 
| 232 | 
         
            +
                    k_out[:, :, input_pos] = k_val
         
     | 
| 233 | 
         
            +
                    v_out[:, :, input_pos] = v_val
         
     | 
| 234 | 
         
            +
             
     | 
| 235 | 
         
            +
                    return k_out, v_out
         
     | 
| 236 | 
         
            +
             
     | 
| 237 | 
         
            +
             
     | 
| 238 | 
         
            +
            class Attention(nn.Module):
         
     | 
| 239 | 
         
            +
                def __init__(self, config: ModelArgs):
         
     | 
| 240 | 
         
            +
                    super().__init__()
         
     | 
| 241 | 
         
            +
                    assert config.dim % config.n_head == 0
         
     | 
| 242 | 
         
            +
                    self.dim = config.dim
         
     | 
| 243 | 
         
            +
                    self.head_dim = config.dim // config.n_head
         
     | 
| 244 | 
         
            +
                    self.n_head = config.n_head
         
     | 
| 245 | 
         
            +
                    self.n_kv_head = config.n_kv_head if config.n_kv_head is not None else config.n_head
         
     | 
| 246 | 
         
            +
                    total_kv_dim = (self.n_head + 2 * self.n_kv_head) * self.head_dim
         
     | 
| 247 | 
         
            +
             
     | 
| 248 | 
         
            +
                    # key, query, value projections for all heads, but in a batch
         
     | 
| 249 | 
         
            +
                    self.wqkv = nn.Linear(config.dim, total_kv_dim, bias=False)
         
     | 
| 250 | 
         
            +
                    self.wo = nn.Linear(config.dim, config.dim, bias=False)
         
     | 
| 251 | 
         
            +
                    self.kv_cache = None
         
     | 
| 252 | 
         
            +
             
     | 
| 253 | 
         
            +
                    # regularization
         
     | 
| 254 | 
         
            +
                    self.attn_dropout_p = config.attn_dropout_p
         
     | 
| 255 | 
         
            +
                    self.resid_dropout = nn.Dropout(config.resid_dropout_p)
         
     | 
| 256 | 
         
            +
             
     | 
| 257 | 
         
            +
                def forward(
         
     | 
| 258 | 
         
            +
                    self, x: torch.Tensor, freqs_cis: torch.Tensor = None, 
         
     | 
| 259 | 
         
            +
                    input_pos: Optional[torch.Tensor] = None, 
         
     | 
| 260 | 
         
            +
                    mask: Optional[torch.Tensor] = None
         
     | 
| 261 | 
         
            +
                ):
         
     | 
| 262 | 
         
            +
                    bsz, seqlen, _ = x.shape
         
     | 
| 263 | 
         
            +
                    kv_size = self.n_kv_head * self.head_dim
         
     | 
| 264 | 
         
            +
                    xq, xk, xv = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
         
     | 
| 265 | 
         
            +
             
     | 
| 266 | 
         
            +
                    xq = xq.view(bsz, seqlen, self.n_head, self.head_dim)
         
     | 
| 267 | 
         
            +
                    xk = xk.view(bsz, seqlen, self.n_kv_head, self.head_dim)
         
     | 
| 268 | 
         
            +
                    xv = xv.view(bsz, seqlen, self.n_kv_head, self.head_dim)
         
     | 
| 269 | 
         
            +
                    
         
     | 
| 270 | 
         
            +
                    xq = apply_rotary_emb(xq, freqs_cis)
         
     | 
| 271 | 
         
            +
                    xk = apply_rotary_emb(xk, freqs_cis)
         
     | 
| 272 | 
         
            +
             
     | 
| 273 | 
         
            +
                    xq, xk, xv = map(lambda x: x.transpose(1, 2), (xq, xk, xv))
         
     | 
| 274 | 
         
            +
             
     | 
| 275 | 
         
            +
                    if self.kv_cache is not None:
         
     | 
| 276 | 
         
            +
                        keys, values = self.kv_cache.update(input_pos, xk, xv)
         
     | 
| 277 | 
         
            +
                    else:
         
     | 
| 278 | 
         
            +
                        keys, values = xk, xv
         
     | 
| 279 | 
         
            +
                    keys = keys.repeat_interleave(self.n_head // self.n_kv_head, dim=1)
         
     | 
| 280 | 
         
            +
                    values = values.repeat_interleave(self.n_head // self.n_kv_head, dim=1)
         
     | 
| 281 | 
         
            +
             
     | 
| 282 | 
         
            +
                    output = F.scaled_dot_product_attention(
         
     | 
| 283 | 
         
            +
                        xq, keys, values, 
         
     | 
| 284 | 
         
            +
                        attn_mask=mask, 
         
     | 
| 285 | 
         
            +
                        is_causal=True if mask is None else False, # is_causal=False is for KV cache
         
     | 
| 286 | 
         
            +
                        dropout_p=self.attn_dropout_p if self.training else 0)            
         
     | 
| 287 | 
         
            +
                    
         
     | 
| 288 | 
         
            +
                    output = output.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
         
     | 
| 289 | 
         
            +
             
     | 
| 290 | 
         
            +
                    output = self.resid_dropout(self.wo(output))
         
     | 
| 291 | 
         
            +
                    return output
         
     | 
| 292 | 
         
            +
             
     | 
| 293 | 
         
            +
             
     | 
| 294 | 
         
            +
            class TransformerBlock(nn.Module):
         
     | 
| 295 | 
         
            +
                def __init__(self, config: ModelArgs, drop_path: float):
         
     | 
| 296 | 
         
            +
                    super().__init__()
         
     | 
| 297 | 
         
            +
                    self.attention = Attention(config)
         
     | 
| 298 | 
         
            +
                    self.feed_forward = FeedForward(config)
         
     | 
| 299 | 
         
            +
                    self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
         
     | 
| 300 | 
         
            +
                    self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
         
     | 
| 301 | 
         
            +
                    self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
         
     | 
| 302 | 
         
            +
             
     | 
| 303 | 
         
            +
                def forward(
         
     | 
| 304 | 
         
            +
                    self, x: torch.Tensor, freqs_cis: torch.Tensor, start_pos: int, mask: Optional[torch.Tensor] = None):
         
     | 
| 305 | 
         
            +
                    h = x + self.drop_path(self.attention(self.attention_norm(x), freqs_cis, start_pos, mask))
         
     | 
| 306 | 
         
            +
                    out = h + self.drop_path(self.feed_forward(self.ffn_norm(h)))
         
     | 
| 307 | 
         
            +
                    return out
         
     | 
| 308 | 
         
            +
             
     | 
| 309 | 
         
            +
             
     | 
| 310 | 
         
            +
            class Transformer(nn.Module):
         
     | 
| 311 | 
         
            +
                def __init__(self, config: ModelArgs):
         
     | 
| 312 | 
         
            +
                    super().__init__()
         
     | 
| 313 | 
         
            +
                    self.config = config
         
     | 
| 314 | 
         
            +
                    self.vocab_size = config.vocab_size
         
     | 
| 315 | 
         
            +
                    self.n_layer = config.n_layer
         
     | 
| 316 | 
         
            +
                    self.block_size = config.block_size
         
     | 
| 317 | 
         
            +
                    self.num_classes = config.num_classes
         
     | 
| 318 | 
         
            +
                    self.model_type = config.model_type
         
     | 
| 319 | 
         
            +
                    self.cls_token_num = config.cls_token_num
         
     | 
| 320 | 
         
            +
                    self.layer_internal = config.n_layer // 3
         
     | 
| 321 | 
         
            +
                    # self.adapter = Adapter(output_dim=768)
         
     | 
| 322 | 
         
            +
                    # self.adapter = ViT_Adapter()
         
     | 
| 323 | 
         
            +
                    # self.adapter = DeiT_Adapter()
         
     | 
| 324 | 
         
            +
                    self.adapter = Dinov2_Adapter(adapter_size=config.adapter_size, condition_type=config.condition_type)
         
     | 
| 325 | 
         
            +
                    # self.adapter = EVA_Adapter()
         
     | 
| 326 | 
         
            +
                    if config.adapter_size == "small":
         
     | 
| 327 | 
         
            +
                        self.adapter_mlp = MLP(384, config.dim, config.dim)
         
     | 
| 328 | 
         
            +
                    elif config.adapter_size == 'base':
         
     | 
| 329 | 
         
            +
                        self.adapter_mlp = MLP(768, config.dim, config.dim)
         
     | 
| 330 | 
         
            +
             
     | 
| 331 | 
         
            +
                    if self.model_type == 'c2i':
         
     | 
| 332 | 
         
            +
                        self.cls_embedding = LabelEmbedder(config.num_classes, config.dim, config.class_dropout_prob)
         
     | 
| 333 | 
         
            +
                    elif self.model_type == 't2i':
         
     | 
| 334 | 
         
            +
                        self.cls_embedding = CaptionEmbedder(config.caption_dim, config.dim, config.class_dropout_prob)
         
     | 
| 335 | 
         
            +
                    else:
         
     | 
| 336 | 
         
            +
                        raise Exception("please check model type")
         
     | 
| 337 | 
         
            +
                    self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
         
     | 
| 338 | 
         
            +
                    self.tok_dropout = nn.Dropout(config.token_dropout_p)
         
     | 
| 339 | 
         
            +
             
     | 
| 340 | 
         
            +
                    self.condition_embeddings = nn.Embedding(config.vocab_size, config.dim)
         
     | 
| 341 | 
         
            +
                    self.condition_mlp = ConditionEmbedder(self.block_size, config.dim, config.class_dropout_prob, self.block_size, config.vocab_size)
         
     | 
| 342 | 
         
            +
                    self.condition_layers = torch.nn.ModuleList()
         
     | 
| 343 | 
         
            +
                    for layer_id in range(3):
         
     | 
| 344 | 
         
            +
                        self.condition_layers.append(MLP(config.dim,config.dim,config.dim))
         
     | 
| 345 | 
         
            +
             
     | 
| 346 | 
         
            +
                    # transformer blocks
         
     | 
| 347 | 
         
            +
                    dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.n_layer)]
         
     | 
| 348 | 
         
            +
                    self.layers = torch.nn.ModuleList()
         
     | 
| 349 | 
         
            +
                    for layer_id in range(config.n_layer):
         
     | 
| 350 | 
         
            +
                        self.layers.append(TransformerBlock(config, dpr[layer_id]))
         
     | 
| 351 | 
         
            +
             
     | 
| 352 | 
         
            +
                    # output layer
         
     | 
| 353 | 
         
            +
                    self.norm = RMSNorm(config.dim, eps=config.norm_eps)
         
     | 
| 354 | 
         
            +
                    self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
         
     | 
| 355 | 
         
            +
             
     | 
| 356 | 
         
            +
                    # 2d rotary pos embedding
         
     | 
| 357 | 
         
            +
                    grid_size = int(self.block_size ** 0.5)
         
     | 
| 358 | 
         
            +
                    assert grid_size * grid_size == self.block_size
         
     | 
| 359 | 
         
            +
                    self.freqs_cis = precompute_freqs_cis_2d(grid_size, self.config.dim // self.config.n_head, self.config.rope_base, self.cls_token_num)
         
     | 
| 360 | 
         
            +
                    
         
     | 
| 361 | 
         
            +
                    # KVCache
         
     | 
| 362 | 
         
            +
                    self.max_batch_size = -1
         
     | 
| 363 | 
         
            +
                    self.max_seq_length = -1
         
     | 
| 364 | 
         
            +
             
     | 
| 365 | 
         
            +
                    self.initialize_weights()
         
     | 
| 366 | 
         
            +
                    self.condition_token = None
         
     | 
| 367 | 
         
            +
                    self.mask = get_causal_mask(256)
         
     | 
| 368 | 
         
            +
                    self.global_token = None
         
     | 
| 369 | 
         
            +
             
     | 
| 370 | 
         
            +
             
     | 
| 371 | 
         
            +
                def initialize_weights(self):        
         
     | 
| 372 | 
         
            +
                    # Initialize nn.Linear and nn.Embedding
         
     | 
| 373 | 
         
            +
                    self.apply(self._init_weights)
         
     | 
| 374 | 
         
            +
             
     | 
| 375 | 
         
            +
                    # Zero-out output layers:
         
     | 
| 376 | 
         
            +
                    nn.init.constant_(self.output.weight, 0)
         
     | 
| 377 | 
         
            +
             
     | 
| 378 | 
         
            +
                    
         
     | 
| 379 | 
         
            +
                    
         
     | 
| 380 | 
         
            +
                def _init_weights(self, module):
         
     | 
| 381 | 
         
            +
                    std = self.config.initializer_range
         
     | 
| 382 | 
         
            +
                    if isinstance(module, nn.Linear):
         
     | 
| 383 | 
         
            +
                        module.weight.data.normal_(mean=0.0, std=std)
         
     | 
| 384 | 
         
            +
                        if module.bias is not None:
         
     | 
| 385 | 
         
            +
                            module.bias.data.zero_()
         
     | 
| 386 | 
         
            +
                    elif isinstance(module, nn.Embedding):
         
     | 
| 387 | 
         
            +
                        module.weight.data.normal_(mean=0.0, std=std)
         
     | 
| 388 | 
         
            +
             
     | 
| 389 | 
         
            +
                    
         
     | 
| 390 | 
         
            +
                def setup_caches(self, max_batch_size, max_seq_length, dtype):
         
     | 
| 391 | 
         
            +
                    # if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size:
         
     | 
| 392 | 
         
            +
                    #     return
         
     | 
| 393 | 
         
            +
                    head_dim = self.config.dim // self.config.n_head
         
     | 
| 394 | 
         
            +
                    max_seq_length = find_multiple(max_seq_length, 8)  # 
         
     | 
| 395 | 
         
            +
                    self.max_seq_length = max_seq_length
         
     | 
| 396 | 
         
            +
                    self.max_batch_size = max_batch_size
         
     | 
| 397 | 
         
            +
                    for b in self.layers:
         
     | 
| 398 | 
         
            +
                        b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_head, head_dim, dtype)
         
     | 
| 399 | 
         
            +
             
     | 
| 400 | 
         
            +
                    causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool))
         
     | 
| 401 | 
         
            +
                    self.causal_mask = causal_mask.unsqueeze(0).repeat(self.max_batch_size, 1, 1)
         
     | 
| 402 | 
         
            +
                    grid_size = int(self.config.block_size ** 0.5)
         
     | 
| 403 | 
         
            +
                    assert grid_size * grid_size == self.block_size
         
     | 
| 404 | 
         
            +
                    self.freqs_cis = precompute_freqs_cis_2d(grid_size, self.config.dim // self.config.n_head, self.config.rope_base, self.cls_token_num)
         
     | 
| 405 | 
         
            +
             
     | 
| 406 | 
         
            +
             
     | 
| 407 | 
         
            +
                
         
     | 
| 408 | 
         
            +
                def forward(
         
     | 
| 409 | 
         
            +
                    self, 
         
     | 
| 410 | 
         
            +
                    idx: torch.Tensor, 
         
     | 
| 411 | 
         
            +
                    cond_idx: torch.Tensor,  # cond_idx_or_embed
         
     | 
| 412 | 
         
            +
                    input_pos:  Optional[torch.Tensor] = None, 
         
     | 
| 413 | 
         
            +
                    targets: Optional[torch.Tensor] = None,
         
     | 
| 414 | 
         
            +
                    mask: Optional[torch.Tensor] = None,
         
     | 
| 415 | 
         
            +
                    valid: Optional[torch.Tensor] = None,
         
     | 
| 416 | 
         
            +
                    condition: Optional[torch.Tensor] = None
         
     | 
| 417 | 
         
            +
                ):
         
     | 
| 418 | 
         
            +
                    if idx is not None and cond_idx is not None: # training or naive inference
         
     | 
| 419 | 
         
            +
                        cond_embeddings,drop_ids = self.cls_embedding(cond_idx, train=self.training)
         
     | 
| 420 | 
         
            +
                        cond_embeddings = cond_embeddings[:,:self.cls_token_num]
         
     | 
| 421 | 
         
            +
                        token_embeddings = self.tok_embeddings(idx)
         
     | 
| 422 | 
         
            +
                        if condition is not None:
         
     | 
| 423 | 
         
            +
                            condition_embeddings = self.adapter(condition)
         
     | 
| 424 | 
         
            +
                            condition_embeddings = self.adapter_mlp(condition_embeddings)
         
     | 
| 425 | 
         
            +
                            self.condition_token = self.condition_mlp(condition_embeddings,train=self.training, drop_ids=drop_ids)
         
     | 
| 426 | 
         
            +
                        token_embeddings = torch.cat((cond_embeddings, token_embeddings), dim=1)
         
     | 
| 427 | 
         
            +
             
     | 
| 428 | 
         
            +
                        h = self.tok_dropout(token_embeddings)
         
     | 
| 429 | 
         
            +
                        self.freqs_cis = self.freqs_cis.to(h.device)
         
     | 
| 430 | 
         
            +
                    else:
         
     | 
| 431 | 
         
            +
                        if cond_idx is not None: # prefill in inference
         
     | 
| 432 | 
         
            +
                            token_embeddings = self.cls_embedding(cond_idx, train=self.training)
         
     | 
| 433 | 
         
            +
                            token_embeddings = token_embeddings[:,:self.cls_token_num]
         
     | 
| 434 | 
         
            +
                            if condition is not None:
         
     | 
| 435 | 
         
            +
                                condition_embeddings = self.condition_mlp(condition.to(torch.bfloat16),train=self.training)
         
     | 
| 436 | 
         
            +
                                self.condition_token = condition_embeddings
         
     | 
| 437 | 
         
            +
                                
         
     | 
| 438 | 
         
            +
                        else: # decode_n_tokens(kv cache) in inference
         
     | 
| 439 | 
         
            +
                            token_embeddings = self.tok_embeddings(idx)
         
     | 
| 440 | 
         
            +
                        bs = token_embeddings.shape[0]
         
     | 
| 441 | 
         
            +
                        mask = self.causal_mask[:bs, None, input_pos]
         
     | 
| 442 | 
         
            +
                        h = self.tok_dropout(token_embeddings)
         
     | 
| 443 | 
         
            +
                        self.freqs_cis = self.freqs_cis
         
     | 
| 444 | 
         
            +
             
     | 
| 445 | 
         
            +
                    if self.training:
         
     | 
| 446 | 
         
            +
                        freqs_cis = self.freqs_cis[:token_embeddings.shape[1]]
         
     | 
| 447 | 
         
            +
                    else:
         
     | 
| 448 | 
         
            +
                        freqs_cis = self.freqs_cis[input_pos]
         
     | 
| 449 | 
         
            +
                    # transformer blocks
         
     | 
| 450 | 
         
            +
                    for i, layer in enumerate(self.layers):
         
     | 
| 451 | 
         
            +
                        if i%self.layer_internal == 0:
         
     | 
| 452 | 
         
            +
                            if self.training:
         
     | 
| 453 | 
         
            +
                                h[:, self.cls_token_num-1:] = h[:, self.cls_token_num-1:] + self.condition_layers[i//self.layer_internal](self.condition_token)
         
     | 
| 454 | 
         
            +
                            else:
         
     | 
| 455 | 
         
            +
                                if len(input_pos)>1:
         
     | 
| 456 | 
         
            +
                                    h[:, -1:] = h[:, -1:] + self.condition_layers[i//self.layer_internal](self.condition_token[:,0:1])
         
     | 
| 457 | 
         
            +
                                else:
         
     | 
| 458 | 
         
            +
                                    h = h + self.condition_layers[i//self.layer_internal](self.condition_token[:,input_pos-self.cls_token_num+1])
         
     | 
| 459 | 
         
            +
                        h = layer(h, freqs_cis, input_pos, mask)
         
     | 
| 460 | 
         
            +
                    # output layers
         
     | 
| 461 | 
         
            +
                    h = self.norm(h)
         
     | 
| 462 | 
         
            +
                    logits = self.output(h).float()
         
     | 
| 463 | 
         
            +
                    
         
     | 
| 464 | 
         
            +
                    if self.training:
         
     | 
| 465 | 
         
            +
                        logits = logits[:, self.cls_token_num - 1:].contiguous()
         
     | 
| 466 | 
         
            +
                    # if we are given some desired targets also calculate the loss
         
     | 
| 467 | 
         
            +
                    loss = None
         
     | 
| 468 | 
         
            +
                    if valid is not None:
         
     | 
| 469 | 
         
            +
                        loss_all = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), reduction='none')
         
     | 
| 470 | 
         
            +
                        valid_all = valid[:,None].repeat(1, targets.shape[1]).view(-1)
         
     | 
| 471 | 
         
            +
                        loss = (loss_all * valid_all).sum() / max(valid_all.sum(), 1)
         
     | 
| 472 | 
         
            +
                    elif targets is not None:
         
     | 
| 473 | 
         
            +
                        loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
         
     | 
| 474 | 
         
            +
             
     | 
| 475 | 
         
            +
             
     | 
| 476 | 
         
            +
                    return logits, loss
         
     | 
| 477 | 
         
            +
             
     | 
| 478 | 
         
            +
             
     | 
| 479 | 
         
            +
                def get_fsdp_wrap_module_list(self) -> List[nn.Module]:
         
     | 
| 480 | 
         
            +
                    return list(self.layers)
         
     | 
| 481 | 
         
            +
             
     | 
| 482 | 
         
            +
             
     | 
| 483 | 
         
            +
             
     | 
| 484 | 
         
            +
            #################################################################################
         
     | 
| 485 | 
         
            +
            #                      Rotary Positional Embedding Functions                    #
         
     | 
| 486 | 
         
            +
            #################################################################################
         
     | 
| 487 | 
         
            +
            # https://github.com/pytorch-labs/gpt-fast/blob/main/model.py 
         
     | 
| 488 | 
         
            +
            def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000, cls_token_num=120):
         
     | 
| 489 | 
         
            +
                freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem))
         
     | 
| 490 | 
         
            +
                t = torch.arange(seq_len, device=freqs.device)
         
     | 
| 491 | 
         
            +
                freqs = torch.outer(t, freqs) # (seq_len, head_dim // 2)
         
     | 
| 492 | 
         
            +
                freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
         
     | 
| 493 | 
         
            +
                cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) # (cls_token_num+seq_len, head_dim // 2, 2)
         
     | 
| 494 | 
         
            +
                cond_cache = torch.cat([torch.zeros(cls_token_num, n_elem // 2, 2), cache]) # (cls_token_num+seq_len, head_dim // 2, 2)
         
     | 
| 495 | 
         
            +
                return cond_cache 
         
     | 
| 496 | 
         
            +
             
     | 
| 497 | 
         
            +
             
     | 
| 498 | 
         
            +
            def precompute_freqs_cis_2d(grid_size: int, n_elem: int, base: int = 10000, cls_token_num=120):
         
     | 
| 499 | 
         
            +
                # split the dimension into half, one for x and one for y
         
     | 
| 500 | 
         
            +
                half_dim = n_elem // 2
         
     | 
| 501 | 
         
            +
                freqs = 1.0 / (base ** (torch.arange(0, half_dim, 2)[: (half_dim // 2)].float() / half_dim))
         
     | 
| 502 | 
         
            +
                t = torch.arange(grid_size, device=freqs.device)
         
     | 
| 503 | 
         
            +
                freqs = torch.outer(t, freqs) # (grid_size, head_dim // 2)
         
     | 
| 504 | 
         
            +
                freqs_grid = torch.concat([
         
     | 
| 505 | 
         
            +
                    freqs[:, None, :].expand(-1, grid_size, -1),
         
     | 
| 506 | 
         
            +
                    freqs[None, :, :].expand(grid_size, -1, -1),
         
     | 
| 507 | 
         
            +
                ], dim=-1)  # (grid_size, grid_size, head_dim // 2)
         
     | 
| 508 | 
         
            +
                cache_grid = torch.stack([torch.cos(freqs_grid), torch.sin(freqs_grid)], dim=-1) # (grid_size, grid_size, head_dim // 2, 2)
         
     | 
| 509 | 
         
            +
                cache = cache_grid.flatten(0, 1)
         
     | 
| 510 | 
         
            +
                cond_cache = torch.cat([torch.zeros(cls_token_num, n_elem // 2, 2), cache]) # (cls_token_num+grid_size**2, head_dim // 2, 2)
         
     | 
| 511 | 
         
            +
                return cond_cache 
         
     | 
| 512 | 
         
            +
             
     | 
| 513 | 
         
            +
             
     | 
| 514 | 
         
            +
            def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor):
         
     | 
| 515 | 
         
            +
                # x: (bs, seq_len, n_head, head_dim)
         
     | 
| 516 | 
         
            +
                # freqs_cis (seq_len, head_dim // 2, 2)
         
     | 
| 517 | 
         
            +
                xshaped = x.float().reshape(*x.shape[:-1], -1, 2) # (bs, seq_len, n_head, head_dim//2, 2)
         
     | 
| 518 | 
         
            +
                freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2) # (1, seq_len, 1, head_dim//2, 2)
         
     | 
| 519 | 
         
            +
                x_out2 = torch.stack([
         
     | 
| 520 | 
         
            +
                        xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
         
     | 
| 521 | 
         
            +
                        xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
         
     | 
| 522 | 
         
            +
                ], dim=-1)
         
     | 
| 523 | 
         
            +
                x_out2 = x_out2.flatten(3)
         
     | 
| 524 | 
         
            +
                return x_out2.type_as(x)
         
     | 
| 525 | 
         
            +
             
     | 
| 526 | 
         
            +
             
     | 
| 527 | 
         
            +
             
     | 
| 528 | 
         
            +
            #################################################################################
         
     | 
| 529 | 
         
            +
            #                                GPT Configs                                    #
         
     | 
| 530 | 
         
            +
            #################################################################################
         
     | 
| 531 | 
         
            +
            ### text-conditional
         
     | 
| 532 | 
         
            +
            def GPT_7B(**kwargs):
         
     | 
| 533 | 
         
            +
                return Transformer(ModelArgs(n_layer=32, n_head=32, dim=4096, **kwargs)) # 6.6B
         
     | 
| 534 | 
         
            +
             
     | 
| 535 | 
         
            +
            def GPT_3B(**kwargs):
         
     | 
| 536 | 
         
            +
                return Transformer(ModelArgs(n_layer=24, n_head=32, dim=3200, **kwargs)) # 3.1B
         
     | 
| 537 | 
         
            +
             
     | 
| 538 | 
         
            +
            def GPT_1B(**kwargs):
         
     | 
| 539 | 
         
            +
                return Transformer(ModelArgs(n_layer=22, n_head=32, dim=2048, **kwargs)) # 1.2B
         
     | 
| 540 | 
         
            +
             
     | 
| 541 | 
         
            +
            ### class-conditional
         
     | 
| 542 | 
         
            +
            def GPT_XXXL(**kwargs):
         
     | 
| 543 | 
         
            +
                return Transformer(ModelArgs(n_layer=48, n_head=40, dim=2560, **kwargs)) # 3.9B
         
     | 
| 544 | 
         
            +
             
     | 
| 545 | 
         
            +
            def GPT_XXL(**kwargs):
         
     | 
| 546 | 
         
            +
                return Transformer(ModelArgs(n_layer=48, n_head=24, dim=1536, **kwargs)) # 1.4B
         
     | 
| 547 | 
         
            +
             
     | 
| 548 | 
         
            +
            def GPT_XL(**kwargs):
         
     | 
| 549 | 
         
            +
                return Transformer(ModelArgs(n_layer=36, n_head=20, dim=1280, **kwargs)) # 775M
         
     | 
| 550 | 
         
            +
             
     | 
| 551 | 
         
            +
            def GPT_L(**kwargs):
         
     | 
| 552 | 
         
            +
                return Transformer(ModelArgs(n_layer=24, n_head=16, dim=1024, **kwargs)) # 343M
         
     | 
| 553 | 
         
            +
             
     | 
| 554 | 
         
            +
            def GPT_B(**kwargs):
         
     | 
| 555 | 
         
            +
                return Transformer(ModelArgs(n_layer=12, n_head=12, dim=768, **kwargs)) # 111M
         
     | 
| 556 | 
         
            +
                    
         
     | 
| 557 | 
         
            +
             
     | 
| 558 | 
         
            +
            GPT_models = {
         
     | 
| 559 | 
         
            +
                'GPT-B': GPT_B, 'GPT-L': GPT_L, 'GPT-XL': GPT_XL, 'GPT-XXL': GPT_XXL, 'GPT-XXXL': GPT_XXXL,
         
     | 
| 560 | 
         
            +
                'GPT-1B': GPT_1B, 'GPT-3B': GPT_3B, 'GPT-7B': GPT_7B, 
         
     | 
| 561 | 
         
            +
            }
         
     | 
    	
        autoregressive/sample/sample_c2i.py
    ADDED
    
    | 
         @@ -0,0 +1,151 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Modified from:
         
     | 
| 2 | 
         
            +
            #   DiT:  https://github.com/facebookresearch/DiT/blob/main/sample.py
         
     | 
| 3 | 
         
            +
            import torch
         
     | 
| 4 | 
         
            +
            torch.backends.cuda.matmul.allow_tf32 = True
         
     | 
| 5 | 
         
            +
            torch.backends.cudnn.allow_tf32 = True
         
     | 
| 6 | 
         
            +
            torch.set_float32_matmul_precision('high')
         
     | 
| 7 | 
         
            +
            setattr(torch.nn.Linear, 'reset_parameters', lambda self: None)
         
     | 
| 8 | 
         
            +
            setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None)
         
     | 
| 9 | 
         
            +
            from torchvision.utils import save_image
         
     | 
| 10 | 
         
            +
            import os
         
     | 
| 11 | 
         
            +
            import sys
         
     | 
| 12 | 
         
            +
            current_directory = os.getcwd()
         
     | 
| 13 | 
         
            +
            sys.path.append(current_directory)
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            from PIL import Image
         
     | 
| 16 | 
         
            +
            import time
         
     | 
| 17 | 
         
            +
            import argparse
         
     | 
| 18 | 
         
            +
            from tokenizer.tokenizer_image.vq_model import VQ_models
         
     | 
| 19 | 
         
            +
            from autoregressive.models.gpt import GPT_models
         
     | 
| 20 | 
         
            +
            from autoregressive.models.generate import generate
         
     | 
| 21 | 
         
            +
            from functools import partial
         
     | 
| 22 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 23 | 
         
            +
            import numpy as np
         
     | 
| 24 | 
         
            +
            import cv2
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
            def main(args):
         
     | 
| 28 | 
         
            +
                # Setup PyTorch:
         
     | 
| 29 | 
         
            +
                torch.manual_seed(args.seed)
         
     | 
| 30 | 
         
            +
                torch.backends.cudnn.deterministic = True
         
     | 
| 31 | 
         
            +
                torch.backends.cudnn.benchmark = False
         
     | 
| 32 | 
         
            +
                torch.set_grad_enabled(False)
         
     | 
| 33 | 
         
            +
                device = "cuda:0" if torch.cuda.is_available() else "cpu"
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
                # create and load model
         
     | 
| 36 | 
         
            +
                vq_model = VQ_models[args.vq_model](
         
     | 
| 37 | 
         
            +
                    codebook_size=args.codebook_size,
         
     | 
| 38 | 
         
            +
                    codebook_embed_dim=args.codebook_embed_dim)
         
     | 
| 39 | 
         
            +
                vq_model.to(device)
         
     | 
| 40 | 
         
            +
                vq_model.eval()
         
     | 
| 41 | 
         
            +
                checkpoint = torch.load(args.vq_ckpt, map_location="cpu")
         
     | 
| 42 | 
         
            +
                vq_model.load_state_dict(checkpoint["model"])
         
     | 
| 43 | 
         
            +
                del checkpoint
         
     | 
| 44 | 
         
            +
                print(f"image tokenizer is loaded")
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
                # create and load gpt model
         
     | 
| 47 | 
         
            +
                precision = {'none': torch.float32, 'bf16': torch.bfloat16, 'fp16': torch.float16}[args.precision]
         
     | 
| 48 | 
         
            +
                latent_size = args.image_size // args.downsample_size
         
     | 
| 49 | 
         
            +
                gpt_model = GPT_models[args.gpt_model](
         
     | 
| 50 | 
         
            +
                    vocab_size=args.codebook_size,
         
     | 
| 51 | 
         
            +
                    block_size=latent_size ** 2,
         
     | 
| 52 | 
         
            +
                    num_classes=args.num_classes,
         
     | 
| 53 | 
         
            +
                    cls_token_num=args.cls_token_num,
         
     | 
| 54 | 
         
            +
                    model_type=args.gpt_type,
         
     | 
| 55 | 
         
            +
                    condition_token_num=args.condition_token_nums,
         
     | 
| 56 | 
         
            +
                    image_size=args.image_size
         
     | 
| 57 | 
         
            +
                ).to(device=device, dtype=precision)      
         
     | 
| 58 | 
         
            +
                
         
     | 
| 59 | 
         
            +
                _, file_extension = os.path.splitext(args.gpt_ckpt)
         
     | 
| 60 | 
         
            +
                if file_extension.lower() == '.safetensors':
         
     | 
| 61 | 
         
            +
                    from safetensors.torch import load_file
         
     | 
| 62 | 
         
            +
                    model_weight = load_file(args.gpt_ckpt)
         
     | 
| 63 | 
         
            +
                    gpt_model.load_state_dict(model_weight, strict=False)
         
     | 
| 64 | 
         
            +
                    gpt_model.eval()
         
     | 
| 65 | 
         
            +
                else:
         
     | 
| 66 | 
         
            +
                    checkpoint = torch.load(args.gpt_ckpt, map_location="cpu")
         
     | 
| 67 | 
         
            +
                    if "model" in checkpoint:  # ddp
         
     | 
| 68 | 
         
            +
                        model_weight = checkpoint["model"]
         
     | 
| 69 | 
         
            +
                    elif "module" in checkpoint: # deepspeed
         
     | 
| 70 | 
         
            +
                        model_weight = checkpoint["module"]
         
     | 
| 71 | 
         
            +
                    elif "state_dict" in checkpoint:
         
     | 
| 72 | 
         
            +
                        model_weight = checkpoint["state_dict"]
         
     | 
| 73 | 
         
            +
                    else:
         
     | 
| 74 | 
         
            +
                        raise Exception("please check model weight")
         
     | 
| 75 | 
         
            +
                    gpt_model.load_state_dict(model_weight, strict=False)
         
     | 
| 76 | 
         
            +
                    gpt_model.eval()
         
     | 
| 77 | 
         
            +
                    del checkpoint
         
     | 
| 78 | 
         
            +
                print(f"gpt model is loaded")
         
     | 
| 79 | 
         
            +
             
     | 
| 80 | 
         
            +
                if args.compile:
         
     | 
| 81 | 
         
            +
                    print(f"compiling the model...")
         
     | 
| 82 | 
         
            +
                    gpt_model = torch.compile(
         
     | 
| 83 | 
         
            +
                        gpt_model,
         
     | 
| 84 | 
         
            +
                        mode="reduce-overhead",
         
     | 
| 85 | 
         
            +
                        fullgraph=True
         
     | 
| 86 | 
         
            +
                    ) # requires PyTorch 2.0 (optional)
         
     | 
| 87 | 
         
            +
                else:
         
     | 
| 88 | 
         
            +
                    print(f"no need to compile model in demo") 
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
                condition_null = None
         
     | 
| 91 | 
         
            +
                if args.condition_type == 'canny':
         
     | 
| 92 | 
         
            +
                    sample_list = [650, 2312, 15000, 48850]  # canny
         
     | 
| 93 | 
         
            +
                elif args.condition_type == 'depth':
         
     | 
| 94 | 
         
            +
                    sample_list = [101, 4351, 10601, 48901]
         
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
                class_labels = [np.load(f"condition/example/c2i/{args.condition_type}/{i}.npy")[0] for i in sample_list]
         
     | 
| 97 | 
         
            +
                condition_imgs = [np.array(Image.open((f"condition/example/c2i/{args.condition_type}/{i}.png")))[None,None,...] for i in sample_list]
         
     | 
| 98 | 
         
            +
                condition_imgs = torch.from_numpy(np.concatenate(condition_imgs, axis=0)).to(device).to(torch.float32)/255
         
     | 
| 99 | 
         
            +
                condition_imgs = 2*(condition_imgs-0.5)
         
     | 
| 100 | 
         
            +
                print(condition_imgs.shape)
         
     | 
| 101 | 
         
            +
                c_indices = torch.tensor(class_labels, device=device)
         
     | 
| 102 | 
         
            +
                qzshape = [len(class_labels), args.codebook_embed_dim, latent_size, latent_size]
         
     | 
| 103 | 
         
            +
                t1 = time.time()
         
     | 
| 104 | 
         
            +
             
     | 
| 105 | 
         
            +
                index_sample = generate(
         
     | 
| 106 | 
         
            +
                    gpt_model, c_indices, latent_size ** 2, condition=condition_imgs.repeat(1,3,1,1).to(precision), condition_null=condition_null, condition_token_nums=args.condition_token_nums,
         
     | 
| 107 | 
         
            +
                    cfg_scale=args.cfg_scale, cfg_interval=args.cfg_interval,
         
     | 
| 108 | 
         
            +
                    temperature=args.temperature, top_k=args.top_k,
         
     | 
| 109 | 
         
            +
                    top_p=args.top_p, sample_logits=True, 
         
     | 
| 110 | 
         
            +
                    )
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
                sampling_time = time.time() - t1
         
     | 
| 113 | 
         
            +
                print(f"gpt sampling takes about {sampling_time:.2f} seconds.")    
         
     | 
| 114 | 
         
            +
                
         
     | 
| 115 | 
         
            +
                t2 = time.time()
         
     | 
| 116 | 
         
            +
                samples = vq_model.decode_code(index_sample, qzshape) # output value is between [-1, 1]
         
     | 
| 117 | 
         
            +
                decoder_time = time.time() - t2
         
     | 
| 118 | 
         
            +
                print(f"decoder takes about {decoder_time:.2f} seconds.")
         
     | 
| 119 | 
         
            +
                # Save and display images:
         
     | 
| 120 | 
         
            +
                condition_imgs = condition_imgs.repeat(1,3,1,1)
         
     | 
| 121 | 
         
            +
                samples = torch.cat((condition_imgs[:4], samples[:4]),dim=0)
         
     | 
| 122 | 
         
            +
                save_image(samples, f"sample/example/sample_{args.gpt_type}_{args.condition_type}.png", nrow=4, normalize=True, value_range=(-1, 1))
         
     | 
| 123 | 
         
            +
             
     | 
| 124 | 
         
            +
             
     | 
| 125 | 
         
            +
             
     | 
| 126 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 127 | 
         
            +
                parser = argparse.ArgumentParser()
         
     | 
| 128 | 
         
            +
                parser.add_argument("--gpt-model", type=str, choices=list(GPT_models.keys()), default="GPT-B")
         
     | 
| 129 | 
         
            +
                parser.add_argument("--gpt-ckpt", type=str, default=None)
         
     | 
| 130 | 
         
            +
                parser.add_argument("--gpt-type", type=str, choices=['c2i', 't2i'], default="c2i", help="class-conditional or text-conditional")
         
     | 
| 131 | 
         
            +
                parser.add_argument("--from-fsdp", action='store_true')
         
     | 
| 132 | 
         
            +
                parser.add_argument("--cls-token-num", type=int, default=1, help="max token number of condition input")
         
     | 
| 133 | 
         
            +
                parser.add_argument("--precision", type=str, default='bf16', choices=["none", "fp16", "bf16"]) 
         
     | 
| 134 | 
         
            +
                parser.add_argument("--compile", action='store_true', default=False)
         
     | 
| 135 | 
         
            +
                parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16")
         
     | 
| 136 | 
         
            +
                parser.add_argument("--vq-ckpt", type=str, default=None, help="ckpt path for vq model")
         
     | 
| 137 | 
         
            +
                parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization")
         
     | 
| 138 | 
         
            +
                parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization")
         
     | 
| 139 | 
         
            +
                parser.add_argument("--image-size", type=int, choices=[256, 384, 512], default=256)
         
     | 
| 140 | 
         
            +
                parser.add_argument("--downsample-size", type=int, choices=[8, 16], default=16)
         
     | 
| 141 | 
         
            +
                parser.add_argument("--num-classes", type=int, default=1000)
         
     | 
| 142 | 
         
            +
                parser.add_argument("--cfg-scale", type=float, default=4.0)
         
     | 
| 143 | 
         
            +
                parser.add_argument("--cfg-interval", type=float, default=-1)
         
     | 
| 144 | 
         
            +
                parser.add_argument("--seed", type=int, default=0)
         
     | 
| 145 | 
         
            +
                parser.add_argument("--top-k", type=int, default=2000,help="top-k value to sample with")
         
     | 
| 146 | 
         
            +
                parser.add_argument("--temperature", type=float, default=1.0, help="temperature value to sample with")
         
     | 
| 147 | 
         
            +
                parser.add_argument("--top-p", type=float, default=1.0, help="top-p value to sample with")
         
     | 
| 148 | 
         
            +
                parser.add_argument("--condition-token-nums", type=int, default=0)
         
     | 
| 149 | 
         
            +
                parser.add_argument("--condition-type", type=str, default='canny', choices=['canny', 'depth'])
         
     | 
| 150 | 
         
            +
                args = parser.parse_args()
         
     | 
| 151 | 
         
            +
                main(args)
         
     | 
    	
        autoregressive/sample/sample_c2i_ddp.py
    ADDED
    
    | 
         @@ -0,0 +1,188 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Modified from:
         
     | 
| 2 | 
         
            +
            #   DiT:  https://github.com/facebookresearch/DiT/blob/main/sample_ddp.py
         
     | 
| 3 | 
         
            +
            import torch
         
     | 
| 4 | 
         
            +
            torch.backends.cuda.matmul.allow_tf32 = True
         
     | 
| 5 | 
         
            +
            torch.backends.cudnn.allow_tf32 = True
         
     | 
| 6 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 7 | 
         
            +
            import torch.distributed as dist
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            from tqdm import tqdm
         
     | 
| 10 | 
         
            +
            import os
         
     | 
| 11 | 
         
            +
            from PIL import Image
         
     | 
| 12 | 
         
            +
            import numpy as np
         
     | 
| 13 | 
         
            +
            import math
         
     | 
| 14 | 
         
            +
            import argparse
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            from tokenizer.tokenizer_image.vq_model import VQ_models
         
     | 
| 17 | 
         
            +
            from autoregressive.models.gpt import GPT_models
         
     | 
| 18 | 
         
            +
            from autoregressive.models.generate import generate
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            def create_npz_from_sample_folder(sample_dir, num=50_000):
         
     | 
| 22 | 
         
            +
                """
         
     | 
| 23 | 
         
            +
                Builds a single .npz file from a folder of .png samples.
         
     | 
| 24 | 
         
            +
                """
         
     | 
| 25 | 
         
            +
                samples = []
         
     | 
| 26 | 
         
            +
                for i in tqdm(range(num), desc="Building .npz file from samples"):
         
     | 
| 27 | 
         
            +
                    sample_pil = Image.open(f"{sample_dir}/{i:06d}.png")
         
     | 
| 28 | 
         
            +
                    sample_np = np.asarray(sample_pil).astype(np.uint8)
         
     | 
| 29 | 
         
            +
                    samples.append(sample_np)
         
     | 
| 30 | 
         
            +
                samples = np.stack(samples)
         
     | 
| 31 | 
         
            +
                assert samples.shape == (num, samples.shape[1], samples.shape[2], 3)
         
     | 
| 32 | 
         
            +
                npz_path = f"{sample_dir}.npz"
         
     | 
| 33 | 
         
            +
                np.savez(npz_path, arr_0=samples)
         
     | 
| 34 | 
         
            +
                print(f"Saved .npz file to {npz_path} [shape={samples.shape}].")
         
     | 
| 35 | 
         
            +
                return npz_path
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
            def main(args):
         
     | 
| 39 | 
         
            +
                # Setup PyTorch:
         
     | 
| 40 | 
         
            +
                assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage"
         
     | 
| 41 | 
         
            +
                torch.set_grad_enabled(False)
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
                # Setup DDP:
         
     | 
| 44 | 
         
            +
                dist.init_process_group("nccl")
         
     | 
| 45 | 
         
            +
                rank = dist.get_rank()
         
     | 
| 46 | 
         
            +
                device = rank % torch.cuda.device_count()
         
     | 
| 47 | 
         
            +
                seed = args.global_seed * dist.get_world_size() + rank
         
     | 
| 48 | 
         
            +
                torch.manual_seed(seed)
         
     | 
| 49 | 
         
            +
                torch.cuda.set_device(device)
         
     | 
| 50 | 
         
            +
                print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
                # create and load model
         
     | 
| 53 | 
         
            +
                vq_model = VQ_models[args.vq_model](
         
     | 
| 54 | 
         
            +
                    codebook_size=args.codebook_size,
         
     | 
| 55 | 
         
            +
                    codebook_embed_dim=args.codebook_embed_dim)
         
     | 
| 56 | 
         
            +
                vq_model.to(device)
         
     | 
| 57 | 
         
            +
                vq_model.eval()
         
     | 
| 58 | 
         
            +
                checkpoint = torch.load(args.vq_ckpt, map_location="cpu")
         
     | 
| 59 | 
         
            +
                vq_model.load_state_dict(checkpoint["model"])
         
     | 
| 60 | 
         
            +
                del checkpoint
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
                # create and load gpt model
         
     | 
| 63 | 
         
            +
                precision = {'none': torch.float32, 'bf16': torch.bfloat16, 'fp16': torch.float16}[args.precision]
         
     | 
| 64 | 
         
            +
                latent_size = args.image_size // args.downsample_size
         
     | 
| 65 | 
         
            +
                gpt_model = GPT_models[args.gpt_model](
         
     | 
| 66 | 
         
            +
                    vocab_size=args.codebook_size,
         
     | 
| 67 | 
         
            +
                    block_size=latent_size ** 2,
         
     | 
| 68 | 
         
            +
                    num_classes=args.num_classes,
         
     | 
| 69 | 
         
            +
                    cls_token_num=args.cls_token_num,
         
     | 
| 70 | 
         
            +
                    model_type=args.gpt_type,
         
     | 
| 71 | 
         
            +
                ).to(device=device, dtype=precision)
         
     | 
| 72 | 
         
            +
                checkpoint = torch.load(args.gpt_ckpt, map_location="cpu")
         
     | 
| 73 | 
         
            +
                if args.from_fsdp: # fsdp
         
     | 
| 74 | 
         
            +
                    model_weight = checkpoint
         
     | 
| 75 | 
         
            +
                elif "model" in checkpoint:  # ddp
         
     | 
| 76 | 
         
            +
                    model_weight = checkpoint["model"]
         
     | 
| 77 | 
         
            +
                elif "module" in checkpoint: # deepspeed
         
     | 
| 78 | 
         
            +
                    model_weight = checkpoint["module"]
         
     | 
| 79 | 
         
            +
                elif "state_dict" in checkpoint:
         
     | 
| 80 | 
         
            +
                    model_weight = checkpoint["state_dict"]
         
     | 
| 81 | 
         
            +
                else:
         
     | 
| 82 | 
         
            +
                    raise Exception("please check model weight, maybe add --from-fsdp to run command")
         
     | 
| 83 | 
         
            +
                # if 'freqs_cis' in model_weight:
         
     | 
| 84 | 
         
            +
                #     model_weight.pop('freqs_cis')
         
     | 
| 85 | 
         
            +
                gpt_model.load_state_dict(model_weight, strict=False)
         
     | 
| 86 | 
         
            +
                gpt_model.eval()
         
     | 
| 87 | 
         
            +
                del checkpoint
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
                if args.compile:
         
     | 
| 90 | 
         
            +
                    print(f"compiling the model...")
         
     | 
| 91 | 
         
            +
                    gpt_model = torch.compile(
         
     | 
| 92 | 
         
            +
                        gpt_model,
         
     | 
| 93 | 
         
            +
                        mode="reduce-overhead",
         
     | 
| 94 | 
         
            +
                        fullgraph=True
         
     | 
| 95 | 
         
            +
                    ) # requires PyTorch 2.0 (optional)
         
     | 
| 96 | 
         
            +
                else:
         
     | 
| 97 | 
         
            +
                    print(f"no model compile") 
         
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
                # Create folder to save samples:
         
     | 
| 100 | 
         
            +
                model_string_name = args.gpt_model.replace("/", "-")
         
     | 
| 101 | 
         
            +
                if args.from_fsdp:
         
     | 
| 102 | 
         
            +
                    ckpt_string_name = args.gpt_ckpt.split('/')[-2]
         
     | 
| 103 | 
         
            +
                else:
         
     | 
| 104 | 
         
            +
                    ckpt_string_name = os.path.basename(args.gpt_ckpt).replace(".pth", "").replace(".pt", "")
         
     | 
| 105 | 
         
            +
                folder_name = f"{model_string_name}-{ckpt_string_name}-size-{args.image_size}-size-{args.image_size_eval}-{args.vq_model}-" \
         
     | 
| 106 | 
         
            +
                              f"topk-{args.top_k}-topp-{args.top_p}-temperature-{args.temperature}-" \
         
     | 
| 107 | 
         
            +
                              f"cfg-{args.cfg_scale}-seed-{args.global_seed}"
         
     | 
| 108 | 
         
            +
                sample_folder_dir = f"{args.sample_dir}/{folder_name}"
         
     | 
| 109 | 
         
            +
                if rank == 0:
         
     | 
| 110 | 
         
            +
                    os.makedirs(sample_folder_dir, exist_ok=True)
         
     | 
| 111 | 
         
            +
                    print(f"Saving .png samples at {sample_folder_dir}")
         
     | 
| 112 | 
         
            +
                dist.barrier()
         
     | 
| 113 | 
         
            +
             
     | 
| 114 | 
         
            +
                # Figure out how many samples we need to generate on each GPU and how many iterations we need to run:
         
     | 
| 115 | 
         
            +
                n = args.per_proc_batch_size
         
     | 
| 116 | 
         
            +
                global_batch_size = n * dist.get_world_size()
         
     | 
| 117 | 
         
            +
                # To make things evenly-divisible, we'll sample a bit more than we need and then discard the extra samples:
         
     | 
| 118 | 
         
            +
                total_samples = int(math.ceil(args.num_fid_samples / global_batch_size) * global_batch_size)
         
     | 
| 119 | 
         
            +
                if rank == 0:
         
     | 
| 120 | 
         
            +
                    print(f"Total number of images that will be sampled: {total_samples}")
         
     | 
| 121 | 
         
            +
                assert total_samples % dist.get_world_size() == 0, "total_samples must be divisible by world_size"
         
     | 
| 122 | 
         
            +
                samples_needed_this_gpu = int(total_samples // dist.get_world_size())
         
     | 
| 123 | 
         
            +
                assert samples_needed_this_gpu % n == 0, "samples_needed_this_gpu must be divisible by the per-GPU batch size"
         
     | 
| 124 | 
         
            +
                iterations = int(samples_needed_this_gpu // n)
         
     | 
| 125 | 
         
            +
                pbar = range(iterations)
         
     | 
| 126 | 
         
            +
                pbar = tqdm(pbar) if rank == 0 else pbar
         
     | 
| 127 | 
         
            +
                total = 0
         
     | 
| 128 | 
         
            +
                for _ in pbar:
         
     | 
| 129 | 
         
            +
                    # Sample inputs:
         
     | 
| 130 | 
         
            +
                    c_indices = torch.randint(0, args.num_classes, (n,), device=device)
         
     | 
| 131 | 
         
            +
                    qzshape = [len(c_indices), args.codebook_embed_dim, latent_size, latent_size]
         
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
                    index_sample = generate(
         
     | 
| 134 | 
         
            +
                        gpt_model, c_indices, latent_size ** 2,
         
     | 
| 135 | 
         
            +
                        cfg_scale=args.cfg_scale, cfg_interval=args.cfg_interval,
         
     | 
| 136 | 
         
            +
                        temperature=args.temperature, top_k=args.top_k,
         
     | 
| 137 | 
         
            +
                        top_p=args.top_p, sample_logits=True, 
         
     | 
| 138 | 
         
            +
                        )
         
     | 
| 139 | 
         
            +
                    
         
     | 
| 140 | 
         
            +
                    samples = vq_model.decode_code(index_sample, qzshape) # output value is between [-1, 1]
         
     | 
| 141 | 
         
            +
                    if args.image_size_eval != args.image_size:
         
     | 
| 142 | 
         
            +
                        samples = F.interpolate(samples, size=(args.image_size_eval, args.image_size_eval), mode='bicubic')
         
     | 
| 143 | 
         
            +
                    samples = torch.clamp(127.5 * samples + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
         
     | 
| 144 | 
         
            +
                    
         
     | 
| 145 | 
         
            +
                    # Save samples to disk as individual .png files
         
     | 
| 146 | 
         
            +
                    for i, sample in enumerate(samples):
         
     | 
| 147 | 
         
            +
                        index = i * dist.get_world_size() + rank + total
         
     | 
| 148 | 
         
            +
                        Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png")
         
     | 
| 149 | 
         
            +
                    total += global_batch_size
         
     | 
| 150 | 
         
            +
             
     | 
| 151 | 
         
            +
                # Make sure all processes have finished saving their samples before attempting to convert to .npz
         
     | 
| 152 | 
         
            +
                dist.barrier()
         
     | 
| 153 | 
         
            +
                if rank == 0:
         
     | 
| 154 | 
         
            +
                    create_npz_from_sample_folder(sample_folder_dir, args.num_fid_samples)
         
     | 
| 155 | 
         
            +
                    print("Done.")
         
     | 
| 156 | 
         
            +
                dist.barrier()
         
     | 
| 157 | 
         
            +
                dist.destroy_process_group()
         
     | 
| 158 | 
         
            +
             
     | 
| 159 | 
         
            +
             
     | 
| 160 | 
         
            +
             
     | 
| 161 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 162 | 
         
            +
                parser = argparse.ArgumentParser()
         
     | 
| 163 | 
         
            +
                parser.add_argument("--gpt-model", type=str, choices=list(GPT_models.keys()), default="GPT-B")
         
     | 
| 164 | 
         
            +
                parser.add_argument("--gpt-ckpt", type=str, default=None)
         
     | 
| 165 | 
         
            +
                parser.add_argument("--gpt-type", type=str, choices=['c2i', 't2i'], default="c2i", help="class-conditional or text-conditional")
         
     | 
| 166 | 
         
            +
                parser.add_argument("--from-fsdp", action='store_true')
         
     | 
| 167 | 
         
            +
                parser.add_argument("--cls-token-num", type=int, default=1, help="max token number of condition input")
         
     | 
| 168 | 
         
            +
                parser.add_argument("--precision", type=str, default='bf16', choices=["none", "fp16", "bf16"]) 
         
     | 
| 169 | 
         
            +
                parser.add_argument("--compile", action='store_true', default=True)
         
     | 
| 170 | 
         
            +
                parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16")
         
     | 
| 171 | 
         
            +
                parser.add_argument("--vq-ckpt", type=str, default=None, help="ckpt path for vq model")
         
     | 
| 172 | 
         
            +
                parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization")
         
     | 
| 173 | 
         
            +
                parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization")
         
     | 
| 174 | 
         
            +
                parser.add_argument("--image-size", type=int, choices=[256, 384, 512], default=384)
         
     | 
| 175 | 
         
            +
                parser.add_argument("--image-size-eval", type=int, choices=[256, 384, 512], default=256)
         
     | 
| 176 | 
         
            +
                parser.add_argument("--downsample-size", type=int, choices=[8, 16], default=16)
         
     | 
| 177 | 
         
            +
                parser.add_argument("--num-classes", type=int, default=1000)
         
     | 
| 178 | 
         
            +
                parser.add_argument("--cfg-scale",  type=float, default=1.5)
         
     | 
| 179 | 
         
            +
                parser.add_argument("--cfg-interval", type=float, default=-1)
         
     | 
| 180 | 
         
            +
                parser.add_argument("--sample-dir", type=str, default="samples")
         
     | 
| 181 | 
         
            +
                parser.add_argument("--per-proc-batch-size", type=int, default=32)
         
     | 
| 182 | 
         
            +
                parser.add_argument("--num-fid-samples", type=int, default=5000)
         
     | 
| 183 | 
         
            +
                parser.add_argument("--global-seed", type=int, default=0)
         
     | 
| 184 | 
         
            +
                parser.add_argument("--top-k", type=int, default=0,help="top-k value to sample with")
         
     | 
| 185 | 
         
            +
                parser.add_argument("--temperature", type=float, default=1.0, help="temperature value to sample with")
         
     | 
| 186 | 
         
            +
                parser.add_argument("--top-p", type=float, default=1.0, help="top-p value to sample with")
         
     | 
| 187 | 
         
            +
                args = parser.parse_args()
         
     | 
| 188 | 
         
            +
                main(args)
         
     | 
    	
        autoregressive/sample/sample_t2i.py
    ADDED
    
    | 
         @@ -0,0 +1,215 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import torch
         
     | 
| 2 | 
         
            +
            torch.backends.cuda.matmul.allow_tf32 = True
         
     | 
| 3 | 
         
            +
            torch.backends.cudnn.allow_tf32 = True
         
     | 
| 4 | 
         
            +
            torch.set_float32_matmul_precision('high')
         
     | 
| 5 | 
         
            +
            setattr(torch.nn.Linear, 'reset_parameters', lambda self: None)     # disable default parameter init for faster speed
         
     | 
| 6 | 
         
            +
            setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None)  # disable default parameter init for faster speed
         
     | 
| 7 | 
         
            +
            from torchvision.utils import save_image
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            import os
         
     | 
| 10 | 
         
            +
            import sys
         
     | 
| 11 | 
         
            +
            current_directory = os.getcwd()
         
     | 
| 12 | 
         
            +
            sys.path.append(current_directory)
         
     | 
| 13 | 
         
            +
            import time
         
     | 
| 14 | 
         
            +
            import argparse
         
     | 
| 15 | 
         
            +
            from tokenizer.tokenizer_image.vq_model import VQ_models
         
     | 
| 16 | 
         
            +
            from language.t5 import T5Embedder
         
     | 
| 17 | 
         
            +
            from autoregressive.models.gpt import GPT_models
         
     | 
| 18 | 
         
            +
            from autoregressive.models.gpt_t2i import GPT_models
         
     | 
| 19 | 
         
            +
            from autoregressive.models.generate import generate
         
     | 
| 20 | 
         
            +
            os.environ["TOKENIZERS_PARALLELISM"] = "false"
         
     | 
| 21 | 
         
            +
            from dataset.t2i_control import build_t2i_control_code
         
     | 
| 22 | 
         
            +
            from accelerate import Accelerator
         
     | 
| 23 | 
         
            +
            from dataset.build import build_dataset
         
     | 
| 24 | 
         
            +
            from pathlib import Path
         
     | 
| 25 | 
         
            +
            from accelerate.utils import ProjectConfiguration, set_seed
         
     | 
| 26 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 27 | 
         
            +
            from condition.canny import CannyDetector
         
     | 
| 28 | 
         
            +
            from condition.hed import HEDdetector
         
     | 
| 29 | 
         
            +
            import numpy as np
         
     | 
| 30 | 
         
            +
            from PIL import Image
         
     | 
| 31 | 
         
            +
            from condition.lineart import LineArt
         
     | 
| 32 | 
         
            +
            import cv2
         
     | 
| 33 | 
         
            +
            from transformers import DPTImageProcessor, DPTForDepthEstimation
         
     | 
| 34 | 
         
            +
            def main(args):
         
     | 
| 35 | 
         
            +
                # Setup PyTorch:
         
     | 
| 36 | 
         
            +
                torch.manual_seed(args.seed)
         
     | 
| 37 | 
         
            +
                torch.backends.cudnn.deterministic = True
         
     | 
| 38 | 
         
            +
                torch.backends.cudnn.benchmark = False
         
     | 
| 39 | 
         
            +
                torch.set_grad_enabled(False)
         
     | 
| 40 | 
         
            +
                device = "cuda" if torch.cuda.is_available() else "cpu"
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
                # create and load model
         
     | 
| 43 | 
         
            +
                vq_model = VQ_models[args.vq_model](
         
     | 
| 44 | 
         
            +
                    codebook_size=args.codebook_size,
         
     | 
| 45 | 
         
            +
                    codebook_embed_dim=args.codebook_embed_dim)
         
     | 
| 46 | 
         
            +
                vq_model.to(device)
         
     | 
| 47 | 
         
            +
                vq_model.eval()
         
     | 
| 48 | 
         
            +
                checkpoint = torch.load(args.vq_ckpt, map_location="cpu")
         
     | 
| 49 | 
         
            +
                vq_model.load_state_dict(checkpoint["model"])
         
     | 
| 50 | 
         
            +
                del checkpoint
         
     | 
| 51 | 
         
            +
                print(f"image tokenizer is loaded")
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
                # create and load gpt model
         
     | 
| 54 | 
         
            +
                precision = {'none': torch.float32, 'bf16': torch.bfloat16, 'fp16': torch.float16}[args.precision]
         
     | 
| 55 | 
         
            +
                latent_size = args.image_size // args.downsample_size
         
     | 
| 56 | 
         
            +
                gpt_model = GPT_models[args.gpt_model](
         
     | 
| 57 | 
         
            +
                    block_size=latent_size ** 2,
         
     | 
| 58 | 
         
            +
                    cls_token_num=args.cls_token_num,
         
     | 
| 59 | 
         
            +
                    model_type=args.gpt_type,
         
     | 
| 60 | 
         
            +
                    condition_type=args.condition_type,
         
     | 
| 61 | 
         
            +
                ).to(device=device, dtype=precision)
         
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
                _, file_extension = os.path.splitext(args.gpt_ckpt)
         
     | 
| 64 | 
         
            +
                if file_extension.lower() == '.safetensors':
         
     | 
| 65 | 
         
            +
                    from safetensors.torch import load_file
         
     | 
| 66 | 
         
            +
                    model_weight = load_file(args.gpt_ckpt)
         
     | 
| 67 | 
         
            +
                    gpt_model.load_state_dict(model_weight, strict=False)
         
     | 
| 68 | 
         
            +
                    gpt_model.eval()
         
     | 
| 69 | 
         
            +
                else:
         
     | 
| 70 | 
         
            +
                    checkpoint = torch.load(args.gpt_ckpt, map_location="cpu")
         
     | 
| 71 | 
         
            +
                    if "model" in checkpoint:  # ddp
         
     | 
| 72 | 
         
            +
                        model_weight = checkpoint["model"]
         
     | 
| 73 | 
         
            +
                    elif "module" in checkpoint: # deepspeed
         
     | 
| 74 | 
         
            +
                        model_weight = checkpoint["module"]
         
     | 
| 75 | 
         
            +
                    elif "state_dict" in checkpoint:
         
     | 
| 76 | 
         
            +
                        model_weight = checkpoint["state_dict"]
         
     | 
| 77 | 
         
            +
                    else:
         
     | 
| 78 | 
         
            +
                        raise Exception("please check model weight")
         
     | 
| 79 | 
         
            +
                    gpt_model.load_state_dict(model_weight, strict=False)
         
     | 
| 80 | 
         
            +
                    gpt_model.eval()
         
     | 
| 81 | 
         
            +
                    del checkpoint
         
     | 
| 82 | 
         
            +
                print(f"gpt model is loaded")
         
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
                if args.compile:
         
     | 
| 85 | 
         
            +
                    print(f"compiling the model...")
         
     | 
| 86 | 
         
            +
                    gpt_model = torch.compile(
         
     | 
| 87 | 
         
            +
                        gpt_model,
         
     | 
| 88 | 
         
            +
                        mode="reduce-overhead",
         
     | 
| 89 | 
         
            +
                        fullgraph=True
         
     | 
| 90 | 
         
            +
                    ) # requires PyTorch 2.0 (optional)
         
     | 
| 91 | 
         
            +
                else:
         
     | 
| 92 | 
         
            +
                    print(f"no need to compile model in demo") 
         
     | 
| 93 | 
         
            +
                
         
     | 
| 94 | 
         
            +
                assert os.path.exists(args.t5_path)
         
     | 
| 95 | 
         
            +
                t5_model = T5Embedder(
         
     | 
| 96 | 
         
            +
                    device=device, 
         
     | 
| 97 | 
         
            +
                    local_cache=True, 
         
     | 
| 98 | 
         
            +
                    cache_dir=args.t5_path, 
         
     | 
| 99 | 
         
            +
                    dir_or_name=args.t5_model_type,
         
     | 
| 100 | 
         
            +
                    torch_dtype=precision,
         
     | 
| 101 | 
         
            +
                    model_max_length=args.t5_feature_max_len,
         
     | 
| 102 | 
         
            +
                )
         
     | 
| 103 | 
         
            +
                
         
     | 
| 104 | 
         
            +
             
     | 
| 105 | 
         
            +
                if args.condition_type == 'canny':
         
     | 
| 106 | 
         
            +
                    get_control = CannyDetector()
         
     | 
| 107 | 
         
            +
                elif args.condition_type == 'hed':
         
     | 
| 108 | 
         
            +
                    get_control = HEDdetector().to(device).eval()
         
     | 
| 109 | 
         
            +
                elif args.condition_type == 'lineart':
         
     | 
| 110 | 
         
            +
                    get_control = LineArt()
         
     | 
| 111 | 
         
            +
                    get_control.load_state_dict(torch.load('condition/ckpts/model.pth', map_location=torch.device('cpu')))
         
     | 
| 112 | 
         
            +
                    get_control.to(device)
         
     | 
| 113 | 
         
            +
                elif args.condition_type == 'depth':
         
     | 
| 114 | 
         
            +
                    processor = DPTImageProcessor.from_pretrained("condition/ckpts/dpt_large")
         
     | 
| 115 | 
         
            +
                    model = DPTForDepthEstimation.from_pretrained("condition/ckpts/dpt_large").to(device)
         
     | 
| 116 | 
         
            +
                with torch.no_grad():
         
     | 
| 117 | 
         
            +
                    
         
     | 
| 118 | 
         
            +
                    condition_path = args.condition_path
         
     | 
| 119 | 
         
            +
                    if args.condition_type == 'seg':
         
     | 
| 120 | 
         
            +
                        condition_img = torch.from_numpy(np.array(Image.open(condition_path)))
         
     | 
| 121 | 
         
            +
                        condition_img = condition_img.permute(2,0,1).unsqueeze(0).repeat(2,1,1,1)
         
     | 
| 122 | 
         
            +
                    elif args.condition_type == 'canny':
         
     | 
| 123 | 
         
            +
                        condition_img = get_control(np.array(Image.open(condition_path)))
         
     | 
| 124 | 
         
            +
                        condition_img = torch.from_numpy(condition_img[None,None,...]).repeat(2,3,1,1)
         
     | 
| 125 | 
         
            +
                    elif args.condition_type == 'hed':
         
     | 
| 126 | 
         
            +
                        condition_img = get_control(torch.from_numpy(np.array(Image.open(condition_path))).permute(2,0,1).unsqueeze(0).to(device))
         
     | 
| 127 | 
         
            +
                        condition_img = condition_img.unsqueeze(1).repeat(2,3,1,1)
         
     | 
| 128 | 
         
            +
                    elif args.condition_type == 'lineart':
         
     | 
| 129 | 
         
            +
                        condition_img = get_control(torch.from_numpy(np.array(Image.open(condition_path))).permute(2,0,1).unsqueeze(0).to(device).float())
         
     | 
| 130 | 
         
            +
                        condition_img = condition_img.repeat(2,3,1,1) * 255
         
     | 
| 131 | 
         
            +
                    elif args.condition_type == 'depth':
         
     | 
| 132 | 
         
            +
                        images = Image.open(condition_path)
         
     | 
| 133 | 
         
            +
                        inputs = processor(images=images, return_tensors="pt", size=(512,512)).to(device)
         
     | 
| 134 | 
         
            +
                        outputs = model(**inputs)
         
     | 
| 135 | 
         
            +
                        condition_img = outputs.predicted_depth
         
     | 
| 136 | 
         
            +
                        condition_img = condition_img.unsqueeze(0).repeat(2,3,1,1)
         
     | 
| 137 | 
         
            +
                        condition_img = (condition_img * 255 / condition_img.max())
         
     | 
| 138 | 
         
            +
                    condition_img = condition_img.to(device)
         
     | 
| 139 | 
         
            +
                    condition_img = 2*(condition_img/255 - 0.5)
         
     | 
| 140 | 
         
            +
                    prompts = [args.prompt if args.prompt is not None else "a high-quality image"]
         
     | 
| 141 | 
         
            +
                    prompts = prompts * 2
         
     | 
| 142 | 
         
            +
                    caption_embs, emb_masks = t5_model.get_text_embeddings(prompts)
         
     | 
| 143 | 
         
            +
             
     | 
| 144 | 
         
            +
                    if not args.no_left_padding:
         
     | 
| 145 | 
         
            +
                        print(f"processing left-padding...")    
         
     | 
| 146 | 
         
            +
                        # a naive way to implement left-padding
         
     | 
| 147 | 
         
            +
                        new_emb_masks = torch.flip(emb_masks, dims=[-1])
         
     | 
| 148 | 
         
            +
                        new_caption_embs = []
         
     | 
| 149 | 
         
            +
                        for idx, (caption_emb, emb_mask) in enumerate(zip(caption_embs, emb_masks)):
         
     | 
| 150 | 
         
            +
                            valid_num = int(emb_mask.sum().item())
         
     | 
| 151 | 
         
            +
                            print(f'  prompt {idx} token len: {valid_num}')
         
     | 
| 152 | 
         
            +
                            new_caption_emb = torch.cat([caption_emb[valid_num:],caption_emb[:valid_num]])
         
     | 
| 153 | 
         
            +
                            new_caption_embs.append(new_caption_emb)
         
     | 
| 154 | 
         
            +
                        new_caption_embs = torch.stack(new_caption_embs)
         
     | 
| 155 | 
         
            +
                    else:
         
     | 
| 156 | 
         
            +
                        new_caption_embs, new_emb_masks = caption_embs, emb_masks
         
     | 
| 157 | 
         
            +
                    c_indices = new_caption_embs * new_emb_masks[:,:, None]
         
     | 
| 158 | 
         
            +
                    c_emb_masks = new_emb_masks
         
     | 
| 159 | 
         
            +
                    qzshape = [len(c_indices), args.codebook_embed_dim, args.image_H//args.downsample_size, args.image_W//args.downsample_size]
         
     | 
| 160 | 
         
            +
                    t1 = time.time()
         
     | 
| 161 | 
         
            +
                    index_sample = generate(
         
     | 
| 162 | 
         
            +
                        gpt_model, c_indices, (args.image_H//args.downsample_size)*(args.image_W//args.downsample_size),#latent_size ** 2, 
         
     | 
| 163 | 
         
            +
                        c_emb_masks, condition=condition_img.to(precision),
         
     | 
| 164 | 
         
            +
                        cfg_scale=args.cfg_scale,
         
     | 
| 165 | 
         
            +
                        temperature=args.temperature, top_k=args.top_k,
         
     | 
| 166 | 
         
            +
                        top_p=args.top_p, sample_logits=True, 
         
     | 
| 167 | 
         
            +
                        )
         
     | 
| 168 | 
         
            +
                    sampling_time = time.time() - t1
         
     | 
| 169 | 
         
            +
                    print(f"Full sampling takes about {sampling_time:.2f} seconds.")    
         
     | 
| 170 | 
         
            +
                    
         
     | 
| 171 | 
         
            +
                    t2 = time.time()
         
     | 
| 172 | 
         
            +
                    print(index_sample.shape)
         
     | 
| 173 | 
         
            +
                    samples = vq_model.decode_code(index_sample, qzshape) # output value is between [-1, 1]
         
     | 
| 174 | 
         
            +
                    decoder_time = time.time() - t2
         
     | 
| 175 | 
         
            +
                    print(f"decoder takes about {decoder_time:.2f} seconds.")
         
     | 
| 176 | 
         
            +
             
     | 
| 177 | 
         
            +
                    samples = torch.cat((condition_img[0:1], samples), dim=0)
         
     | 
| 178 | 
         
            +
                    save_image(samples, f"sample/example/sample_t2i_{args.condition_type}.png", nrow=4, normalize=True, value_range=(-1, 1))
         
     | 
| 179 | 
         
            +
                    print(f"image is saved to sample/example/sample_t2i_{args.condition_type}.png")
         
     | 
| 180 | 
         
            +
                    print(prompts)
         
     | 
| 181 | 
         
            +
             
     | 
| 182 | 
         
            +
             
     | 
| 183 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 184 | 
         
            +
                parser = argparse.ArgumentParser()
         
     | 
| 185 | 
         
            +
                parser.add_argument("--t5-path", type=str, default='checkpoints/t5-ckpt')
         
     | 
| 186 | 
         
            +
                parser.add_argument("--t5-model-type", type=str, default='flan-t5-xl')
         
     | 
| 187 | 
         
            +
                parser.add_argument("--t5-feature-max-len", type=int, default=120)
         
     | 
| 188 | 
         
            +
                parser.add_argument("--t5-feature-dim", type=int, default=2048)
         
     | 
| 189 | 
         
            +
                parser.add_argument("--no-left-padding", action='store_true', default=False)
         
     | 
| 190 | 
         
            +
                parser.add_argument("--gpt-model", type=str, choices=list(GPT_models.keys()), default="GPT-XL")
         
     | 
| 191 | 
         
            +
                parser.add_argument("--gpt-ckpt", type=str, default=None)
         
     | 
| 192 | 
         
            +
                parser.add_argument("--gpt-type", type=str, choices=['c2i', 't2i'], default="t2i", help="class->image or text->image")  
         
     | 
| 193 | 
         
            +
                parser.add_argument("--cls-token-num", type=int, default=120, help="max token number of condition input")
         
     | 
| 194 | 
         
            +
                parser.add_argument("--precision", type=str, default='bf16', choices=["none", "fp16", "bf16"]) 
         
     | 
| 195 | 
         
            +
                parser.add_argument("--compile", action='store_true', default=False)
         
     | 
| 196 | 
         
            +
                parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16")
         
     | 
| 197 | 
         
            +
                parser.add_argument("--vq-ckpt", type=str, default=None, help="ckpt path for vq model")
         
     | 
| 198 | 
         
            +
                parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization")
         
     | 
| 199 | 
         
            +
                parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization")
         
     | 
| 200 | 
         
            +
                parser.add_argument("--image-size", type=int, choices=[256, 320, 384, 400, 448, 512, 576, 640, 704, 768], default=768)
         
     | 
| 201 | 
         
            +
                parser.add_argument("--image-H", type=int, default=512)
         
     | 
| 202 | 
         
            +
                parser.add_argument("--image-W", type=int, default=512)
         
     | 
| 203 | 
         
            +
                parser.add_argument("--downsample-size", type=int, choices=[8, 16], default=16)
         
     | 
| 204 | 
         
            +
                parser.add_argument("--cfg-scale", type=float, default=4)
         
     | 
| 205 | 
         
            +
                parser.add_argument("--seed", type=int, default=0)
         
     | 
| 206 | 
         
            +
                parser.add_argument("--top-k", type=int, default=2000, help="top-k value to sample with")
         
     | 
| 207 | 
         
            +
                parser.add_argument("--temperature", type=float, default=1.0, help="temperature value to sample with")
         
     | 
| 208 | 
         
            +
                parser.add_argument("--top-p", type=float, default=1.0, help="top-p value to sample with")
         
     | 
| 209 | 
         
            +
             
     | 
| 210 | 
         
            +
                parser.add_argument("--mixed-precision", type=str, default='bf16', choices=["none", "fp16", "bf16"]) 
         
     | 
| 211 | 
         
            +
                parser.add_argument("--condition-type", type=str, choices=['seg', 'canny', 'hed', 'lineart', 'depth'], default="canny")
         
     | 
| 212 | 
         
            +
                parser.add_argument("--prompt", type=str, default='a high-quality image')
         
     | 
| 213 | 
         
            +
                parser.add_argument("--condition-path", type=str, default='condition/example/t2i/multigen/landscape.png')
         
     | 
| 214 | 
         
            +
                args = parser.parse_args()
         
     | 
| 215 | 
         
            +
                main(args)
         
     | 
    	
        autoregressive/sample/sample_t2i_MR.py
    ADDED
    
    | 
         @@ -0,0 +1,237 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import torch
         
     | 
| 2 | 
         
            +
            torch.backends.cuda.matmul.allow_tf32 = True
         
     | 
| 3 | 
         
            +
            torch.backends.cudnn.allow_tf32 = True
         
     | 
| 4 | 
         
            +
            torch.set_float32_matmul_precision('high')
         
     | 
| 5 | 
         
            +
            setattr(torch.nn.Linear, 'reset_parameters', lambda self: None)     # disable default parameter init for faster speed
         
     | 
| 6 | 
         
            +
            setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None)  # disable default parameter init for faster speed
         
     | 
| 7 | 
         
            +
            from torchvision.utils import save_image
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            import os
         
     | 
| 10 | 
         
            +
            import sys
         
     | 
| 11 | 
         
            +
            current_directory = os.getcwd()
         
     | 
| 12 | 
         
            +
            sys.path.append(current_directory)
         
     | 
| 13 | 
         
            +
            import time
         
     | 
| 14 | 
         
            +
            import argparse
         
     | 
| 15 | 
         
            +
            from tokenizer.tokenizer_image.vq_model import VQ_models
         
     | 
| 16 | 
         
            +
            from language.t5 import T5Embedder
         
     | 
| 17 | 
         
            +
            from autoregressive.models.gpt_t2i import GPT_models
         
     | 
| 18 | 
         
            +
            from autoregressive.models.generate import generate
         
     | 
| 19 | 
         
            +
            os.environ["TOKENIZERS_PARALLELISM"] = "false"
         
     | 
| 20 | 
         
            +
            from dataset.t2i_control import build_t2i_control_code
         
     | 
| 21 | 
         
            +
            from accelerate import Accelerator
         
     | 
| 22 | 
         
            +
            from dataset.build import build_dataset
         
     | 
| 23 | 
         
            +
            from pathlib import Path
         
     | 
| 24 | 
         
            +
            from accelerate.utils import ProjectConfiguration, set_seed
         
     | 
| 25 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 26 | 
         
            +
            from condition.canny import CannyDetector
         
     | 
| 27 | 
         
            +
            from condition.hed import HEDdetector
         
     | 
| 28 | 
         
            +
            import numpy as np
         
     | 
| 29 | 
         
            +
            from PIL import Image
         
     | 
| 30 | 
         
            +
            from condition.lineart import LineArt
         
     | 
| 31 | 
         
            +
            import cv2
         
     | 
| 32 | 
         
            +
            from transformers import DPTImageProcessor, DPTForDepthEstimation
         
     | 
| 33 | 
         
            +
            from condition.midas.depth import MidasDetector
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
            def resize_image_to_16_multiple(image_path, condition_type='seg'):
         
     | 
| 37 | 
         
            +
                image = Image.open(image_path)
         
     | 
| 38 | 
         
            +
                width, height = image.size
         
     | 
| 39 | 
         
            +
                
         
     | 
| 40 | 
         
            +
                if condition_type == 'depth':  # The depth model requires a side length that is a multiple of 32
         
     | 
| 41 | 
         
            +
                    new_width = (width + 31) // 32 * 32
         
     | 
| 42 | 
         
            +
                    new_height = (height + 31) // 32 * 32
         
     | 
| 43 | 
         
            +
                else:
         
     | 
| 44 | 
         
            +
                    new_width = (width + 15) // 16 * 16
         
     | 
| 45 | 
         
            +
                    new_height = (height + 15) // 16 * 16
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
                resized_image = image.resize((new_width, new_height))
         
     | 
| 48 | 
         
            +
                return resized_image
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
            def main(args):
         
     | 
| 51 | 
         
            +
                # Setup PyTorch:
         
     | 
| 52 | 
         
            +
                torch.manual_seed(args.seed)
         
     | 
| 53 | 
         
            +
                torch.backends.cudnn.deterministic = True
         
     | 
| 54 | 
         
            +
                torch.backends.cudnn.benchmark = False
         
     | 
| 55 | 
         
            +
                torch.set_grad_enabled(False)
         
     | 
| 56 | 
         
            +
                device = "cuda" if torch.cuda.is_available() else "cpu"
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
                # create and load model
         
     | 
| 59 | 
         
            +
                vq_model = VQ_models[args.vq_model](
         
     | 
| 60 | 
         
            +
                    codebook_size=args.codebook_size,
         
     | 
| 61 | 
         
            +
                    codebook_embed_dim=args.codebook_embed_dim)
         
     | 
| 62 | 
         
            +
                vq_model.to(device)
         
     | 
| 63 | 
         
            +
                vq_model.eval()
         
     | 
| 64 | 
         
            +
                checkpoint = torch.load(args.vq_ckpt, map_location="cpu")
         
     | 
| 65 | 
         
            +
                vq_model.load_state_dict(checkpoint["model"])
         
     | 
| 66 | 
         
            +
                del checkpoint
         
     | 
| 67 | 
         
            +
                print(f"image tokenizer is loaded")
         
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
            +
                # create and load gpt model
         
     | 
| 70 | 
         
            +
                precision = {'none': torch.float32, 'bf16': torch.bfloat16, 'fp16': torch.float16}[args.precision]
         
     | 
| 71 | 
         
            +
                latent_size = args.image_size // args.downsample_size
         
     | 
| 72 | 
         
            +
                gpt_model = GPT_models[args.gpt_model](
         
     | 
| 73 | 
         
            +
                    block_size=latent_size ** 2,
         
     | 
| 74 | 
         
            +
                    cls_token_num=args.cls_token_num,
         
     | 
| 75 | 
         
            +
                    model_type=args.gpt_type,
         
     | 
| 76 | 
         
            +
                    condition_type=args.condition_type,
         
     | 
| 77 | 
         
            +
                ).to(device=device, dtype=precision)
         
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
                _, file_extension = os.path.splitext(args.gpt_ckpt)
         
     | 
| 80 | 
         
            +
                if file_extension.lower() == '.safetensors':
         
     | 
| 81 | 
         
            +
                    from safetensors.torch import load_file
         
     | 
| 82 | 
         
            +
                    model_weight = load_file(args.gpt_ckpt)
         
     | 
| 83 | 
         
            +
                    gpt_model.load_state_dict(model_weight, strict=False)
         
     | 
| 84 | 
         
            +
                    gpt_model.eval()
         
     | 
| 85 | 
         
            +
                else:
         
     | 
| 86 | 
         
            +
                    checkpoint = torch.load(args.gpt_ckpt, map_location="cpu")
         
     | 
| 87 | 
         
            +
                    if "model" in checkpoint:  # ddp
         
     | 
| 88 | 
         
            +
                        model_weight = checkpoint["model"]
         
     | 
| 89 | 
         
            +
                    elif "module" in checkpoint: # deepspeed
         
     | 
| 90 | 
         
            +
                        model_weight = checkpoint["module"]
         
     | 
| 91 | 
         
            +
                    elif "state_dict" in checkpoint:
         
     | 
| 92 | 
         
            +
                        model_weight = checkpoint["state_dict"]
         
     | 
| 93 | 
         
            +
                    else:
         
     | 
| 94 | 
         
            +
                        raise Exception("please check model weight")
         
     | 
| 95 | 
         
            +
                    gpt_model.load_state_dict(model_weight, strict=False)
         
     | 
| 96 | 
         
            +
                    gpt_model.eval()
         
     | 
| 97 | 
         
            +
                    del checkpoint
         
     | 
| 98 | 
         
            +
                print(f"gpt model is loaded")
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
                if args.compile:
         
     | 
| 101 | 
         
            +
                    print(f"compiling the model...")
         
     | 
| 102 | 
         
            +
                    gpt_model = torch.compile(
         
     | 
| 103 | 
         
            +
                        gpt_model,
         
     | 
| 104 | 
         
            +
                        mode="reduce-overhead",
         
     | 
| 105 | 
         
            +
                        fullgraph=True
         
     | 
| 106 | 
         
            +
                    ) # requires PyTorch 2.0 (optional)
         
     | 
| 107 | 
         
            +
                else:
         
     | 
| 108 | 
         
            +
                    print(f"no need to compile model in demo") 
         
     | 
| 109 | 
         
            +
                
         
     | 
| 110 | 
         
            +
                assert os.path.exists(args.t5_path)
         
     | 
| 111 | 
         
            +
                t5_model = T5Embedder(
         
     | 
| 112 | 
         
            +
                    device=device, 
         
     | 
| 113 | 
         
            +
                    local_cache=True, 
         
     | 
| 114 | 
         
            +
                    cache_dir=args.t5_path, 
         
     | 
| 115 | 
         
            +
                    dir_or_name=args.t5_model_type,
         
     | 
| 116 | 
         
            +
                    torch_dtype=precision,
         
     | 
| 117 | 
         
            +
                    model_max_length=args.t5_feature_max_len,
         
     | 
| 118 | 
         
            +
                )
         
     | 
| 119 | 
         
            +
                
         
     | 
| 120 | 
         
            +
             
     | 
| 121 | 
         
            +
                if args.condition_type == 'canny':
         
     | 
| 122 | 
         
            +
                    get_control = CannyDetector()
         
     | 
| 123 | 
         
            +
                elif args.condition_type == 'hed':
         
     | 
| 124 | 
         
            +
                    get_control = HEDdetector().to(device).eval()
         
     | 
| 125 | 
         
            +
                elif args.condition_type == 'lineart':
         
     | 
| 126 | 
         
            +
                    get_control = LineArt()
         
     | 
| 127 | 
         
            +
                    get_control.load_state_dict(torch.load('condition/ckpts/model.pth', map_location=torch.device('cpu')))
         
     | 
| 128 | 
         
            +
                    get_control.to(device)
         
     | 
| 129 | 
         
            +
                elif args.condition_type == 'depth':
         
     | 
| 130 | 
         
            +
                    processor = DPTImageProcessor.from_pretrained("condition/ckpts/dpt_large")
         
     | 
| 131 | 
         
            +
                    model_large = DPTForDepthEstimation.from_pretrained("condition/ckpts/dpt_large").to(device)
         
     | 
| 132 | 
         
            +
                    model = MidasDetector(device=device)
         
     | 
| 133 | 
         
            +
                with torch.no_grad():
         
     | 
| 134 | 
         
            +
                    
         
     | 
| 135 | 
         
            +
                    condition_img = resize_image_to_16_multiple(args.condition_path, args.condition_type)
         
     | 
| 136 | 
         
            +
                    W, H = condition_img.size
         
     | 
| 137 | 
         
            +
                    print(H,W)
         
     | 
| 138 | 
         
            +
                    if args.condition_type == 'seg':
         
     | 
| 139 | 
         
            +
                        condition_img = torch.from_numpy(np.array(condition_img))
         
     | 
| 140 | 
         
            +
                        condition_img = condition_img.permute(2,0,1).unsqueeze(0).repeat(2,1,1,1)
         
     | 
| 141 | 
         
            +
                    elif args.condition_type == 'canny':
         
     | 
| 142 | 
         
            +
                        condition_img = get_control(np.array(condition_img))
         
     | 
| 143 | 
         
            +
                        condition_img = torch.from_numpy(condition_img[None,None,...]).repeat(2,3,1,1)
         
     | 
| 144 | 
         
            +
                    elif args.condition_type == 'hed':
         
     | 
| 145 | 
         
            +
                        condition_img = get_control(torch.from_numpy(np.array(condition_img)).permute(2,0,1).unsqueeze(0).to(device))
         
     | 
| 146 | 
         
            +
                        condition_img = condition_img.unsqueeze(1).repeat(2,3,1,1)
         
     | 
| 147 | 
         
            +
                    elif args.condition_type == 'lineart':
         
     | 
| 148 | 
         
            +
                        condition_img = get_control(torch.from_numpy(np.array(condition_img)).permute(2,0,1).unsqueeze(0).to(device).float())
         
     | 
| 149 | 
         
            +
                        condition_img = condition_img.repeat(2,3,1,1) * 255
         
     | 
| 150 | 
         
            +
                    elif args.condition_type == 'depth':
         
     | 
| 151 | 
         
            +
                        images = condition_img
         
     | 
| 152 | 
         
            +
                        if H == W:
         
     | 
| 153 | 
         
            +
                            inputs = processor(images=images, return_tensors="pt", size=(H,W)).to(device)
         
     | 
| 154 | 
         
            +
                            outputs = model_large(**inputs)
         
     | 
| 155 | 
         
            +
                            condition_img = outputs.predicted_depth
         
     | 
| 156 | 
         
            +
                            condition_img = (condition_img * 255 / condition_img.max())
         
     | 
| 157 | 
         
            +
                        else:
         
     | 
| 158 | 
         
            +
                            condition_img = torch.from_numpy(model(torch.from_numpy(np.array(condition_img)).to(device))).unsqueeze(0)
         
     | 
| 159 | 
         
            +
                        condition_img = condition_img.unsqueeze(0).repeat(2,3,1,1)
         
     | 
| 160 | 
         
            +
                    condition_img = condition_img.to(device)
         
     | 
| 161 | 
         
            +
                    condition_img = 2*(condition_img/255 - 0.5)
         
     | 
| 162 | 
         
            +
                    prompts = [args.prompt if args.prompt is not None else "a high-quality image"]
         
     | 
| 163 | 
         
            +
                    prompts = prompts * 2
         
     | 
| 164 | 
         
            +
                    caption_embs, emb_masks = t5_model.get_text_embeddings(prompts)
         
     | 
| 165 | 
         
            +
             
     | 
| 166 | 
         
            +
                    if not args.no_left_padding:
         
     | 
| 167 | 
         
            +
                        print(f"processing left-padding...")    
         
     | 
| 168 | 
         
            +
                        # a naive way to implement left-padding
         
     | 
| 169 | 
         
            +
                        new_emb_masks = torch.flip(emb_masks, dims=[-1])
         
     | 
| 170 | 
         
            +
                        new_caption_embs = []
         
     | 
| 171 | 
         
            +
                        for idx, (caption_emb, emb_mask) in enumerate(zip(caption_embs, emb_masks)):
         
     | 
| 172 | 
         
            +
                            valid_num = int(emb_mask.sum().item())
         
     | 
| 173 | 
         
            +
                            print(f'  prompt {idx} token len: {valid_num}')
         
     | 
| 174 | 
         
            +
                            new_caption_emb = torch.cat([caption_emb[valid_num:],caption_emb[:valid_num]])
         
     | 
| 175 | 
         
            +
                            new_caption_embs.append(new_caption_emb)
         
     | 
| 176 | 
         
            +
                        new_caption_embs = torch.stack(new_caption_embs)
         
     | 
| 177 | 
         
            +
                    else:
         
     | 
| 178 | 
         
            +
                        new_caption_embs, new_emb_masks = caption_embs, emb_masks
         
     | 
| 179 | 
         
            +
                    c_indices = new_caption_embs * new_emb_masks[:,:, None]
         
     | 
| 180 | 
         
            +
                    c_emb_masks = new_emb_masks
         
     | 
| 181 | 
         
            +
                    qzshape = [len(c_indices), args.codebook_embed_dim, H//args.downsample_size, W//args.downsample_size]
         
     | 
| 182 | 
         
            +
                    t1 = time.time()
         
     | 
| 183 | 
         
            +
                    index_sample = generate(
         
     | 
| 184 | 
         
            +
                        gpt_model, c_indices, (H//args.downsample_size)*(W//args.downsample_size),#latent_size ** 2, 
         
     | 
| 185 | 
         
            +
                        c_emb_masks, condition=condition_img.to(precision),
         
     | 
| 186 | 
         
            +
                        cfg_scale=args.cfg_scale,
         
     | 
| 187 | 
         
            +
                        temperature=args.temperature, top_k=args.top_k,
         
     | 
| 188 | 
         
            +
                        top_p=args.top_p, sample_logits=True, 
         
     | 
| 189 | 
         
            +
                        )
         
     | 
| 190 | 
         
            +
                    sampling_time = time.time() - t1
         
     | 
| 191 | 
         
            +
                    print(f"Full sampling takes about {sampling_time:.2f} seconds.")    
         
     | 
| 192 | 
         
            +
                    
         
     | 
| 193 | 
         
            +
                    t2 = time.time()
         
     | 
| 194 | 
         
            +
                    print(index_sample.shape)
         
     | 
| 195 | 
         
            +
                    samples = vq_model.decode_code(index_sample, qzshape) # output value is between [-1, 1]
         
     | 
| 196 | 
         
            +
                    decoder_time = time.time() - t2
         
     | 
| 197 | 
         
            +
                    print(f"decoder takes about {decoder_time:.2f} seconds.")
         
     | 
| 198 | 
         
            +
             
     | 
| 199 | 
         
            +
                    samples = torch.cat((condition_img[0:1], samples), dim=0)
         
     | 
| 200 | 
         
            +
                    save_image(samples, f"sample/example/sample_t2i_MR_{args.condition_type}.png", nrow=4, normalize=True, value_range=(-1, 1))
         
     | 
| 201 | 
         
            +
                    print(f"image is saved to sample/example/sample_t2i_MR_{args.condition_type}.png")
         
     | 
| 202 | 
         
            +
                    print(prompts)
         
     | 
| 203 | 
         
            +
             
     | 
| 204 | 
         
            +
             
     | 
| 205 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 206 | 
         
            +
                parser = argparse.ArgumentParser()
         
     | 
| 207 | 
         
            +
                parser.add_argument("--t5-path", type=str, default='checkpoints/t5-ckpt')
         
     | 
| 208 | 
         
            +
                parser.add_argument("--t5-model-type", type=str, default='flan-t5-xl')
         
     | 
| 209 | 
         
            +
                parser.add_argument("--t5-feature-max-len", type=int, default=120)
         
     | 
| 210 | 
         
            +
                parser.add_argument("--t5-feature-dim", type=int, default=2048)
         
     | 
| 211 | 
         
            +
                parser.add_argument("--no-left-padding", action='store_true', default=False)
         
     | 
| 212 | 
         
            +
                parser.add_argument("--gpt-model", type=str, choices=list(GPT_models.keys()), default="GPT-XL")
         
     | 
| 213 | 
         
            +
                parser.add_argument("--gpt-ckpt", type=str, default=None)
         
     | 
| 214 | 
         
            +
                parser.add_argument("--gpt-type", type=str, choices=['c2i', 't2i'], default="t2i", help="class->image or text->image")  
         
     | 
| 215 | 
         
            +
                parser.add_argument("--cls-token-num", type=int, default=120, help="max token number of condition input")
         
     | 
| 216 | 
         
            +
                parser.add_argument("--precision", type=str, default='bf16', choices=["none", "fp16", "bf16"]) 
         
     | 
| 217 | 
         
            +
                parser.add_argument("--compile", action='store_true', default=False)
         
     | 
| 218 | 
         
            +
                parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16")
         
     | 
| 219 | 
         
            +
                parser.add_argument("--vq-ckpt", type=str, default=None, help="ckpt path for vq model")
         
     | 
| 220 | 
         
            +
                parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization")
         
     | 
| 221 | 
         
            +
                parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization")
         
     | 
| 222 | 
         
            +
                parser.add_argument("--image-size", type=int, choices=[256, 320, 384, 400, 448, 512, 576, 640, 704, 768], default=768)
         
     | 
| 223 | 
         
            +
                parser.add_argument("--image-H", type=int, default=512)
         
     | 
| 224 | 
         
            +
                parser.add_argument("--image-W", type=int, default=512)
         
     | 
| 225 | 
         
            +
                parser.add_argument("--downsample-size", type=int, choices=[8, 16], default=16)
         
     | 
| 226 | 
         
            +
                parser.add_argument("--cfg-scale", type=float, default=4)
         
     | 
| 227 | 
         
            +
                parser.add_argument("--seed", type=int, default=0)
         
     | 
| 228 | 
         
            +
                parser.add_argument("--top-k", type=int, default=2000, help="top-k value to sample with")
         
     | 
| 229 | 
         
            +
                parser.add_argument("--temperature", type=float, default=1.0, help="temperature value to sample with")
         
     | 
| 230 | 
         
            +
                parser.add_argument("--top-p", type=float, default=1.0, help="top-p value to sample with")
         
     | 
| 231 | 
         
            +
             
     | 
| 232 | 
         
            +
                parser.add_argument("--mixed-precision", type=str, default='bf16', choices=["none", "fp16", "bf16"]) 
         
     | 
| 233 | 
         
            +
                parser.add_argument("--condition-type", type=str, choices=['seg', 'canny', 'hed', 'lineart', 'depth'], default="canny")
         
     | 
| 234 | 
         
            +
                parser.add_argument("--prompt", type=str, default='a high-quality image')
         
     | 
| 235 | 
         
            +
                parser.add_argument("--condition-path", type=str, default='condition/example/t2i/multigen/landscape.png')
         
     | 
| 236 | 
         
            +
                args = parser.parse_args()
         
     | 
| 237 | 
         
            +
                main(args)
         
     | 
    	
        autoregressive/sample/sample_t2i_ddp.py
    ADDED
    
    | 
         @@ -0,0 +1,229 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import torch
         
     | 
| 2 | 
         
            +
            torch.backends.cuda.matmul.allow_tf32 = True
         
     | 
| 3 | 
         
            +
            torch.backends.cudnn.allow_tf32 = True
         
     | 
| 4 | 
         
            +
            torch.set_float32_matmul_precision('high')
         
     | 
| 5 | 
         
            +
            setattr(torch.nn.Linear, 'reset_parameters', lambda self: None)     # disable default parameter init for faster speed
         
     | 
| 6 | 
         
            +
            setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None)  # disable default parameter init for faster speed
         
     | 
| 7 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 8 | 
         
            +
            import torch.distributed as dist
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            import os
         
     | 
| 11 | 
         
            +
            import math
         
     | 
| 12 | 
         
            +
            import json
         
     | 
| 13 | 
         
            +
            import argparse
         
     | 
| 14 | 
         
            +
            import pandas as pd
         
     | 
| 15 | 
         
            +
            from tqdm import tqdm
         
     | 
| 16 | 
         
            +
            from PIL import Image
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            from tokenizer.tokenizer_image.vq_model import VQ_models
         
     | 
| 19 | 
         
            +
            from language.t5 import T5Embedder
         
     | 
| 20 | 
         
            +
            from autoregressive.models.gpt import GPT_models
         
     | 
| 21 | 
         
            +
            from autoregressive.models.generate import generate
         
     | 
| 22 | 
         
            +
            os.environ["TOKENIZERS_PARALLELISM"] = "false"
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
            def main(args):
         
     | 
| 27 | 
         
            +
                # Setup PyTorch:
         
     | 
| 28 | 
         
            +
                assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage"
         
     | 
| 29 | 
         
            +
                torch.set_grad_enabled(False)
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
                # Setup DDP:
         
     | 
| 32 | 
         
            +
                dist.init_process_group("nccl")
         
     | 
| 33 | 
         
            +
                rank = dist.get_rank()
         
     | 
| 34 | 
         
            +
                device = rank % torch.cuda.device_count()
         
     | 
| 35 | 
         
            +
                seed = args.global_seed * dist.get_world_size() + rank
         
     | 
| 36 | 
         
            +
                torch.manual_seed(seed)
         
     | 
| 37 | 
         
            +
                torch.cuda.set_device(device)
         
     | 
| 38 | 
         
            +
                print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
                # create and load model
         
     | 
| 41 | 
         
            +
                vq_model = VQ_models[args.vq_model](
         
     | 
| 42 | 
         
            +
                    codebook_size=args.codebook_size,
         
     | 
| 43 | 
         
            +
                    codebook_embed_dim=args.codebook_embed_dim)
         
     | 
| 44 | 
         
            +
                vq_model.to(device)
         
     | 
| 45 | 
         
            +
                vq_model.eval()
         
     | 
| 46 | 
         
            +
                checkpoint = torch.load(args.vq_ckpt, map_location="cpu")
         
     | 
| 47 | 
         
            +
                vq_model.load_state_dict(checkpoint["model"])
         
     | 
| 48 | 
         
            +
                del checkpoint
         
     | 
| 49 | 
         
            +
                print(f"image tokenizer is loaded")
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
                # create and load gpt model
         
     | 
| 52 | 
         
            +
                precision = {'none': torch.float32, 'bf16': torch.bfloat16, 'fp16': torch.float16}[args.precision]
         
     | 
| 53 | 
         
            +
                latent_size = args.image_size // args.downsample_size
         
     | 
| 54 | 
         
            +
                gpt_model = GPT_models[args.gpt_model](
         
     | 
| 55 | 
         
            +
                    block_size=latent_size ** 2,
         
     | 
| 56 | 
         
            +
                    cls_token_num=args.cls_token_num,
         
     | 
| 57 | 
         
            +
                    model_type=args.gpt_type,
         
     | 
| 58 | 
         
            +
                ).to(device=device, dtype=precision)
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
                checkpoint = torch.load(args.gpt_ckpt, map_location="cpu")
         
     | 
| 61 | 
         
            +
             
         
     | 
| 62 | 
         
            +
                if "model" in checkpoint:  # ddp
         
     | 
| 63 | 
         
            +
                    model_weight = checkpoint["model"]
         
     | 
| 64 | 
         
            +
                elif "module" in checkpoint: # deepspeed
         
     | 
| 65 | 
         
            +
                    model_weight = checkpoint["module"]
         
     | 
| 66 | 
         
            +
                elif "state_dict" in checkpoint:
         
     | 
| 67 | 
         
            +
                    model_weight = checkpoint["state_dict"]
         
     | 
| 68 | 
         
            +
                else:
         
     | 
| 69 | 
         
            +
                    raise Exception("please check model weight")
         
     | 
| 70 | 
         
            +
                gpt_model.load_state_dict(model_weight, strict=False)
         
     | 
| 71 | 
         
            +
                gpt_model.eval()
         
     | 
| 72 | 
         
            +
                del checkpoint
         
     | 
| 73 | 
         
            +
                print(f"gpt model is loaded")
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
                if args.compile:
         
     | 
| 76 | 
         
            +
                    print(f"compiling the model...")
         
     | 
| 77 | 
         
            +
                    gpt_model = torch.compile(
         
     | 
| 78 | 
         
            +
                        gpt_model,
         
     | 
| 79 | 
         
            +
                        mode="reduce-overhead",
         
     | 
| 80 | 
         
            +
                        fullgraph=True
         
     | 
| 81 | 
         
            +
                    ) # requires PyTorch 2.0 (optional)
         
     | 
| 82 | 
         
            +
                else:
         
     | 
| 83 | 
         
            +
                    print(f"no need to compile model in demo") 
         
     | 
| 84 | 
         
            +
                
         
     | 
| 85 | 
         
            +
                assert os.path.exists(args.t5_path)
         
     | 
| 86 | 
         
            +
                t5_model = T5Embedder(
         
     | 
| 87 | 
         
            +
                    device=device, 
         
     | 
| 88 | 
         
            +
                    local_cache=True, 
         
     | 
| 89 | 
         
            +
                    cache_dir=args.t5_path, 
         
     | 
| 90 | 
         
            +
                    dir_or_name=args.t5_model_type,
         
     | 
| 91 | 
         
            +
                    torch_dtype=precision,
         
     | 
| 92 | 
         
            +
                    model_max_length=args.t5_feature_max_len,
         
     | 
| 93 | 
         
            +
                )
         
     | 
| 94 | 
         
            +
                print(f"t5 model is loaded")
         
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
                # Create folder to save samples:
         
     | 
| 97 | 
         
            +
                model_string_name = args.gpt_model.replace("/", "-")
         
     | 
| 98 | 
         
            +
                ckpt_string_name = os.path.basename(args.gpt_ckpt).replace(".pth", "").replace(".pt", "")
         
     | 
| 99 | 
         
            +
                prompt_name = args.prompt_csv.split('/')[-1].split('.')[0].lower()
         
     | 
| 100 | 
         
            +
                folder_name = f"{model_string_name}-{ckpt_string_name}-{prompt_name}-size-{args.image_size}-size-{args.image_size}-{args.vq_model}-" \
         
     | 
| 101 | 
         
            +
                              f"topk-{args.top_k}-topp-{args.top_p}-temperature-{args.temperature}-" \
         
     | 
| 102 | 
         
            +
                              f"cfg-{args.cfg_scale}-seed-{args.global_seed}"
         
     | 
| 103 | 
         
            +
                sample_folder_dir = f"{args.sample_dir}/{folder_name}"
         
     | 
| 104 | 
         
            +
                if rank == 0:
         
     | 
| 105 | 
         
            +
                    os.makedirs(f"{sample_folder_dir}/images", exist_ok=True)
         
     | 
| 106 | 
         
            +
                    print(f"Saving .png samples at {sample_folder_dir}/images")
         
     | 
| 107 | 
         
            +
                dist.barrier()
         
     | 
| 108 | 
         
            +
             
     | 
| 109 | 
         
            +
                df = pd.read_csv(args.prompt_csv, delimiter='\t')
         
     | 
| 110 | 
         
            +
                prompt_list = df['Prompt'].tolist()
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
                # Figure out how many samples we need to generate on each GPU and how many iterations we need to run:
         
     | 
| 113 | 
         
            +
                n = args.per_proc_batch_size
         
     | 
| 114 | 
         
            +
                global_batch_size = n * dist.get_world_size()
         
     | 
| 115 | 
         
            +
                num_fid_samples = min(args.num_fid_samples, len(prompt_list))
         
     | 
| 116 | 
         
            +
                # To make things evenly-divisible, we'll sample a bit more than we need and then discard the extra samples:
         
     | 
| 117 | 
         
            +
                total_samples = int(math.ceil(num_fid_samples / global_batch_size) * global_batch_size)
         
     | 
| 118 | 
         
            +
                if rank == 0:
         
     | 
| 119 | 
         
            +
                    print(f"Total number of images that will be sampled: {total_samples}")
         
     | 
| 120 | 
         
            +
                assert total_samples % dist.get_world_size() == 0, "total_samples must be divisible by world_size"
         
     | 
| 121 | 
         
            +
                samples_needed_this_gpu = int(total_samples // dist.get_world_size())
         
     | 
| 122 | 
         
            +
                assert samples_needed_this_gpu % n == 0, "samples_needed_this_gpu must be divisible by the per-GPU batch size"
         
     | 
| 123 | 
         
            +
                iterations = int(samples_needed_this_gpu // n)
         
     | 
| 124 | 
         
            +
                pbar = range(iterations)
         
     | 
| 125 | 
         
            +
                pbar = tqdm(pbar) if rank == 0 else pbar
         
     | 
| 126 | 
         
            +
                total = 0
         
     | 
| 127 | 
         
            +
                for _ in pbar:
         
     | 
| 128 | 
         
            +
                    # Select text prompt
         
     | 
| 129 | 
         
            +
                    prompt_batch = []
         
     | 
| 130 | 
         
            +
                    for i in range(n):
         
     | 
| 131 | 
         
            +
                        index = i * dist.get_world_size() + rank + total
         
     | 
| 132 | 
         
            +
                        prompt_batch.append(prompt_list[index] if index < len(prompt_list) else "a cute dog")
         
     | 
| 133 | 
         
            +
                          
         
     | 
| 134 | 
         
            +
                    # Sample inputs:
         
     | 
| 135 | 
         
            +
                    caption_embs, emb_masks = t5_model.get_text_embeddings(prompt_batch)
         
     | 
| 136 | 
         
            +
                    
         
     | 
| 137 | 
         
            +
                    if not args.no_left_padding:
         
     | 
| 138 | 
         
            +
                        new_emb_masks = torch.flip(emb_masks, dims=[-1])
         
     | 
| 139 | 
         
            +
                        new_caption_embs = []
         
     | 
| 140 | 
         
            +
                        for idx, (caption_emb, emb_mask) in enumerate(zip(caption_embs, emb_masks)):
         
     | 
| 141 | 
         
            +
                            valid_num = int(emb_mask.sum().item())
         
     | 
| 142 | 
         
            +
                            # prompt_cur = prompt_batch[idx]
         
     | 
| 143 | 
         
            +
                            # print(f'  prompt {idx} token len: {valid_num} : {prompt_cur}')
         
     | 
| 144 | 
         
            +
                            new_caption_emb = torch.cat([caption_emb[valid_num:], caption_emb[:valid_num]])
         
     | 
| 145 | 
         
            +
                            new_caption_embs.append(new_caption_emb)
         
     | 
| 146 | 
         
            +
                        new_caption_embs = torch.stack(new_caption_embs)
         
     | 
| 147 | 
         
            +
             
     | 
| 148 | 
         
            +
                    else:
         
     | 
| 149 | 
         
            +
                        new_caption_embs, new_emb_masks = caption_embs, emb_masks
         
     | 
| 150 | 
         
            +
             
     | 
| 151 | 
         
            +
                    c_indices = new_caption_embs * new_emb_masks[:,:, None]
         
     | 
| 152 | 
         
            +
                    c_emb_masks = new_emb_masks
         
     | 
| 153 | 
         
            +
             
     | 
| 154 | 
         
            +
                    qzshape = [len(c_indices), args.codebook_embed_dim, latent_size, latent_size]
         
     | 
| 155 | 
         
            +
                    index_sample = generate(
         
     | 
| 156 | 
         
            +
                        gpt_model, c_indices, latent_size ** 2, 
         
     | 
| 157 | 
         
            +
                        c_emb_masks,
         
     | 
| 158 | 
         
            +
                        cfg_scale=args.cfg_scale,
         
     | 
| 159 | 
         
            +
                        temperature=args.temperature, top_k=args.top_k,
         
     | 
| 160 | 
         
            +
                        top_p=args.top_p, sample_logits=True, 
         
     | 
| 161 | 
         
            +
                        )
         
     | 
| 162 | 
         
            +
                    
         
     | 
| 163 | 
         
            +
                    samples = vq_model.decode_code(index_sample, qzshape) # output value is between [-1, 1]
         
     | 
| 164 | 
         
            +
                    samples = torch.clamp(127.5 * samples + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
         
     | 
| 165 | 
         
            +
                    
         
     | 
| 166 | 
         
            +
                    # Save samples to disk as individual .png files
         
     | 
| 167 | 
         
            +
                    for i, sample in enumerate(samples):
         
     | 
| 168 | 
         
            +
                        index = i * dist.get_world_size() + rank + total
         
     | 
| 169 | 
         
            +
                        Image.fromarray(sample).save(f"{sample_folder_dir}/images/{index:06d}.png")
         
     | 
| 170 | 
         
            +
                    total += global_batch_size
         
     | 
| 171 | 
         
            +
             
     | 
| 172 | 
         
            +
                # Make sure all processes have finished saving their samples before attempting to convert to .npz
         
     | 
| 173 | 
         
            +
                dist.barrier()
         
     | 
| 174 | 
         
            +
                if rank == 0:
         
     | 
| 175 | 
         
            +
                    # Save infer result in a jsonl file
         
     | 
| 176 | 
         
            +
                    json_items = []
         
     | 
| 177 | 
         
            +
                    for idx, prompt in enumerate(prompt_list):
         
     | 
| 178 | 
         
            +
                        image_path = os.path.join(sample_folder_dir, "images", f"{idx:06d}.png")
         
     | 
| 179 | 
         
            +
                        json_items.append({"text": prompt, "image_path": image_path})
         
     | 
| 180 | 
         
            +
                    res_jsonl_path = os.path.join(sample_folder_dir, "result.jsonl")
         
     | 
| 181 | 
         
            +
                    print(f"Save jsonl to {res_jsonl_path}...")
         
     | 
| 182 | 
         
            +
                    with open(res_jsonl_path, "w") as f:
         
     | 
| 183 | 
         
            +
                        for item in json_items:
         
     | 
| 184 | 
         
            +
                            f.write(json.dumps(item) + "\n")
         
     | 
| 185 | 
         
            +
             
     | 
| 186 | 
         
            +
                    # Save captions to txt
         
     | 
| 187 | 
         
            +
                    caption_path = os.path.join(sample_folder_dir, "captions.txt")
         
     | 
| 188 | 
         
            +
                    print(f"Save captions to {caption_path}...")
         
     | 
| 189 | 
         
            +
                    with open(caption_path, "w") as f:
         
     | 
| 190 | 
         
            +
                        for item in prompt_list:
         
     | 
| 191 | 
         
            +
                            f.write(f"{item}\n")
         
     | 
| 192 | 
         
            +
                    print("Done.")
         
     | 
| 193 | 
         
            +
                
         
     | 
| 194 | 
         
            +
                dist.barrier()
         
     | 
| 195 | 
         
            +
                dist.destroy_process_group()
         
     | 
| 196 | 
         
            +
             
     | 
| 197 | 
         
            +
             
     | 
| 198 | 
         
            +
             
     | 
| 199 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 200 | 
         
            +
                parser = argparse.ArgumentParser()
         
     | 
| 201 | 
         
            +
                parser.add_argument("--prompt-csv", type=str, default='evaluations/t2i/PartiPrompts.tsv')
         
     | 
| 202 | 
         
            +
                parser.add_argument("--t5-path", type=str, default='pretrained_models/t5-ckpt')
         
     | 
| 203 | 
         
            +
                parser.add_argument("--t5-model-type", type=str, default='flan-t5-xl')
         
     | 
| 204 | 
         
            +
                parser.add_argument("--t5-feature-max-len", type=int, default=120)
         
     | 
| 205 | 
         
            +
                parser.add_argument("--t5-feature-dim", type=int, default=2048)
         
     | 
| 206 | 
         
            +
                parser.add_argument("--no-left-padding", action='store_true', default=False)
         
     | 
| 207 | 
         
            +
                parser.add_argument("--gpt-model", type=str, choices=list(GPT_models.keys()), default="GPT-XL")
         
     | 
| 208 | 
         
            +
                parser.add_argument("--gpt-ckpt", type=str, default=None)
         
     | 
| 209 | 
         
            +
                parser.add_argument("--gpt-type", type=str, choices=['c2i', 't2i'], default="t2i", help="class->image or text->image")  
         
     | 
| 210 | 
         
            +
                parser.add_argument("--cls-token-num", type=int, default=120, help="max token number of condition input")
         
     | 
| 211 | 
         
            +
                parser.add_argument("--precision", type=str, default='bf16', choices=["none", "fp16", "bf16"]) 
         
     | 
| 212 | 
         
            +
                parser.add_argument("--compile", action='store_true', default=False)
         
     | 
| 213 | 
         
            +
                parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16")
         
     | 
| 214 | 
         
            +
                parser.add_argument("--vq-ckpt", type=str, default=None, help="ckpt path for vq model")
         
     | 
| 215 | 
         
            +
                parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization")
         
     | 
| 216 | 
         
            +
                parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization")
         
     | 
| 217 | 
         
            +
                parser.add_argument("--image-size", type=int, choices=[256, 384, 512], default=512)
         
     | 
| 218 | 
         
            +
                parser.add_argument("--downsample-size", type=int, choices=[8, 16], default=16)
         
     | 
| 219 | 
         
            +
                parser.add_argument("--num-classes", type=int, default=1000)
         
     | 
| 220 | 
         
            +
                parser.add_argument("--cfg-scale", type=float, default=7.5)
         
     | 
| 221 | 
         
            +
                parser.add_argument("--sample-dir", type=str, default="samples_parti", help="samples_coco or samples_parti")
         
     | 
| 222 | 
         
            +
                parser.add_argument("--per-proc-batch-size", type=int, default=32)
         
     | 
| 223 | 
         
            +
                parser.add_argument("--num-fid-samples", type=int, default=30000)
         
     | 
| 224 | 
         
            +
                parser.add_argument("--global-seed", type=int, default=0)
         
     | 
| 225 | 
         
            +
                parser.add_argument("--top-k", type=int, default=1000, help="top-k value to sample with")
         
     | 
| 226 | 
         
            +
                parser.add_argument("--temperature", type=float, default=1.0, help="temperature value to sample with")
         
     | 
| 227 | 
         
            +
                parser.add_argument("--top-p", type=float, default=1.0, help="top-p value to sample with")
         
     | 
| 228 | 
         
            +
                args = parser.parse_args()
         
     | 
| 229 | 
         
            +
                main(args)
         
     | 
    	
        checkpoints/vq_ds16_t2i.pt
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:0e21fc1318e2e9ee641a07bdad0e20675e9ec35e6e3eb911d58b5d7a2cd8d4cb
         
     | 
| 3 | 
         
            +
            size 287920306
         
     | 
    	
        condition/README.md
    ADDED
    
    | 
         @@ -0,0 +1,23 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            Prepare the preprocessing model
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            Hed: https://huggingface.co/lllyasviel/Annotators/blob/main/ControlNetHED.pth\
         
     | 
| 4 | 
         
            +
            Lineart: https://huggingface.co/spaces/awacke1/Image-to-Line-Drawings/resolve/main/model.pth\
         
     | 
| 5 | 
         
            +
            depth: https://huggingface.co/lllyasviel/Annotators/blob/main/dpt_hybrid-midas-501f0c75.pt (hybrid for inference)\
         
     | 
| 6 | 
         
            +
                   https://huggingface.co/Intel/dpt-large (large for test conditional consistency and fid)\
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            We recommend storing them in the following paths
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
                |---condition
         
     | 
| 11 | 
         
            +
                    |---ckpts
         
     | 
| 12 | 
         
            +
                        |---dpt_large
         
     | 
| 13 | 
         
            +
                            |---config.json
         
     | 
| 14 | 
         
            +
                            |---preprocessor_config.json
         
     | 
| 15 | 
         
            +
                            |---pytorch_model.bin
         
     | 
| 16 | 
         
            +
                        |---ControlNetHED.pth
         
     | 
| 17 | 
         
            +
                        |---dpt_hybrid-midas-501f0c75.pt
         
     | 
| 18 | 
         
            +
                        |---model.pth
         
     | 
| 19 | 
         
            +
                    |---example
         
     | 
| 20 | 
         
            +
                    |---midas
         
     | 
| 21 | 
         
            +
                    .
         
     | 
| 22 | 
         
            +
                    .
         
     | 
| 23 | 
         
            +
                    .
         
     | 
    	
        condition/canny.py
    ADDED
    
    | 
         @@ -0,0 +1,25 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import cv2
         
     | 
| 2 | 
         
            +
            import torch
         
     | 
| 3 | 
         
            +
            import numpy as np
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            class CannyDetector:
         
     | 
| 7 | 
         
            +
                def __call__(self, img, low_threshold=100, high_threshold=200):
         
     | 
| 8 | 
         
            +
                    """
         
     | 
| 9 | 
         
            +
                    input: array or tensor (H,W,3)
         
     | 
| 10 | 
         
            +
                    output: array (H,W)
         
     | 
| 11 | 
         
            +
                    """
         
     | 
| 12 | 
         
            +
                    if torch.is_tensor(img):
         
     | 
| 13 | 
         
            +
                        img = img.cpu().detach().numpy().astype(np.uint8)
         
     | 
| 14 | 
         
            +
                    return cv2.Canny(img, low_threshold, high_threshold)
         
     | 
| 15 | 
         
            +
                
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            if __name__ == '__main__':
         
     | 
| 18 | 
         
            +
                apply_canny = CannyDetector()
         
     | 
| 19 | 
         
            +
                img = cv2.imread('condition/dragon_resize.png')
         
     | 
| 20 | 
         
            +
                import numpy as np
         
     | 
| 21 | 
         
            +
                print(img.max())
         
     | 
| 22 | 
         
            +
                detected_map = apply_canny(img, 100, 200)
         
     | 
| 23 | 
         
            +
                print(detected_map.shape, detected_map.max(), detected_map.min())
         
     | 
| 24 | 
         
            +
                cv2.imwrite('condition/example_canny.jpg', detected_map)
         
     | 
| 25 | 
         
            +
                np.save('condition/example_canny.npy', detected_map[None,None])
         
     | 
    	
        condition/depth.py
    ADDED
    
    | 
         @@ -0,0 +1,47 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from controlnet_aux import LineartDetector
         
     | 
| 2 | 
         
            +
            import torch
         
     | 
| 3 | 
         
            +
            import cv2
         
     | 
| 4 | 
         
            +
            import numpy as np
         
     | 
| 5 | 
         
            +
            from transformers import DPTImageProcessor, DPTForDepthEstimation
         
     | 
| 6 | 
         
            +
            class Depth:
         
     | 
| 7 | 
         
            +
                def __init__(self, device):
         
     | 
| 8 | 
         
            +
                    self.model = DPTForDepthEstimation.from_pretrained("condition/ckpts/dpt_large")
         
     | 
| 9 | 
         
            +
                    
         
     | 
| 10 | 
         
            +
                def __call__(self, input_image):
         
     | 
| 11 | 
         
            +
                    """
         
     | 
| 12 | 
         
            +
                    input: tensor()
         
     | 
| 13 | 
         
            +
                    """
         
     | 
| 14 | 
         
            +
                    control_image = self.model(input_image)
         
     | 
| 15 | 
         
            +
                    return np.array(control_image)
         
     | 
| 16 | 
         
            +
                
         
     | 
| 17 | 
         
            +
            if __name__ == '__main__':
         
     | 
| 18 | 
         
            +
                import matplotlib.pyplot as plt
         
     | 
| 19 | 
         
            +
                from tqdm import tqdm
         
     | 
| 20 | 
         
            +
                from transformers import DPTImageProcessor, DPTForDepthEstimation
         
     | 
| 21 | 
         
            +
                from PIL import Image
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
                image = Image.open('condition/example/t2i/depth/depth.png')
         
     | 
| 24 | 
         
            +
                img = cv2.imread('condition/example/t2i/depth/depth.png')
         
     | 
| 25 | 
         
            +
                processor = DPTImageProcessor.from_pretrained("condition/ckpts/dpt_large")
         
     | 
| 26 | 
         
            +
                model = DPTForDepthEstimation.from_pretrained("condition/ckpts/dpt_large")
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
                inputs = torch.from_numpy(np.array(img)).permute(2,0,1).unsqueeze(0).float()#
         
     | 
| 29 | 
         
            +
                inputs = 2*(inputs/255 - 0.5)
         
     | 
| 30 | 
         
            +
                inputs = processor(images=image, return_tensors="pt", size=(512,512))
         
     | 
| 31 | 
         
            +
                print(inputs)
         
     | 
| 32 | 
         
            +
                with torch.no_grad():
         
     | 
| 33 | 
         
            +
                    outputs = model(**inputs)
         
     | 
| 34 | 
         
            +
                    predicted_depth = outputs.predicted_depth
         
     | 
| 35 | 
         
            +
                print(predicted_depth.shape)
         
     | 
| 36 | 
         
            +
                prediction = torch.nn.functional.interpolate(
         
     | 
| 37 | 
         
            +
                    predicted_depth.unsqueeze(1),
         
     | 
| 38 | 
         
            +
                    size=image.size[::-1],
         
     | 
| 39 | 
         
            +
                    mode="bicubic",
         
     | 
| 40 | 
         
            +
                    align_corners=False,
         
     | 
| 41 | 
         
            +
                )
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
                output = prediction.squeeze().cpu().numpy()
         
     | 
| 44 | 
         
            +
                formatted = (output * 255 / np.max(output)).astype("uint8")
         
     | 
| 45 | 
         
            +
                
         
     | 
| 46 | 
         
            +
                depth = Image.fromarray(formatted)
         
     | 
| 47 | 
         
            +
                depth.save('condition/example/t2i/depth/example_depth.jpg')
         
     | 
    	
        condition/example/t2i/multi_resolution/bird.jpg
    ADDED
    
    
											 
									 | 
									
								
    	
        condition/example/t2i/multi_resolution/car.jpg
    ADDED
    
    
											 
									 | 
									
								
    	
        condition/example/t2i/multigen/doll.jpg
    ADDED
    
    
											 
									 | 
									
								
    	
        condition/example/t2i/multigen/girl.jpg
    ADDED
    
    
											 
									 | 
									
								
    	
        condition/example/t2i/multigen/house.jpg
    ADDED
    
    
											 
									 | 
									
								
    	
        condition/example/t2i/multigen/sofa.png
    ADDED
    
    
											 
									 | 
									
								
    	
        condition/hed.py
    ADDED
    
    | 
         @@ -0,0 +1,117 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # This is an improved version and model of HED edge detection with Apache License, Version 2.0.
         
     | 
| 2 | 
         
            +
            # Please use this implementation in your products
         
     | 
| 3 | 
         
            +
            # This implementation may produce slightly different results from Saining Xie's official implementations,
         
     | 
| 4 | 
         
            +
            # but it generates smoother edges and is more suitable for ControlNet as well as other image-to-image translations.
         
     | 
| 5 | 
         
            +
            # Different from official models and other implementations, this is an RGB-input model (rather than BGR)
         
     | 
| 6 | 
         
            +
            # and in this way it works better for gradio's RGB protocol
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            import os
         
     | 
| 9 | 
         
            +
            import cv2
         
     | 
| 10 | 
         
            +
            import torch
         
     | 
| 11 | 
         
            +
            import numpy as np
         
     | 
| 12 | 
         
            +
            from torch.nn.parallel import DataParallel
         
     | 
| 13 | 
         
            +
            from einops import rearrange
         
     | 
| 14 | 
         
            +
            from condition.utils import annotator_ckpts_path
         
     | 
| 15 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            class DoubleConvBlock(torch.nn.Module):
         
     | 
| 18 | 
         
            +
                def __init__(self, input_channel, output_channel, layer_number):
         
     | 
| 19 | 
         
            +
                    super().__init__()
         
     | 
| 20 | 
         
            +
                    self.convs = torch.nn.Sequential()
         
     | 
| 21 | 
         
            +
                    self.convs.append(torch.nn.Conv2d(in_channels=input_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1))
         
     | 
| 22 | 
         
            +
                    for i in range(1, layer_number):
         
     | 
| 23 | 
         
            +
                        self.convs.append(torch.nn.Conv2d(in_channels=output_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1))
         
     | 
| 24 | 
         
            +
                    self.projection = torch.nn.Conv2d(in_channels=output_channel, out_channels=1, kernel_size=(1, 1), stride=(1, 1), padding=0)
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
                def __call__(self, x, down_sampling=False):
         
     | 
| 27 | 
         
            +
                    h = x
         
     | 
| 28 | 
         
            +
                    if down_sampling:
         
     | 
| 29 | 
         
            +
                        h = torch.nn.functional.max_pool2d(h, kernel_size=(2, 2), stride=(2, 2))
         
     | 
| 30 | 
         
            +
                    for conv in self.convs:
         
     | 
| 31 | 
         
            +
                        h = conv(h)
         
     | 
| 32 | 
         
            +
                        h = torch.nn.functional.relu(h)
         
     | 
| 33 | 
         
            +
                    return h, self.projection(h)
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
            class ControlNetHED_Apache2(torch.nn.Module):
         
     | 
| 37 | 
         
            +
                def __init__(self):
         
     | 
| 38 | 
         
            +
                    super().__init__()
         
     | 
| 39 | 
         
            +
                    self.norm = torch.nn.Parameter(torch.zeros(size=(1, 3, 1, 1)))
         
     | 
| 40 | 
         
            +
                    self.block1 = DoubleConvBlock(input_channel=3, output_channel=64, layer_number=2)
         
     | 
| 41 | 
         
            +
                    self.block2 = DoubleConvBlock(input_channel=64, output_channel=128, layer_number=2)
         
     | 
| 42 | 
         
            +
                    self.block3 = DoubleConvBlock(input_channel=128, output_channel=256, layer_number=3)
         
     | 
| 43 | 
         
            +
                    self.block4 = DoubleConvBlock(input_channel=256, output_channel=512, layer_number=3)
         
     | 
| 44 | 
         
            +
                    self.block5 = DoubleConvBlock(input_channel=512, output_channel=512, layer_number=3)
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
                def __call__(self, x):
         
     | 
| 47 | 
         
            +
                    h = x - self.norm
         
     | 
| 48 | 
         
            +
                    h, projection1 = self.block1(h)
         
     | 
| 49 | 
         
            +
                    h, projection2 = self.block2(h, down_sampling=True)
         
     | 
| 50 | 
         
            +
                    h, projection3 = self.block3(h, down_sampling=True)
         
     | 
| 51 | 
         
            +
                    h, projection4 = self.block4(h, down_sampling=True)
         
     | 
| 52 | 
         
            +
                    h, projection5 = self.block5(h, down_sampling=True)
         
     | 
| 53 | 
         
            +
                    return projection1, projection2, projection3, projection4, projection5
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
            class HEDdetector(torch.nn.Module):
         
     | 
| 57 | 
         
            +
                def __init__(self):
         
     | 
| 58 | 
         
            +
                    super().__init__()
         
     | 
| 59 | 
         
            +
                    remote_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/ControlNetHED.pth"
         
     | 
| 60 | 
         
            +
                    modelpath = os.path.join(annotator_ckpts_path, "ControlNetHED.pth")
         
     | 
| 61 | 
         
            +
                    if not os.path.exists(modelpath):
         
     | 
| 62 | 
         
            +
                        from basicsr.utils.download_util import load_file_from_url
         
     | 
| 63 | 
         
            +
                        load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path)
         
     | 
| 64 | 
         
            +
                    self.netNetwork = ControlNetHED_Apache2().float()#.to(self.device).eval()
         
     | 
| 65 | 
         
            +
                    self.netNetwork.load_state_dict(torch.load(modelpath))
         
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
                def __call__(self, input_image):
         
     | 
| 68 | 
         
            +
                    """
         
     | 
| 69 | 
         
            +
                    input: tensor (B,C,H,W)
         
     | 
| 70 | 
         
            +
                    output: tensor (B,H,W)
         
     | 
| 71 | 
         
            +
                    """
         
     | 
| 72 | 
         
            +
                    B, C, H, W = input_image.shape
         
     | 
| 73 | 
         
            +
                    image_hed = input_image
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
                    edges = self.netNetwork(image_hed)
         
     | 
| 76 | 
         
            +
                    edges = [F.interpolate(e, size=(H, W), mode='bilinear', align_corners=False).squeeze(1) for e in edges]
         
     | 
| 77 | 
         
            +
                    edges = torch.stack(edges, dim=1)
         
     | 
| 78 | 
         
            +
                    edge = 1 / (1 + torch.exp(-torch.mean(edges, dim=1)))
         
     | 
| 79 | 
         
            +
                    edge = (edge * 255.0).clamp(0, 255)
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
                    return edge
         
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
            def nms(x, t, s):
         
     | 
| 85 | 
         
            +
                x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s)
         
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
                f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
         
     | 
| 88 | 
         
            +
                f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
         
     | 
| 89 | 
         
            +
                f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
         
     | 
| 90 | 
         
            +
                f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
         
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
                y = np.zeros_like(x)
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
                for f in [f1, f2, f3, f4]:
         
     | 
| 95 | 
         
            +
                    np.putmask(y, cv2.dilate(x, kernel=f) == x, x)
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
                z = np.zeros_like(y, dtype=np.uint8)
         
     | 
| 98 | 
         
            +
                z[y > t] = 255
         
     | 
| 99 | 
         
            +
                return z
         
     | 
| 100 | 
         
            +
             
     | 
| 101 | 
         
            +
            if __name__ == '__main__':
         
     | 
| 102 | 
         
            +
                import matplotlib.pyplot as plt
         
     | 
| 103 | 
         
            +
                from tqdm import tqdm
         
     | 
| 104 | 
         
            +
                import torch.nn.functional as F
         
     | 
| 105 | 
         
            +
                device = torch.device('cuda')
         
     | 
| 106 | 
         
            +
                apply_hed = HEDdetector().to(device).eval()
         
     | 
| 107 | 
         
            +
                img = cv2.imread('condition/dragon_1024_512.jpg')
         
     | 
| 108 | 
         
            +
                H,W = img.shape[:2]
         
     | 
| 109 | 
         
            +
                resize_img = cv2.resize(img,(512,1024))
         
     | 
| 110 | 
         
            +
                detected_map = apply_hed(torch.from_numpy(img).permute(2,0,1).unsqueeze(0).cuda())
         
     | 
| 111 | 
         
            +
                resize_detected_map = apply_hed(torch.from_numpy(resize_img).permute(2,0,1).unsqueeze(0).cuda())
         
     | 
| 112 | 
         
            +
                cv2.imwrite('condition/example_hed_resize.jpg', resize_detected_map[0].cpu().detach().numpy())
         
     | 
| 113 | 
         
            +
                resize_detected_map = F.interpolate(resize_detected_map.unsqueeze(0).to(torch.float32), size=(H,W), mode='bilinear', align_corners=False, antialias=True)
         
     | 
| 114 | 
         
            +
                print(abs(detected_map - resize_detected_map).sum())
         
     | 
| 115 | 
         
            +
                print(img.shape, img.max(),img.min(),detected_map.shape, detected_map.max(),detected_map.min())
         
     | 
| 116 | 
         
            +
                cv2.imwrite('condition/example_hed.jpg', detected_map[0].cpu().detach().numpy())
         
     | 
| 117 | 
         
            +
                cv2.imwrite('condition/example_hed_resized.jpg', resize_detected_map[0,0].cpu().detach().numpy())
         
     | 
    	
        condition/lineart.py
    ADDED
    
    | 
         @@ -0,0 +1,98 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from controlnet_aux import LineartDetector
         
     | 
| 2 | 
         
            +
            import torch
         
     | 
| 3 | 
         
            +
            import cv2
         
     | 
| 4 | 
         
            +
            import numpy as np
         
     | 
| 5 | 
         
            +
            import torch.nn as nn
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            norm_layer = nn.InstanceNorm2d
         
     | 
| 9 | 
         
            +
            class ResidualBlock(nn.Module):
         
     | 
| 10 | 
         
            +
                def __init__(self, in_features):
         
     | 
| 11 | 
         
            +
                    super(ResidualBlock, self).__init__()
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
                    conv_block = [  nn.ReflectionPad2d(1),
         
     | 
| 14 | 
         
            +
                                    nn.Conv2d(in_features, in_features, 3),
         
     | 
| 15 | 
         
            +
                                    norm_layer(in_features),
         
     | 
| 16 | 
         
            +
                                    nn.ReLU(inplace=True),
         
     | 
| 17 | 
         
            +
                                    nn.ReflectionPad2d(1),
         
     | 
| 18 | 
         
            +
                                    nn.Conv2d(in_features, in_features, 3),
         
     | 
| 19 | 
         
            +
                                    norm_layer(in_features)
         
     | 
| 20 | 
         
            +
                                    ]
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
                    self.conv_block = nn.Sequential(*conv_block)
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
                def forward(self, x):
         
     | 
| 25 | 
         
            +
                    return x + self.conv_block(x)
         
     | 
| 26 | 
         
            +
            class LineArt(nn.Module):
         
     | 
| 27 | 
         
            +
                def __init__(self, input_nc=3, output_nc=1, n_residual_blocks=3, sigmoid=True):
         
     | 
| 28 | 
         
            +
                    super(LineArt, self).__init__()
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
                    # Initial convolution block
         
     | 
| 31 | 
         
            +
                    model0 = [   nn.ReflectionPad2d(3),
         
     | 
| 32 | 
         
            +
                                nn.Conv2d(input_nc, 64, 7),
         
     | 
| 33 | 
         
            +
                                norm_layer(64),
         
     | 
| 34 | 
         
            +
                                nn.ReLU(inplace=True) ]
         
     | 
| 35 | 
         
            +
                    self.model0 = nn.Sequential(*model0)
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
                    # Downsampling
         
     | 
| 38 | 
         
            +
                    model1 = []
         
     | 
| 39 | 
         
            +
                    in_features = 64
         
     | 
| 40 | 
         
            +
                    out_features = in_features*2
         
     | 
| 41 | 
         
            +
                    for _ in range(2):
         
     | 
| 42 | 
         
            +
                        model1 += [  nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
         
     | 
| 43 | 
         
            +
                                    norm_layer(out_features),
         
     | 
| 44 | 
         
            +
                                    nn.ReLU(inplace=True) ]
         
     | 
| 45 | 
         
            +
                        in_features = out_features
         
     | 
| 46 | 
         
            +
                        out_features = in_features*2
         
     | 
| 47 | 
         
            +
                    self.model1 = nn.Sequential(*model1)
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
                    model2 = []
         
     | 
| 50 | 
         
            +
                    # Residual blocks
         
     | 
| 51 | 
         
            +
                    for _ in range(n_residual_blocks):
         
     | 
| 52 | 
         
            +
                        model2 += [ResidualBlock(in_features)]
         
     | 
| 53 | 
         
            +
                    self.model2 = nn.Sequential(*model2)
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
                    # Upsampling
         
     | 
| 56 | 
         
            +
                    model3 = []
         
     | 
| 57 | 
         
            +
                    out_features = in_features//2
         
     | 
| 58 | 
         
            +
                    for _ in range(2):
         
     | 
| 59 | 
         
            +
                        model3 += [  nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
         
     | 
| 60 | 
         
            +
                                    norm_layer(out_features),
         
     | 
| 61 | 
         
            +
                                    nn.ReLU(inplace=True) ]
         
     | 
| 62 | 
         
            +
                        in_features = out_features
         
     | 
| 63 | 
         
            +
                        out_features = in_features//2
         
     | 
| 64 | 
         
            +
                    self.model3 = nn.Sequential(*model3)
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
                    # Output layer
         
     | 
| 67 | 
         
            +
                    model4 = [  nn.ReflectionPad2d(3),
         
     | 
| 68 | 
         
            +
                                    nn.Conv2d(64, output_nc, 7)]
         
     | 
| 69 | 
         
            +
                    if sigmoid:
         
     | 
| 70 | 
         
            +
                        model4 += [nn.Sigmoid()]
         
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
                    self.model4 = nn.Sequential(*model4)
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
                def forward(self, x, cond=None):
         
     | 
| 75 | 
         
            +
                    """
         
     | 
| 76 | 
         
            +
                    input: tensor (B,C,H,W)
         
     | 
| 77 | 
         
            +
                    output: tensor (B,1,H,W) 0~1
         
     | 
| 78 | 
         
            +
                    """
         
     | 
| 79 | 
         
            +
             
     | 
| 80 | 
         
            +
                    out = self.model0(x)
         
     | 
| 81 | 
         
            +
                    out = self.model1(out)
         
     | 
| 82 | 
         
            +
                    out = self.model2(out)
         
     | 
| 83 | 
         
            +
                    out = self.model3(out)
         
     | 
| 84 | 
         
            +
                    out = self.model4(out)
         
     | 
| 85 | 
         
            +
             
     | 
| 86 | 
         
            +
                    return out
         
     | 
| 87 | 
         
            +
                
         
     | 
| 88 | 
         
            +
                
         
     | 
| 89 | 
         
            +
            if __name__ == '__main__':
         
     | 
| 90 | 
         
            +
                import matplotlib.pyplot as plt
         
     | 
| 91 | 
         
            +
                from tqdm import tqdm
         
     | 
| 92 | 
         
            +
                apply_lineart = LineArt()
         
     | 
| 93 | 
         
            +
                apply_lineart.load_state_dict(torch.load('condition/ckpts/model.pth', map_location=torch.device('cpu')))
         
     | 
| 94 | 
         
            +
                img = cv2.imread('condition/car_448_768.jpg')
         
     | 
| 95 | 
         
            +
                img = torch.from_numpy(img).permute(2,0,1).unsqueeze(0).repeat(8,1,1,1).float()
         
     | 
| 96 | 
         
            +
                detected_map = apply_lineart(img)
         
     | 
| 97 | 
         
            +
                print(img.shape, img.max(),img.min(),detected_map.shape, detected_map.max(),detected_map.min())
         
     | 
| 98 | 
         
            +
                cv2.imwrite('condition/example_lineart.jpg', 255*detected_map[0,0].cpu().detach().numpy())
         
     | 
    	
        condition/midas/depth.py
    ADDED
    
    | 
         @@ -0,0 +1,223 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Midas Depth Estimation
         
     | 
| 2 | 
         
            +
            # From https://github.com/isl-org/MiDaS
         
     | 
| 3 | 
         
            +
            # MIT LICENSE
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            import cv2
         
     | 
| 6 | 
         
            +
            import numpy as np
         
     | 
| 7 | 
         
            +
            import torch
         
     | 
| 8 | 
         
            +
            import os
         
     | 
| 9 | 
         
            +
            import sys
         
     | 
| 10 | 
         
            +
            current_directory = os.getcwd()
         
     | 
| 11 | 
         
            +
            sys.path.append(current_directory)
         
     | 
| 12 | 
         
            +
            from einops import rearrange
         
     | 
| 13 | 
         
            +
            # from .api import MiDaSInference
         
     | 
| 14 | 
         
            +
            from condition.utils import annotator_ckpts_path
         
     | 
| 15 | 
         
            +
            from condition.midas.midas.dpt_depth import DPTDepthModel
         
     | 
| 16 | 
         
            +
            from condition.midas.midas.midas_net import MidasNet
         
     | 
| 17 | 
         
            +
            from condition.midas.midas.midas_net_custom import MidasNet_small
         
     | 
| 18 | 
         
            +
            from condition.midas.midas.transforms import Resize, NormalizeImage, PrepareForNet
         
     | 
| 19 | 
         
            +
            import os
         
     | 
| 20 | 
         
            +
            import torch.nn as nn
         
     | 
| 21 | 
         
            +
            from torchvision.transforms import Compose
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
            ISL_PATHS = {
         
     | 
| 24 | 
         
            +
                "dpt_large": os.path.join(annotator_ckpts_path, "dpt_large-midas-2f21e586.pt"),
         
     | 
| 25 | 
         
            +
                "dpt_hybrid": os.path.join(annotator_ckpts_path, "dpt_hybrid-midas-501f0c75.pt"),
         
     | 
| 26 | 
         
            +
                "midas_v21": "",
         
     | 
| 27 | 
         
            +
                "midas_v21_small": "",
         
     | 
| 28 | 
         
            +
            }
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
            remote_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/dpt_hybrid-midas-501f0c75.pt"
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
            def disabled_train(self, mode=True):
         
     | 
| 34 | 
         
            +
                """Overwrite model.train with this function to make sure train/eval mode
         
     | 
| 35 | 
         
            +
                does not change anymore."""
         
     | 
| 36 | 
         
            +
                return self
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
            def load_midas_transform(model_type):
         
     | 
| 40 | 
         
            +
                # https://github.com/isl-org/MiDaS/blob/master/run.py
         
     | 
| 41 | 
         
            +
                # load transform only
         
     | 
| 42 | 
         
            +
                if model_type == "dpt_large":  # DPT-Large
         
     | 
| 43 | 
         
            +
                    net_w, net_h = 384, 384
         
     | 
| 44 | 
         
            +
                    resize_mode = "minimal"
         
     | 
| 45 | 
         
            +
                    normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
                elif model_type == "dpt_hybrid":  # DPT-Hybrid
         
     | 
| 48 | 
         
            +
                    net_w, net_h = 384, 384
         
     | 
| 49 | 
         
            +
                    resize_mode = "minimal"
         
     | 
| 50 | 
         
            +
                    normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
                elif model_type == "midas_v21":
         
     | 
| 53 | 
         
            +
                    net_w, net_h = 384, 384
         
     | 
| 54 | 
         
            +
                    resize_mode = "upper_bound"
         
     | 
| 55 | 
         
            +
                    normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
                elif model_type == "midas_v21_small":
         
     | 
| 58 | 
         
            +
                    net_w, net_h = 256, 256
         
     | 
| 59 | 
         
            +
                    resize_mode = "upper_bound"
         
     | 
| 60 | 
         
            +
                    normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
                else:
         
     | 
| 63 | 
         
            +
                    assert False, f"model_type '{model_type}' not implemented, use: --model_type large"
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
                transform = Compose(
         
     | 
| 66 | 
         
            +
                    [
         
     | 
| 67 | 
         
            +
                        Resize(
         
     | 
| 68 | 
         
            +
                            net_w,
         
     | 
| 69 | 
         
            +
                            net_h,
         
     | 
| 70 | 
         
            +
                            resize_target=None,
         
     | 
| 71 | 
         
            +
                            keep_aspect_ratio=True,
         
     | 
| 72 | 
         
            +
                            ensure_multiple_of=32,
         
     | 
| 73 | 
         
            +
                            resize_method=resize_mode,
         
     | 
| 74 | 
         
            +
                            image_interpolation_method=cv2.INTER_CUBIC,
         
     | 
| 75 | 
         
            +
                        ),
         
     | 
| 76 | 
         
            +
                        normalization,
         
     | 
| 77 | 
         
            +
                        PrepareForNet(),
         
     | 
| 78 | 
         
            +
                    ]
         
     | 
| 79 | 
         
            +
                )
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
                return transform
         
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
            def load_model(model_type):
         
     | 
| 85 | 
         
            +
                # https://github.com/isl-org/MiDaS/blob/master/run.py
         
     | 
| 86 | 
         
            +
                # load network
         
     | 
| 87 | 
         
            +
                model_path = ISL_PATHS[model_type]
         
     | 
| 88 | 
         
            +
                if model_type == "dpt_large":  # DPT-Large
         
     | 
| 89 | 
         
            +
                    model = DPTDepthModel(
         
     | 
| 90 | 
         
            +
                        path=model_path,
         
     | 
| 91 | 
         
            +
                        backbone="vitl16_384",
         
     | 
| 92 | 
         
            +
                        non_negative=True,
         
     | 
| 93 | 
         
            +
                    )
         
     | 
| 94 | 
         
            +
                    net_w, net_h = 384, 384
         
     | 
| 95 | 
         
            +
                    resize_mode = "minimal"
         
     | 
| 96 | 
         
            +
                    normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
         
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
            +
                elif model_type == "dpt_hybrid":  # DPT-Hybrid
         
     | 
| 99 | 
         
            +
                    if not os.path.exists(model_path):
         
     | 
| 100 | 
         
            +
                        from basicsr.utils.download_util import load_file_from_url
         
     | 
| 101 | 
         
            +
                        load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path)
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
                    model = DPTDepthModel(
         
     | 
| 104 | 
         
            +
                        path=model_path,
         
     | 
| 105 | 
         
            +
                        backbone="vitb_rn50_384",
         
     | 
| 106 | 
         
            +
                        non_negative=True,
         
     | 
| 107 | 
         
            +
                    )
         
     | 
| 108 | 
         
            +
                    net_w, net_h = 384, 384
         
     | 
| 109 | 
         
            +
                    resize_mode = "minimal"
         
     | 
| 110 | 
         
            +
                    normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
                elif model_type == "midas_v21":
         
     | 
| 113 | 
         
            +
                    model = MidasNet(model_path, non_negative=True)
         
     | 
| 114 | 
         
            +
                    net_w, net_h = 384, 384
         
     | 
| 115 | 
         
            +
                    resize_mode = "upper_bound"
         
     | 
| 116 | 
         
            +
                    normalization = NormalizeImage(
         
     | 
| 117 | 
         
            +
                        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
         
     | 
| 118 | 
         
            +
                    )
         
     | 
| 119 | 
         
            +
             
     | 
| 120 | 
         
            +
                elif model_type == "midas_v21_small":
         
     | 
| 121 | 
         
            +
                    model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True,
         
     | 
| 122 | 
         
            +
                                           non_negative=True, blocks={'expand': True})
         
     | 
| 123 | 
         
            +
                    net_w, net_h = 256, 256
         
     | 
| 124 | 
         
            +
                    resize_mode = "upper_bound"
         
     | 
| 125 | 
         
            +
                    normalization = NormalizeImage(
         
     | 
| 126 | 
         
            +
                        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
         
     | 
| 127 | 
         
            +
                    )
         
     | 
| 128 | 
         
            +
             
     | 
| 129 | 
         
            +
                else:
         
     | 
| 130 | 
         
            +
                    print(f"model_type '{model_type}' not implemented, use: --model_type large")
         
     | 
| 131 | 
         
            +
                    assert False
         
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
                transform = Compose(
         
     | 
| 134 | 
         
            +
                    [
         
     | 
| 135 | 
         
            +
                        Resize(
         
     | 
| 136 | 
         
            +
                            net_w,
         
     | 
| 137 | 
         
            +
                            net_h,
         
     | 
| 138 | 
         
            +
                            resize_target=None,
         
     | 
| 139 | 
         
            +
                            keep_aspect_ratio=True,
         
     | 
| 140 | 
         
            +
                            ensure_multiple_of=32,
         
     | 
| 141 | 
         
            +
                            resize_method=resize_mode,
         
     | 
| 142 | 
         
            +
                            image_interpolation_method=cv2.INTER_CUBIC,
         
     | 
| 143 | 
         
            +
                        ),
         
     | 
| 144 | 
         
            +
                        normalization,
         
     | 
| 145 | 
         
            +
                        PrepareForNet(),
         
     | 
| 146 | 
         
            +
                    ]
         
     | 
| 147 | 
         
            +
                )
         
     | 
| 148 | 
         
            +
             
     | 
| 149 | 
         
            +
                return model.eval(), transform
         
     | 
| 150 | 
         
            +
             
     | 
| 151 | 
         
            +
             
     | 
| 152 | 
         
            +
            class MiDaSInference(nn.Module):
         
     | 
| 153 | 
         
            +
                MODEL_TYPES_TORCH_HUB = [
         
     | 
| 154 | 
         
            +
                    "DPT_Large",
         
     | 
| 155 | 
         
            +
                    "DPT_Hybrid",
         
     | 
| 156 | 
         
            +
                    "MiDaS_small"
         
     | 
| 157 | 
         
            +
                ]
         
     | 
| 158 | 
         
            +
                MODEL_TYPES_ISL = [
         
     | 
| 159 | 
         
            +
                    "dpt_large",
         
     | 
| 160 | 
         
            +
                    "dpt_hybrid",
         
     | 
| 161 | 
         
            +
                    "midas_v21",
         
     | 
| 162 | 
         
            +
                    "midas_v21_small",
         
     | 
| 163 | 
         
            +
                ]
         
     | 
| 164 | 
         
            +
             
     | 
| 165 | 
         
            +
                def __init__(self, model_type):
         
     | 
| 166 | 
         
            +
                    super().__init__()
         
     | 
| 167 | 
         
            +
                    assert (model_type in self.MODEL_TYPES_ISL)
         
     | 
| 168 | 
         
            +
                    model, _ = load_model(model_type)
         
     | 
| 169 | 
         
            +
                    self.model = model
         
     | 
| 170 | 
         
            +
                    self.model.train = disabled_train
         
     | 
| 171 | 
         
            +
             
     | 
| 172 | 
         
            +
                def forward(self, x):
         
     | 
| 173 | 
         
            +
                    with torch.no_grad():
         
     | 
| 174 | 
         
            +
                        prediction = self.model(x)
         
     | 
| 175 | 
         
            +
                    return prediction
         
     | 
| 176 | 
         
            +
             
     | 
| 177 | 
         
            +
             
     | 
| 178 | 
         
            +
            class MidasDetector:
         
     | 
| 179 | 
         
            +
                def __init__(self,device=torch.device('cuda:0'), model_type="dpt_hybrid"):
         
     | 
| 180 | 
         
            +
                    self.device = device
         
     | 
| 181 | 
         
            +
                    self.model = MiDaSInference(model_type=model_type).to(device)
         
     | 
| 182 | 
         
            +
             
     | 
| 183 | 
         
            +
                def __call__(self, input_image, a=np.pi * 2.0, bg_th=0.1):
         
     | 
| 184 | 
         
            +
                    assert input_image.ndim == 3
         
     | 
| 185 | 
         
            +
                    image_depth = input_image
         
     | 
| 186 | 
         
            +
                    with torch.no_grad():
         
     | 
| 187 | 
         
            +
                        image_depth = image_depth
         
     | 
| 188 | 
         
            +
                        image_depth = image_depth / 127.5 - 1.0
         
     | 
| 189 | 
         
            +
                        image_depth = rearrange(image_depth, 'h w c -> 1 c h w')
         
     | 
| 190 | 
         
            +
                        depth = self.model(image_depth)[0]
         
     | 
| 191 | 
         
            +
             
     | 
| 192 | 
         
            +
                        depth_pt = depth.clone()
         
     | 
| 193 | 
         
            +
                        depth_pt -= torch.min(depth_pt)
         
     | 
| 194 | 
         
            +
                        depth_pt /= torch.max(depth_pt)
         
     | 
| 195 | 
         
            +
                        depth_pt = depth_pt.cpu().numpy()
         
     | 
| 196 | 
         
            +
                        depth_image = (depth_pt * 255.0).clip(0, 255).astype(np.uint8)
         
     | 
| 197 | 
         
            +
             
     | 
| 198 | 
         
            +
                        depth_np = depth.cpu().numpy()
         
     | 
| 199 | 
         
            +
                        x = cv2.Sobel(depth_np, cv2.CV_32F, 1, 0, ksize=3)
         
     | 
| 200 | 
         
            +
                        y = cv2.Sobel(depth_np, cv2.CV_32F, 0, 1, ksize=3)
         
     | 
| 201 | 
         
            +
                        z = np.ones_like(x) * a
         
     | 
| 202 | 
         
            +
                        x[depth_pt < bg_th] = 0
         
     | 
| 203 | 
         
            +
                        y[depth_pt < bg_th] = 0
         
     | 
| 204 | 
         
            +
                        # normal = np.stack([x, y, z], axis=2)
         
     | 
| 205 | 
         
            +
                        # normal /= np.sum(normal ** 2.0, axis=2, keepdims=True) ** 0.5
         
     | 
| 206 | 
         
            +
                        # normal_image = (normal * 127.5 + 127.5).clip(0, 255).astype(np.uint8)
         
     | 
| 207 | 
         
            +
             
     | 
| 208 | 
         
            +
                        return depth_image#, normal_image
         
     | 
| 209 | 
         
            +
                    
         
     | 
| 210 | 
         
            +
            if __name__ == '__main__':
         
     | 
| 211 | 
         
            +
                import matplotlib.pyplot as plt
         
     | 
| 212 | 
         
            +
                from tqdm import tqdm
         
     | 
| 213 | 
         
            +
                from PIL import Image
         
     | 
| 214 | 
         
            +
                import torchvision.transforms.functional as F
         
     | 
| 215 | 
         
            +
                apply_depth = MidasDetector(device=torch.device('cuda:0'))
         
     | 
| 216 | 
         
            +
                img = cv2.imread('/data/vjuicefs_sz_cv_v2/11171709/ControlAR_github/condition/example/t2i/multi_resolution/car_1_448_768.jpg')
         
     | 
| 217 | 
         
            +
                img = cv2.resize(img,(768,448))
         
     | 
| 218 | 
         
            +
                detected_map = apply_depth(torch.from_numpy(img).cuda().float())
         
     | 
| 219 | 
         
            +
                print(img.shape, img.max(),img.min(),detected_map.shape, detected_map.max(),detected_map.min())
         
     | 
| 220 | 
         
            +
                plt.imshow(detected_map, cmap='gray')
         
     | 
| 221 | 
         
            +
                plt.show()
         
     | 
| 222 | 
         
            +
                cv2.imwrite('condition/example_depth.jpg', detected_map)
         
     | 
| 223 | 
         
            +
                # cv2.imwrite('condition/example_normal.jpg', normal_map)
         
     | 
    	
        condition/midas/midas/__init__.py
    ADDED
    
    | 
         
            File without changes
         
     | 
    	
        condition/midas/midas/base_model.py
    ADDED
    
    | 
         @@ -0,0 +1,16 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import torch
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            class BaseModel(torch.nn.Module):
         
     | 
| 5 | 
         
            +
                def load(self, path):
         
     | 
| 6 | 
         
            +
                    """Load model from file.
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
                    Args:
         
     | 
| 9 | 
         
            +
                        path (str): file path
         
     | 
| 10 | 
         
            +
                    """
         
     | 
| 11 | 
         
            +
                    parameters = torch.load(path, map_location=torch.device('cpu'))
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
                    if "optimizer" in parameters:
         
     | 
| 14 | 
         
            +
                        parameters = parameters["model"]
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
                    self.load_state_dict(parameters)
         
     | 
    	
        condition/midas/midas/blocks.py
    ADDED
    
    | 
         @@ -0,0 +1,341 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import torch
         
     | 
| 2 | 
         
            +
            import torch.nn as nn
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            from .vit import (
         
     | 
| 5 | 
         
            +
                _make_pretrained_vitb_rn50_384,
         
     | 
| 6 | 
         
            +
                _make_pretrained_vitl16_384,
         
     | 
| 7 | 
         
            +
                _make_pretrained_vitb16_384,
         
     | 
| 8 | 
         
            +
                forward_vit,
         
     | 
| 9 | 
         
            +
            )
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",):
         
     | 
| 12 | 
         
            +
                if backbone == "vitl16_384":
         
     | 
| 13 | 
         
            +
                    pretrained = _make_pretrained_vitl16_384(
         
     | 
| 14 | 
         
            +
                        use_pretrained, hooks=hooks, use_readout=use_readout
         
     | 
| 15 | 
         
            +
                    )
         
     | 
| 16 | 
         
            +
                    scratch = _make_scratch(
         
     | 
| 17 | 
         
            +
                        [256, 512, 1024, 1024], features, groups=groups, expand=expand
         
     | 
| 18 | 
         
            +
                    )  # ViT-L/16 - 85.0% Top1 (backbone)
         
     | 
| 19 | 
         
            +
                elif backbone == "vitb_rn50_384":
         
     | 
| 20 | 
         
            +
                    pretrained = _make_pretrained_vitb_rn50_384(
         
     | 
| 21 | 
         
            +
                        use_pretrained,
         
     | 
| 22 | 
         
            +
                        hooks=hooks,
         
     | 
| 23 | 
         
            +
                        use_vit_only=use_vit_only,
         
     | 
| 24 | 
         
            +
                        use_readout=use_readout,
         
     | 
| 25 | 
         
            +
                    )
         
     | 
| 26 | 
         
            +
                    scratch = _make_scratch(
         
     | 
| 27 | 
         
            +
                        [256, 512, 768, 768], features, groups=groups, expand=expand
         
     | 
| 28 | 
         
            +
                    )  # ViT-H/16 - 85.0% Top1 (backbone)
         
     | 
| 29 | 
         
            +
                elif backbone == "vitb16_384":
         
     | 
| 30 | 
         
            +
                    pretrained = _make_pretrained_vitb16_384(
         
     | 
| 31 | 
         
            +
                        use_pretrained, hooks=hooks, use_readout=use_readout
         
     | 
| 32 | 
         
            +
                    )
         
     | 
| 33 | 
         
            +
                    scratch = _make_scratch(
         
     | 
| 34 | 
         
            +
                        [96, 192, 384, 768], features, groups=groups, expand=expand
         
     | 
| 35 | 
         
            +
                    )  # ViT-B/16 - 84.6% Top1 (backbone)
         
     | 
| 36 | 
         
            +
                elif backbone == "resnext101_wsl":
         
     | 
| 37 | 
         
            +
                    pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
         
     | 
| 38 | 
         
            +
                    scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand)     # efficientnet_lite3  
         
     | 
| 39 | 
         
            +
                elif backbone == "efficientnet_lite3":
         
     | 
| 40 | 
         
            +
                    pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable)
         
     | 
| 41 | 
         
            +
                    scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand)  # efficientnet_lite3     
         
     | 
| 42 | 
         
            +
                else:
         
     | 
| 43 | 
         
            +
                    print(f"Backbone '{backbone}' not implemented")
         
     | 
| 44 | 
         
            +
                    assert False
         
     | 
| 45 | 
         
            +
                    
         
     | 
| 46 | 
         
            +
                return pretrained, scratch
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
            def _make_scratch(in_shape, out_shape, groups=1, expand=False):
         
     | 
| 50 | 
         
            +
                scratch = nn.Module()
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
                out_shape1 = out_shape
         
     | 
| 53 | 
         
            +
                out_shape2 = out_shape
         
     | 
| 54 | 
         
            +
                out_shape3 = out_shape
         
     | 
| 55 | 
         
            +
                out_shape4 = out_shape
         
     | 
| 56 | 
         
            +
                if expand==True:
         
     | 
| 57 | 
         
            +
                    out_shape1 = out_shape
         
     | 
| 58 | 
         
            +
                    out_shape2 = out_shape*2
         
     | 
| 59 | 
         
            +
                    out_shape3 = out_shape*4
         
     | 
| 60 | 
         
            +
                    out_shape4 = out_shape*8
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
                scratch.layer1_rn = nn.Conv2d(
         
     | 
| 63 | 
         
            +
                    in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
         
     | 
| 64 | 
         
            +
                )
         
     | 
| 65 | 
         
            +
                scratch.layer2_rn = nn.Conv2d(
         
     | 
| 66 | 
         
            +
                    in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
         
     | 
| 67 | 
         
            +
                )
         
     | 
| 68 | 
         
            +
                scratch.layer3_rn = nn.Conv2d(
         
     | 
| 69 | 
         
            +
                    in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
         
     | 
| 70 | 
         
            +
                )
         
     | 
| 71 | 
         
            +
                scratch.layer4_rn = nn.Conv2d(
         
     | 
| 72 | 
         
            +
                    in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
         
     | 
| 73 | 
         
            +
                )
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
                return scratch
         
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
            def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
         
     | 
| 79 | 
         
            +
                efficientnet = torch.hub.load(
         
     | 
| 80 | 
         
            +
                    "rwightman/gen-efficientnet-pytorch",
         
     | 
| 81 | 
         
            +
                    "tf_efficientnet_lite3",
         
     | 
| 82 | 
         
            +
                    pretrained=use_pretrained,
         
     | 
| 83 | 
         
            +
                    exportable=exportable
         
     | 
| 84 | 
         
            +
                )
         
     | 
| 85 | 
         
            +
                return _make_efficientnet_backbone(efficientnet)
         
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
            def _make_efficientnet_backbone(effnet):
         
     | 
| 89 | 
         
            +
                pretrained = nn.Module()
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
                pretrained.layer1 = nn.Sequential(
         
     | 
| 92 | 
         
            +
                    effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]
         
     | 
| 93 | 
         
            +
                )
         
     | 
| 94 | 
         
            +
                pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
         
     | 
| 95 | 
         
            +
                pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
         
     | 
| 96 | 
         
            +
                pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
         
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
            +
                return pretrained
         
     | 
| 99 | 
         
            +
                
         
     | 
| 100 | 
         
            +
             
     | 
| 101 | 
         
            +
            def _make_resnet_backbone(resnet):
         
     | 
| 102 | 
         
            +
                pretrained = nn.Module()
         
     | 
| 103 | 
         
            +
                pretrained.layer1 = nn.Sequential(
         
     | 
| 104 | 
         
            +
                    resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
         
     | 
| 105 | 
         
            +
                )
         
     | 
| 106 | 
         
            +
             
     | 
| 107 | 
         
            +
                pretrained.layer2 = resnet.layer2
         
     | 
| 108 | 
         
            +
                pretrained.layer3 = resnet.layer3
         
     | 
| 109 | 
         
            +
                pretrained.layer4 = resnet.layer4
         
     | 
| 110 | 
         
            +
             
     | 
| 111 | 
         
            +
                return pretrained
         
     | 
| 112 | 
         
            +
             
     | 
| 113 | 
         
            +
             
     | 
| 114 | 
         
            +
            def _make_pretrained_resnext101_wsl(use_pretrained):
         
     | 
| 115 | 
         
            +
                resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
         
     | 
| 116 | 
         
            +
                return _make_resnet_backbone(resnet)
         
     | 
| 117 | 
         
            +
             
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
             
     | 
| 120 | 
         
            +
            class Interpolate(nn.Module):
         
     | 
| 121 | 
         
            +
                """Interpolation module.
         
     | 
| 122 | 
         
            +
                """
         
     | 
| 123 | 
         
            +
             
     | 
| 124 | 
         
            +
                def __init__(self, scale_factor, mode, align_corners=False):
         
     | 
| 125 | 
         
            +
                    """Init.
         
     | 
| 126 | 
         
            +
             
     | 
| 127 | 
         
            +
                    Args:
         
     | 
| 128 | 
         
            +
                        scale_factor (float): scaling
         
     | 
| 129 | 
         
            +
                        mode (str): interpolation mode
         
     | 
| 130 | 
         
            +
                    """
         
     | 
| 131 | 
         
            +
                    super(Interpolate, self).__init__()
         
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
                    self.interp = nn.functional.interpolate
         
     | 
| 134 | 
         
            +
                    self.scale_factor = scale_factor
         
     | 
| 135 | 
         
            +
                    self.mode = mode
         
     | 
| 136 | 
         
            +
                    self.align_corners = align_corners
         
     | 
| 137 | 
         
            +
             
     | 
| 138 | 
         
            +
                def forward(self, x):
         
     | 
| 139 | 
         
            +
                    """Forward pass.
         
     | 
| 140 | 
         
            +
             
     | 
| 141 | 
         
            +
                    Args:
         
     | 
| 142 | 
         
            +
                        x (tensor): input
         
     | 
| 143 | 
         
            +
             
     | 
| 144 | 
         
            +
                    Returns:
         
     | 
| 145 | 
         
            +
                        tensor: interpolated data
         
     | 
| 146 | 
         
            +
                    """
         
     | 
| 147 | 
         
            +
             
     | 
| 148 | 
         
            +
                    x = self.interp(
         
     | 
| 149 | 
         
            +
                        x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners
         
     | 
| 150 | 
         
            +
                    )
         
     | 
| 151 | 
         
            +
             
     | 
| 152 | 
         
            +
                    return x
         
     | 
| 153 | 
         
            +
             
     | 
| 154 | 
         
            +
             
     | 
| 155 | 
         
            +
            class ResidualConvUnit(nn.Module):
         
     | 
| 156 | 
         
            +
                """Residual convolution module.
         
     | 
| 157 | 
         
            +
                """
         
     | 
| 158 | 
         
            +
             
     | 
| 159 | 
         
            +
                def __init__(self, features):
         
     | 
| 160 | 
         
            +
                    """Init.
         
     | 
| 161 | 
         
            +
             
     | 
| 162 | 
         
            +
                    Args:
         
     | 
| 163 | 
         
            +
                        features (int): number of features
         
     | 
| 164 | 
         
            +
                    """
         
     | 
| 165 | 
         
            +
                    super().__init__()
         
     | 
| 166 | 
         
            +
             
     | 
| 167 | 
         
            +
                    self.conv1 = nn.Conv2d(
         
     | 
| 168 | 
         
            +
                        features, features, kernel_size=3, stride=1, padding=1, bias=True
         
     | 
| 169 | 
         
            +
                    )
         
     | 
| 170 | 
         
            +
             
     | 
| 171 | 
         
            +
                    self.conv2 = nn.Conv2d(
         
     | 
| 172 | 
         
            +
                        features, features, kernel_size=3, stride=1, padding=1, bias=True
         
     | 
| 173 | 
         
            +
                    )
         
     | 
| 174 | 
         
            +
             
     | 
| 175 | 
         
            +
                    self.relu = nn.ReLU(inplace=True)
         
     | 
| 176 | 
         
            +
             
     | 
| 177 | 
         
            +
                def forward(self, x):
         
     | 
| 178 | 
         
            +
                    """Forward pass.
         
     | 
| 179 | 
         
            +
             
     | 
| 180 | 
         
            +
                    Args:
         
     | 
| 181 | 
         
            +
                        x (tensor): input
         
     | 
| 182 | 
         
            +
             
     | 
| 183 | 
         
            +
                    Returns:
         
     | 
| 184 | 
         
            +
                        tensor: output
         
     | 
| 185 | 
         
            +
                    """
         
     | 
| 186 | 
         
            +
                    out = self.relu(x)
         
     | 
| 187 | 
         
            +
                    out = self.conv1(out)
         
     | 
| 188 | 
         
            +
                    out = self.relu(out)
         
     | 
| 189 | 
         
            +
                    out = self.conv2(out)
         
     | 
| 190 | 
         
            +
             
     | 
| 191 | 
         
            +
                    return out + x
         
     | 
| 192 | 
         
            +
             
     | 
| 193 | 
         
            +
             
     | 
| 194 | 
         
            +
            class FeatureFusionBlock(nn.Module):
         
     | 
| 195 | 
         
            +
                """Feature fusion block.
         
     | 
| 196 | 
         
            +
                """
         
     | 
| 197 | 
         
            +
             
     | 
| 198 | 
         
            +
                def __init__(self, features):
         
     | 
| 199 | 
         
            +
                    """Init.
         
     | 
| 200 | 
         
            +
             
     | 
| 201 | 
         
            +
                    Args:
         
     | 
| 202 | 
         
            +
                        features (int): number of features
         
     | 
| 203 | 
         
            +
                    """
         
     | 
| 204 | 
         
            +
                    super(FeatureFusionBlock, self).__init__()
         
     | 
| 205 | 
         
            +
             
     | 
| 206 | 
         
            +
                    self.resConfUnit1 = ResidualConvUnit(features)
         
     | 
| 207 | 
         
            +
                    self.resConfUnit2 = ResidualConvUnit(features)
         
     | 
| 208 | 
         
            +
             
     | 
| 209 | 
         
            +
                def forward(self, *xs):
         
     | 
| 210 | 
         
            +
                    """Forward pass.
         
     | 
| 211 | 
         
            +
             
     | 
| 212 | 
         
            +
                    Returns:
         
     | 
| 213 | 
         
            +
                        tensor: output
         
     | 
| 214 | 
         
            +
                    """
         
     | 
| 215 | 
         
            +
                    output = xs[0]
         
     | 
| 216 | 
         
            +
             
     | 
| 217 | 
         
            +
                    if len(xs) == 2:
         
     | 
| 218 | 
         
            +
                        output += self.resConfUnit1(xs[1])
         
     | 
| 219 | 
         
            +
             
     | 
| 220 | 
         
            +
                    output = self.resConfUnit2(output)
         
     | 
| 221 | 
         
            +
             
     | 
| 222 | 
         
            +
                    output = nn.functional.interpolate(
         
     | 
| 223 | 
         
            +
                        output, scale_factor=2, mode="bilinear", align_corners=True
         
     | 
| 224 | 
         
            +
                    )
         
     | 
| 225 | 
         
            +
             
     | 
| 226 | 
         
            +
                    return output
         
     | 
| 227 | 
         
            +
             
     | 
| 228 | 
         
            +
             
     | 
| 229 | 
         
            +
             
     | 
| 230 | 
         
            +
             
     | 
| 231 | 
         
            +
            class ResidualConvUnit_custom(nn.Module):
         
     | 
| 232 | 
         
            +
                """Residual convolution module.
         
     | 
| 233 | 
         
            +
                """
         
     | 
| 234 | 
         
            +
             
     | 
| 235 | 
         
            +
                def __init__(self, features, activation, bn):
         
     | 
| 236 | 
         
            +
                    """Init.
         
     | 
| 237 | 
         
            +
             
     | 
| 238 | 
         
            +
                    Args:
         
     | 
| 239 | 
         
            +
                        features (int): number of features
         
     | 
| 240 | 
         
            +
                    """
         
     | 
| 241 | 
         
            +
                    super().__init__()
         
     | 
| 242 | 
         
            +
             
     | 
| 243 | 
         
            +
                    self.bn = bn
         
     | 
| 244 | 
         
            +
             
     | 
| 245 | 
         
            +
                    self.groups=1
         
     | 
| 246 | 
         
            +
             
     | 
| 247 | 
         
            +
                    self.conv1 = nn.Conv2d(
         
     | 
| 248 | 
         
            +
                        features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
         
     | 
| 249 | 
         
            +
                    )
         
     | 
| 250 | 
         
            +
                    
         
     | 
| 251 | 
         
            +
                    self.conv2 = nn.Conv2d(
         
     | 
| 252 | 
         
            +
                        features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
         
     | 
| 253 | 
         
            +
                    )
         
     | 
| 254 | 
         
            +
             
     | 
| 255 | 
         
            +
                    if self.bn==True:
         
     | 
| 256 | 
         
            +
                        self.bn1 = nn.BatchNorm2d(features)
         
     | 
| 257 | 
         
            +
                        self.bn2 = nn.BatchNorm2d(features)
         
     | 
| 258 | 
         
            +
             
     | 
| 259 | 
         
            +
                    self.activation = activation
         
     | 
| 260 | 
         
            +
             
     | 
| 261 | 
         
            +
                    self.skip_add = nn.quantized.FloatFunctional()
         
     | 
| 262 | 
         
            +
             
     | 
| 263 | 
         
            +
                def forward(self, x):
         
     | 
| 264 | 
         
            +
                    """Forward pass.
         
     | 
| 265 | 
         
            +
             
     | 
| 266 | 
         
            +
                    Args:
         
     | 
| 267 | 
         
            +
                        x (tensor): input
         
     | 
| 268 | 
         
            +
             
     | 
| 269 | 
         
            +
                    Returns:
         
     | 
| 270 | 
         
            +
                        tensor: output
         
     | 
| 271 | 
         
            +
                    """
         
     | 
| 272 | 
         
            +
                    
         
     | 
| 273 | 
         
            +
                    out = self.activation(x)
         
     | 
| 274 | 
         
            +
                    out = self.conv1(out)
         
     | 
| 275 | 
         
            +
                    if self.bn==True:
         
     | 
| 276 | 
         
            +
                        out = self.bn1(out)
         
     | 
| 277 | 
         
            +
                   
         
     | 
| 278 | 
         
            +
                    out = self.activation(out)
         
     | 
| 279 | 
         
            +
                    out = self.conv2(out)
         
     | 
| 280 | 
         
            +
                    if self.bn==True:
         
     | 
| 281 | 
         
            +
                        out = self.bn2(out)
         
     | 
| 282 | 
         
            +
             
     | 
| 283 | 
         
            +
                    if self.groups > 1:
         
     | 
| 284 | 
         
            +
                        out = self.conv_merge(out)
         
     | 
| 285 | 
         
            +
             
     | 
| 286 | 
         
            +
                    return self.skip_add.add(out, x)
         
     | 
| 287 | 
         
            +
             
     | 
| 288 | 
         
            +
                    # return out + x
         
     | 
| 289 | 
         
            +
             
     | 
| 290 | 
         
            +
             
     | 
| 291 | 
         
            +
            class FeatureFusionBlock_custom(nn.Module):
         
     | 
| 292 | 
         
            +
                """Feature fusion block.
         
     | 
| 293 | 
         
            +
                """
         
     | 
| 294 | 
         
            +
             
     | 
| 295 | 
         
            +
                def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True):
         
     | 
| 296 | 
         
            +
                    """Init.
         
     | 
| 297 | 
         
            +
             
     | 
| 298 | 
         
            +
                    Args:
         
     | 
| 299 | 
         
            +
                        features (int): number of features
         
     | 
| 300 | 
         
            +
                    """
         
     | 
| 301 | 
         
            +
                    super(FeatureFusionBlock_custom, self).__init__()
         
     | 
| 302 | 
         
            +
             
     | 
| 303 | 
         
            +
                    self.deconv = deconv
         
     | 
| 304 | 
         
            +
                    self.align_corners = align_corners
         
     | 
| 305 | 
         
            +
             
     | 
| 306 | 
         
            +
                    self.groups=1
         
     | 
| 307 | 
         
            +
             
     | 
| 308 | 
         
            +
                    self.expand = expand
         
     | 
| 309 | 
         
            +
                    out_features = features
         
     | 
| 310 | 
         
            +
                    if self.expand==True:
         
     | 
| 311 | 
         
            +
                        out_features = features//2
         
     | 
| 312 | 
         
            +
                    
         
     | 
| 313 | 
         
            +
                    self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
         
     | 
| 314 | 
         
            +
             
     | 
| 315 | 
         
            +
                    self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
         
     | 
| 316 | 
         
            +
                    self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
         
     | 
| 317 | 
         
            +
                    
         
     | 
| 318 | 
         
            +
                    self.skip_add = nn.quantized.FloatFunctional()
         
     | 
| 319 | 
         
            +
             
     | 
| 320 | 
         
            +
                def forward(self, *xs):
         
     | 
| 321 | 
         
            +
                    """Forward pass.
         
     | 
| 322 | 
         
            +
             
     | 
| 323 | 
         
            +
                    Returns:
         
     | 
| 324 | 
         
            +
                        tensor: output
         
     | 
| 325 | 
         
            +
                    """
         
     | 
| 326 | 
         
            +
                    output = xs[0]
         
     | 
| 327 | 
         
            +
             
     | 
| 328 | 
         
            +
                    if len(xs) == 2:
         
     | 
| 329 | 
         
            +
                        res = self.resConfUnit1(xs[1])
         
     | 
| 330 | 
         
            +
                        output = self.skip_add.add(output, res)
         
     | 
| 331 | 
         
            +
                        # output += res
         
     | 
| 332 | 
         
            +
             
     | 
| 333 | 
         
            +
                    output = self.resConfUnit2(output)
         
     | 
| 334 | 
         
            +
             
     | 
| 335 | 
         
            +
                    output = nn.functional.interpolate(
         
     | 
| 336 | 
         
            +
                        output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
         
     | 
| 337 | 
         
            +
                    )
         
     | 
| 338 | 
         
            +
             
     | 
| 339 | 
         
            +
                    output = self.out_conv(output)
         
     | 
| 340 | 
         
            +
             
     | 
| 341 | 
         
            +
                    return output
         
     | 
    	
        condition/midas/midas/dpt_depth.py
    ADDED
    
    | 
         @@ -0,0 +1,108 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import torch
         
     | 
| 2 | 
         
            +
            import torch.nn as nn
         
     | 
| 3 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            from .base_model import BaseModel
         
     | 
| 6 | 
         
            +
            from .blocks import (
         
     | 
| 7 | 
         
            +
                FeatureFusionBlock,
         
     | 
| 8 | 
         
            +
                FeatureFusionBlock_custom,
         
     | 
| 9 | 
         
            +
                Interpolate,
         
     | 
| 10 | 
         
            +
                _make_encoder,
         
     | 
| 11 | 
         
            +
                forward_vit,
         
     | 
| 12 | 
         
            +
            )
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            def _make_fusion_block(features, use_bn):
         
     | 
| 16 | 
         
            +
                return FeatureFusionBlock_custom(
         
     | 
| 17 | 
         
            +
                    features,
         
     | 
| 18 | 
         
            +
                    nn.ReLU(False),
         
     | 
| 19 | 
         
            +
                    deconv=False,
         
     | 
| 20 | 
         
            +
                    bn=use_bn,
         
     | 
| 21 | 
         
            +
                    expand=False,
         
     | 
| 22 | 
         
            +
                    align_corners=True,
         
     | 
| 23 | 
         
            +
                )
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
            class DPT(BaseModel):
         
     | 
| 27 | 
         
            +
                def __init__(
         
     | 
| 28 | 
         
            +
                    self,
         
     | 
| 29 | 
         
            +
                    head,
         
     | 
| 30 | 
         
            +
                    features=256,
         
     | 
| 31 | 
         
            +
                    backbone="vitb_rn50_384",
         
     | 
| 32 | 
         
            +
                    readout="project",
         
     | 
| 33 | 
         
            +
                    channels_last=False,
         
     | 
| 34 | 
         
            +
                    use_bn=False,
         
     | 
| 35 | 
         
            +
                ):
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
                    super(DPT, self).__init__()
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
                    self.channels_last = channels_last
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                    hooks = {
         
     | 
| 42 | 
         
            +
                        "vitb_rn50_384": [0, 1, 8, 11],
         
     | 
| 43 | 
         
            +
                        "vitb16_384": [2, 5, 8, 11],
         
     | 
| 44 | 
         
            +
                        "vitl16_384": [5, 11, 17, 23],
         
     | 
| 45 | 
         
            +
                    }
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
                    # Instantiate backbone and reassemble blocks
         
     | 
| 48 | 
         
            +
                    self.pretrained, self.scratch = _make_encoder(
         
     | 
| 49 | 
         
            +
                        backbone,
         
     | 
| 50 | 
         
            +
                        features,
         
     | 
| 51 | 
         
            +
                        False, # Set to true of you want to train from scratch, uses ImageNet weights
         
     | 
| 52 | 
         
            +
                        groups=1,
         
     | 
| 53 | 
         
            +
                        expand=False,
         
     | 
| 54 | 
         
            +
                        exportable=False,
         
     | 
| 55 | 
         
            +
                        hooks=hooks[backbone],
         
     | 
| 56 | 
         
            +
                        use_readout=readout,
         
     | 
| 57 | 
         
            +
                    )
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
                    self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
         
     | 
| 60 | 
         
            +
                    self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
         
     | 
| 61 | 
         
            +
                    self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
         
     | 
| 62 | 
         
            +
                    self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
                    self.scratch.output_conv = head
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
                def forward(self, x):
         
     | 
| 68 | 
         
            +
                    if self.channels_last == True:
         
     | 
| 69 | 
         
            +
                        x.contiguous(memory_format=torch.channels_last)
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
                    layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
                    layer_1_rn = self.scratch.layer1_rn(layer_1)
         
     | 
| 74 | 
         
            +
                    layer_2_rn = self.scratch.layer2_rn(layer_2)
         
     | 
| 75 | 
         
            +
                    layer_3_rn = self.scratch.layer3_rn(layer_3)
         
     | 
| 76 | 
         
            +
                    layer_4_rn = self.scratch.layer4_rn(layer_4)
         
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
                    path_4 = self.scratch.refinenet4(layer_4_rn)
         
     | 
| 79 | 
         
            +
                    path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
         
     | 
| 80 | 
         
            +
                    path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
         
     | 
| 81 | 
         
            +
                    path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
         
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
                    out = self.scratch.output_conv(path_1)
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
                    return out
         
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
            class DPTDepthModel(DPT):
         
     | 
| 89 | 
         
            +
                def __init__(self, path=None, non_negative=True, **kwargs):
         
     | 
| 90 | 
         
            +
                    features = kwargs["features"] if "features" in kwargs else 256
         
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
                    head = nn.Sequential(
         
     | 
| 93 | 
         
            +
                        nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
         
     | 
| 94 | 
         
            +
                        Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
         
     | 
| 95 | 
         
            +
                        nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
         
     | 
| 96 | 
         
            +
                        nn.ReLU(True),
         
     | 
| 97 | 
         
            +
                        nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
         
     | 
| 98 | 
         
            +
                        nn.ReLU(True) if non_negative else nn.Identity(),
         
     | 
| 99 | 
         
            +
                        nn.Identity(),
         
     | 
| 100 | 
         
            +
                    )
         
     | 
| 101 | 
         
            +
             
     | 
| 102 | 
         
            +
                    super().__init__(head, **kwargs)
         
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
                    if path is not None:
         
     | 
| 105 | 
         
            +
                       self.load(path)
         
     | 
| 106 | 
         
            +
             
     | 
| 107 | 
         
            +
                def forward(self, x):
         
     | 
| 108 | 
         
            +
                    return super().forward(x).squeeze(dim=1)
         
     | 
    	
        condition/midas/midas/midas_net.py
    ADDED
    
    | 
         @@ -0,0 +1,76 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            """MidashNet: Network for monocular depth estimation trained by mixing several datasets.
         
     | 
| 2 | 
         
            +
            This file contains code that is adapted from
         
     | 
| 3 | 
         
            +
            https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
         
     | 
| 4 | 
         
            +
            """
         
     | 
| 5 | 
         
            +
            import torch
         
     | 
| 6 | 
         
            +
            import torch.nn as nn
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            from .base_model import BaseModel
         
     | 
| 9 | 
         
            +
            from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            class MidasNet(BaseModel):
         
     | 
| 13 | 
         
            +
                """Network for monocular depth estimation.
         
     | 
| 14 | 
         
            +
                """
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
                def __init__(self, path=None, features=256, non_negative=True):
         
     | 
| 17 | 
         
            +
                    """Init.
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
                    Args:
         
     | 
| 20 | 
         
            +
                        path (str, optional): Path to saved model. Defaults to None.
         
     | 
| 21 | 
         
            +
                        features (int, optional): Number of features. Defaults to 256.
         
     | 
| 22 | 
         
            +
                        backbone (str, optional): Backbone network for encoder. Defaults to resnet50
         
     | 
| 23 | 
         
            +
                    """
         
     | 
| 24 | 
         
            +
                    print("Loading weights: ", path)
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
                    super(MidasNet, self).__init__()
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
                    use_pretrained = False if path is None else True
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
                    self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained)
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
                    self.scratch.refinenet4 = FeatureFusionBlock(features)
         
     | 
| 33 | 
         
            +
                    self.scratch.refinenet3 = FeatureFusionBlock(features)
         
     | 
| 34 | 
         
            +
                    self.scratch.refinenet2 = FeatureFusionBlock(features)
         
     | 
| 35 | 
         
            +
                    self.scratch.refinenet1 = FeatureFusionBlock(features)
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
                    self.scratch.output_conv = nn.Sequential(
         
     | 
| 38 | 
         
            +
                        nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
         
     | 
| 39 | 
         
            +
                        Interpolate(scale_factor=2, mode="bilinear"),
         
     | 
| 40 | 
         
            +
                        nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
         
     | 
| 41 | 
         
            +
                        nn.ReLU(True),
         
     | 
| 42 | 
         
            +
                        nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
         
     | 
| 43 | 
         
            +
                        nn.ReLU(True) if non_negative else nn.Identity(),
         
     | 
| 44 | 
         
            +
                    )
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
                    if path:
         
     | 
| 47 | 
         
            +
                        self.load(path)
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
                def forward(self, x):
         
     | 
| 50 | 
         
            +
                    """Forward pass.
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
                    Args:
         
     | 
| 53 | 
         
            +
                        x (tensor): input data (image)
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
                    Returns:
         
     | 
| 56 | 
         
            +
                        tensor: depth
         
     | 
| 57 | 
         
            +
                    """
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
                    layer_1 = self.pretrained.layer1(x)
         
     | 
| 60 | 
         
            +
                    layer_2 = self.pretrained.layer2(layer_1)
         
     | 
| 61 | 
         
            +
                    layer_3 = self.pretrained.layer3(layer_2)
         
     | 
| 62 | 
         
            +
                    layer_4 = self.pretrained.layer4(layer_3)
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
                    layer_1_rn = self.scratch.layer1_rn(layer_1)
         
     | 
| 65 | 
         
            +
                    layer_2_rn = self.scratch.layer2_rn(layer_2)
         
     | 
| 66 | 
         
            +
                    layer_3_rn = self.scratch.layer3_rn(layer_3)
         
     | 
| 67 | 
         
            +
                    layer_4_rn = self.scratch.layer4_rn(layer_4)
         
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
            +
                    path_4 = self.scratch.refinenet4(layer_4_rn)
         
     | 
| 70 | 
         
            +
                    path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
         
     | 
| 71 | 
         
            +
                    path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
         
     | 
| 72 | 
         
            +
                    path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
                    out = self.scratch.output_conv(path_1)
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
                    return torch.squeeze(out, dim=1)
         
     | 
    	
        condition/midas/midas/midas_net_custom.py
    ADDED
    
    | 
         @@ -0,0 +1,128 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            """MidashNet: Network for monocular depth estimation trained by mixing several datasets.
         
     | 
| 2 | 
         
            +
            This file contains code that is adapted from
         
     | 
| 3 | 
         
            +
            https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
         
     | 
| 4 | 
         
            +
            """
         
     | 
| 5 | 
         
            +
            import torch
         
     | 
| 6 | 
         
            +
            import torch.nn as nn
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            from .base_model import BaseModel
         
     | 
| 9 | 
         
            +
            from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            class MidasNet_small(BaseModel):
         
     | 
| 13 | 
         
            +
                """Network for monocular depth estimation.
         
     | 
| 14 | 
         
            +
                """
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
                def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True,
         
     | 
| 17 | 
         
            +
                    blocks={'expand': True}):
         
     | 
| 18 | 
         
            +
                    """Init.
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
                    Args:
         
     | 
| 21 | 
         
            +
                        path (str, optional): Path to saved model. Defaults to None.
         
     | 
| 22 | 
         
            +
                        features (int, optional): Number of features. Defaults to 256.
         
     | 
| 23 | 
         
            +
                        backbone (str, optional): Backbone network for encoder. Defaults to resnet50
         
     | 
| 24 | 
         
            +
                    """
         
     | 
| 25 | 
         
            +
                    print("Loading weights: ", path)
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
                    super(MidasNet_small, self).__init__()
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
                    use_pretrained = False if path else True
         
     | 
| 30 | 
         
            +
                            
         
     | 
| 31 | 
         
            +
                    self.channels_last = channels_last
         
     | 
| 32 | 
         
            +
                    self.blocks = blocks
         
     | 
| 33 | 
         
            +
                    self.backbone = backbone
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
                    self.groups = 1
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
                    features1=features
         
     | 
| 38 | 
         
            +
                    features2=features
         
     | 
| 39 | 
         
            +
                    features3=features
         
     | 
| 40 | 
         
            +
                    features4=features
         
     | 
| 41 | 
         
            +
                    self.expand = False
         
     | 
| 42 | 
         
            +
                    if "expand" in self.blocks and self.blocks['expand'] == True:
         
     | 
| 43 | 
         
            +
                        self.expand = True
         
     | 
| 44 | 
         
            +
                        features1=features
         
     | 
| 45 | 
         
            +
                        features2=features*2
         
     | 
| 46 | 
         
            +
                        features3=features*4
         
     | 
| 47 | 
         
            +
                        features4=features*8
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
                    self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable)
         
     | 
| 50 | 
         
            +
              
         
     | 
| 51 | 
         
            +
                    self.scratch.activation = nn.ReLU(False)    
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
                    self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
         
     | 
| 54 | 
         
            +
                    self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
         
     | 
| 55 | 
         
            +
                    self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
         
     | 
| 56 | 
         
            +
                    self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners)
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
                    
         
     | 
| 59 | 
         
            +
                    self.scratch.output_conv = nn.Sequential(
         
     | 
| 60 | 
         
            +
                        nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups),
         
     | 
| 61 | 
         
            +
                        Interpolate(scale_factor=2, mode="bilinear"),
         
     | 
| 62 | 
         
            +
                        nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1),
         
     | 
| 63 | 
         
            +
                        self.scratch.activation,
         
     | 
| 64 | 
         
            +
                        nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
         
     | 
| 65 | 
         
            +
                        nn.ReLU(True) if non_negative else nn.Identity(),
         
     | 
| 66 | 
         
            +
                        nn.Identity(),
         
     | 
| 67 | 
         
            +
                    )
         
     | 
| 68 | 
         
            +
                    
         
     | 
| 69 | 
         
            +
                    if path:
         
     | 
| 70 | 
         
            +
                        self.load(path)
         
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
                def forward(self, x):
         
     | 
| 74 | 
         
            +
                    """Forward pass.
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
                    Args:
         
     | 
| 77 | 
         
            +
                        x (tensor): input data (image)
         
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
                    Returns:
         
     | 
| 80 | 
         
            +
                        tensor: depth
         
     | 
| 81 | 
         
            +
                    """
         
     | 
| 82 | 
         
            +
                    if self.channels_last==True:
         
     | 
| 83 | 
         
            +
                        print("self.channels_last = ", self.channels_last)
         
     | 
| 84 | 
         
            +
                        x.contiguous(memory_format=torch.channels_last)
         
     | 
| 85 | 
         
            +
             
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
                    layer_1 = self.pretrained.layer1(x)
         
     | 
| 88 | 
         
            +
                    layer_2 = self.pretrained.layer2(layer_1)
         
     | 
| 89 | 
         
            +
                    layer_3 = self.pretrained.layer3(layer_2)
         
     | 
| 90 | 
         
            +
                    layer_4 = self.pretrained.layer4(layer_3)
         
     | 
| 91 | 
         
            +
                    
         
     | 
| 92 | 
         
            +
                    layer_1_rn = self.scratch.layer1_rn(layer_1)
         
     | 
| 93 | 
         
            +
                    layer_2_rn = self.scratch.layer2_rn(layer_2)
         
     | 
| 94 | 
         
            +
                    layer_3_rn = self.scratch.layer3_rn(layer_3)
         
     | 
| 95 | 
         
            +
                    layer_4_rn = self.scratch.layer4_rn(layer_4)
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
            +
                    path_4 = self.scratch.refinenet4(layer_4_rn)
         
     | 
| 99 | 
         
            +
                    path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
         
     | 
| 100 | 
         
            +
                    path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
         
     | 
| 101 | 
         
            +
                    path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
         
     | 
| 102 | 
         
            +
                    
         
     | 
| 103 | 
         
            +
                    out = self.scratch.output_conv(path_1)
         
     | 
| 104 | 
         
            +
             
     | 
| 105 | 
         
            +
                    return torch.squeeze(out, dim=1)
         
     | 
| 106 | 
         
            +
             
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
             
     | 
| 109 | 
         
            +
            def fuse_model(m):
         
     | 
| 110 | 
         
            +
                prev_previous_type = nn.Identity()
         
     | 
| 111 | 
         
            +
                prev_previous_name = ''
         
     | 
| 112 | 
         
            +
                previous_type = nn.Identity()
         
     | 
| 113 | 
         
            +
                previous_name = ''
         
     | 
| 114 | 
         
            +
                for name, module in m.named_modules():
         
     | 
| 115 | 
         
            +
                    if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU:
         
     | 
| 116 | 
         
            +
                        # print("FUSED ", prev_previous_name, previous_name, name)
         
     | 
| 117 | 
         
            +
                        torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True)
         
     | 
| 118 | 
         
            +
                    elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d:
         
     | 
| 119 | 
         
            +
                        # print("FUSED ", prev_previous_name, previous_name)
         
     | 
| 120 | 
         
            +
                        torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True)
         
     | 
| 121 | 
         
            +
                    # elif previous_type == nn.Conv2d and type(module) == nn.ReLU:
         
     | 
| 122 | 
         
            +
                    #    print("FUSED ", previous_name, name)
         
     | 
| 123 | 
         
            +
                    #    torch.quantization.fuse_modules(m, [previous_name, name], inplace=True)
         
     | 
| 124 | 
         
            +
             
     | 
| 125 | 
         
            +
                    prev_previous_type = previous_type
         
     | 
| 126 | 
         
            +
                    prev_previous_name = previous_name
         
     | 
| 127 | 
         
            +
                    previous_type = type(module)
         
     | 
| 128 | 
         
            +
                    previous_name = name
         
     | 
    	
        condition/midas/midas/transforms.py
    ADDED
    
    | 
         @@ -0,0 +1,234 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import numpy as np
         
     | 
| 2 | 
         
            +
            import cv2
         
     | 
| 3 | 
         
            +
            import math
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
         
     | 
| 7 | 
         
            +
                """Rezise the sample to ensure the given size. Keeps aspect ratio.
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
                Args:
         
     | 
| 10 | 
         
            +
                    sample (dict): sample
         
     | 
| 11 | 
         
            +
                    size (tuple): image size
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
                Returns:
         
     | 
| 14 | 
         
            +
                    tuple: new size
         
     | 
| 15 | 
         
            +
                """
         
     | 
| 16 | 
         
            +
                shape = list(sample["disparity"].shape)
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
                if shape[0] >= size[0] and shape[1] >= size[1]:
         
     | 
| 19 | 
         
            +
                    return sample
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
                scale = [0, 0]
         
     | 
| 22 | 
         
            +
                scale[0] = size[0] / shape[0]
         
     | 
| 23 | 
         
            +
                scale[1] = size[1] / shape[1]
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
                scale = max(scale)
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
                shape[0] = math.ceil(scale * shape[0])
         
     | 
| 28 | 
         
            +
                shape[1] = math.ceil(scale * shape[1])
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
                # resize
         
     | 
| 31 | 
         
            +
                sample["image"] = cv2.resize(
         
     | 
| 32 | 
         
            +
                    sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
         
     | 
| 33 | 
         
            +
                )
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
                sample["disparity"] = cv2.resize(
         
     | 
| 36 | 
         
            +
                    sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
         
     | 
| 37 | 
         
            +
                )
         
     | 
| 38 | 
         
            +
                sample["mask"] = cv2.resize(
         
     | 
| 39 | 
         
            +
                    sample["mask"].astype(np.float32),
         
     | 
| 40 | 
         
            +
                    tuple(shape[::-1]),
         
     | 
| 41 | 
         
            +
                    interpolation=cv2.INTER_NEAREST,
         
     | 
| 42 | 
         
            +
                )
         
     | 
| 43 | 
         
            +
                sample["mask"] = sample["mask"].astype(bool)
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
                return tuple(shape)
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
            class Resize(object):
         
     | 
| 49 | 
         
            +
                """Resize sample to given size (width, height).
         
     | 
| 50 | 
         
            +
                """
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
                def __init__(
         
     | 
| 53 | 
         
            +
                    self,
         
     | 
| 54 | 
         
            +
                    width,
         
     | 
| 55 | 
         
            +
                    height,
         
     | 
| 56 | 
         
            +
                    resize_target=True,
         
     | 
| 57 | 
         
            +
                    keep_aspect_ratio=False,
         
     | 
| 58 | 
         
            +
                    ensure_multiple_of=1,
         
     | 
| 59 | 
         
            +
                    resize_method="lower_bound",
         
     | 
| 60 | 
         
            +
                    image_interpolation_method=cv2.INTER_AREA,
         
     | 
| 61 | 
         
            +
                ):
         
     | 
| 62 | 
         
            +
                    """Init.
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
                    Args:
         
     | 
| 65 | 
         
            +
                        width (int): desired output width
         
     | 
| 66 | 
         
            +
                        height (int): desired output height
         
     | 
| 67 | 
         
            +
                        resize_target (bool, optional):
         
     | 
| 68 | 
         
            +
                            True: Resize the full sample (image, mask, target).
         
     | 
| 69 | 
         
            +
                            False: Resize image only.
         
     | 
| 70 | 
         
            +
                            Defaults to True.
         
     | 
| 71 | 
         
            +
                        keep_aspect_ratio (bool, optional):
         
     | 
| 72 | 
         
            +
                            True: Keep the aspect ratio of the input sample.
         
     | 
| 73 | 
         
            +
                            Output sample might not have the given width and height, and
         
     | 
| 74 | 
         
            +
                            resize behaviour depends on the parameter 'resize_method'.
         
     | 
| 75 | 
         
            +
                            Defaults to False.
         
     | 
| 76 | 
         
            +
                        ensure_multiple_of (int, optional):
         
     | 
| 77 | 
         
            +
                            Output width and height is constrained to be multiple of this parameter.
         
     | 
| 78 | 
         
            +
                            Defaults to 1.
         
     | 
| 79 | 
         
            +
                        resize_method (str, optional):
         
     | 
| 80 | 
         
            +
                            "lower_bound": Output will be at least as large as the given size.
         
     | 
| 81 | 
         
            +
                            "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
         
     | 
| 82 | 
         
            +
                            "minimal": Scale as least as possible.  (Output size might be smaller than given size.)
         
     | 
| 83 | 
         
            +
                            Defaults to "lower_bound".
         
     | 
| 84 | 
         
            +
                    """
         
     | 
| 85 | 
         
            +
                    self.__width = width
         
     | 
| 86 | 
         
            +
                    self.__height = height
         
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
                    self.__resize_target = resize_target
         
     | 
| 89 | 
         
            +
                    self.__keep_aspect_ratio = keep_aspect_ratio
         
     | 
| 90 | 
         
            +
                    self.__multiple_of = ensure_multiple_of
         
     | 
| 91 | 
         
            +
                    self.__resize_method = resize_method
         
     | 
| 92 | 
         
            +
                    self.__image_interpolation_method = image_interpolation_method
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
                def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
         
     | 
| 95 | 
         
            +
                    y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
                    if max_val is not None and y > max_val:
         
     | 
| 98 | 
         
            +
                        y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
                    if y < min_val:
         
     | 
| 101 | 
         
            +
                        y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
                    return y
         
     | 
| 104 | 
         
            +
             
     | 
| 105 | 
         
            +
                def get_size(self, width, height):
         
     | 
| 106 | 
         
            +
                    # determine new height and width
         
     | 
| 107 | 
         
            +
                    scale_height = self.__height / height
         
     | 
| 108 | 
         
            +
                    scale_width = self.__width / width
         
     | 
| 109 | 
         
            +
             
     | 
| 110 | 
         
            +
                    if self.__keep_aspect_ratio:
         
     | 
| 111 | 
         
            +
                        if self.__resize_method == "lower_bound":
         
     | 
| 112 | 
         
            +
                            # scale such that output size is lower bound
         
     | 
| 113 | 
         
            +
                            if scale_width > scale_height:
         
     | 
| 114 | 
         
            +
                                # fit width
         
     | 
| 115 | 
         
            +
                                scale_height = scale_width
         
     | 
| 116 | 
         
            +
                            else:
         
     | 
| 117 | 
         
            +
                                # fit height
         
     | 
| 118 | 
         
            +
                                scale_width = scale_height
         
     | 
| 119 | 
         
            +
                        elif self.__resize_method == "upper_bound":
         
     | 
| 120 | 
         
            +
                            # scale such that output size is upper bound
         
     | 
| 121 | 
         
            +
                            if scale_width < scale_height:
         
     | 
| 122 | 
         
            +
                                # fit width
         
     | 
| 123 | 
         
            +
                                scale_height = scale_width
         
     | 
| 124 | 
         
            +
                            else:
         
     | 
| 125 | 
         
            +
                                # fit height
         
     | 
| 126 | 
         
            +
                                scale_width = scale_height
         
     | 
| 127 | 
         
            +
                        elif self.__resize_method == "minimal":
         
     | 
| 128 | 
         
            +
                            # scale as least as possbile
         
     | 
| 129 | 
         
            +
                            if abs(1 - scale_width) < abs(1 - scale_height):
         
     | 
| 130 | 
         
            +
                                # fit width
         
     | 
| 131 | 
         
            +
                                scale_height = scale_width
         
     | 
| 132 | 
         
            +
                            else:
         
     | 
| 133 | 
         
            +
                                # fit height
         
     | 
| 134 | 
         
            +
                                scale_width = scale_height
         
     | 
| 135 | 
         
            +
                        else:
         
     | 
| 136 | 
         
            +
                            raise ValueError(
         
     | 
| 137 | 
         
            +
                                f"resize_method {self.__resize_method} not implemented"
         
     | 
| 138 | 
         
            +
                            )
         
     | 
| 139 | 
         
            +
             
     | 
| 140 | 
         
            +
                    if self.__resize_method == "lower_bound":
         
     | 
| 141 | 
         
            +
                        new_height = self.constrain_to_multiple_of(
         
     | 
| 142 | 
         
            +
                            scale_height * height, min_val=self.__height
         
     | 
| 143 | 
         
            +
                        )
         
     | 
| 144 | 
         
            +
                        new_width = self.constrain_to_multiple_of(
         
     | 
| 145 | 
         
            +
                            scale_width * width, min_val=self.__width
         
     | 
| 146 | 
         
            +
                        )
         
     | 
| 147 | 
         
            +
                    elif self.__resize_method == "upper_bound":
         
     | 
| 148 | 
         
            +
                        new_height = self.constrain_to_multiple_of(
         
     | 
| 149 | 
         
            +
                            scale_height * height, max_val=self.__height
         
     | 
| 150 | 
         
            +
                        )
         
     | 
| 151 | 
         
            +
                        new_width = self.constrain_to_multiple_of(
         
     | 
| 152 | 
         
            +
                            scale_width * width, max_val=self.__width
         
     | 
| 153 | 
         
            +
                        )
         
     | 
| 154 | 
         
            +
                    elif self.__resize_method == "minimal":
         
     | 
| 155 | 
         
            +
                        new_height = self.constrain_to_multiple_of(scale_height * height)
         
     | 
| 156 | 
         
            +
                        new_width = self.constrain_to_multiple_of(scale_width * width)
         
     | 
| 157 | 
         
            +
                    else:
         
     | 
| 158 | 
         
            +
                        raise ValueError(f"resize_method {self.__resize_method} not implemented")
         
     | 
| 159 | 
         
            +
             
     | 
| 160 | 
         
            +
                    return (new_width, new_height)
         
     | 
| 161 | 
         
            +
             
     | 
| 162 | 
         
            +
                def __call__(self, sample):
         
     | 
| 163 | 
         
            +
                    width, height = self.get_size(
         
     | 
| 164 | 
         
            +
                        sample["image"].shape[1], sample["image"].shape[0]
         
     | 
| 165 | 
         
            +
                    )
         
     | 
| 166 | 
         
            +
             
     | 
| 167 | 
         
            +
                    # resize sample
         
     | 
| 168 | 
         
            +
                    sample["image"] = cv2.resize(
         
     | 
| 169 | 
         
            +
                        sample["image"],
         
     | 
| 170 | 
         
            +
                        (width, height),
         
     | 
| 171 | 
         
            +
                        interpolation=self.__image_interpolation_method,
         
     | 
| 172 | 
         
            +
                    )
         
     | 
| 173 | 
         
            +
             
     | 
| 174 | 
         
            +
                    if self.__resize_target:
         
     | 
| 175 | 
         
            +
                        if "disparity" in sample:
         
     | 
| 176 | 
         
            +
                            sample["disparity"] = cv2.resize(
         
     | 
| 177 | 
         
            +
                                sample["disparity"],
         
     | 
| 178 | 
         
            +
                                (width, height),
         
     | 
| 179 | 
         
            +
                                interpolation=cv2.INTER_NEAREST,
         
     | 
| 180 | 
         
            +
                            )
         
     | 
| 181 | 
         
            +
             
     | 
| 182 | 
         
            +
                        if "depth" in sample:
         
     | 
| 183 | 
         
            +
                            sample["depth"] = cv2.resize(
         
     | 
| 184 | 
         
            +
                                sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
         
     | 
| 185 | 
         
            +
                            )
         
     | 
| 186 | 
         
            +
             
     | 
| 187 | 
         
            +
                        sample["mask"] = cv2.resize(
         
     | 
| 188 | 
         
            +
                            sample["mask"].astype(np.float32),
         
     | 
| 189 | 
         
            +
                            (width, height),
         
     | 
| 190 | 
         
            +
                            interpolation=cv2.INTER_NEAREST,
         
     | 
| 191 | 
         
            +
                        )
         
     | 
| 192 | 
         
            +
                        sample["mask"] = sample["mask"].astype(bool)
         
     | 
| 193 | 
         
            +
             
     | 
| 194 | 
         
            +
                    return sample
         
     | 
| 195 | 
         
            +
             
     | 
| 196 | 
         
            +
             
     | 
| 197 | 
         
            +
            class NormalizeImage(object):
         
     | 
| 198 | 
         
            +
                """Normlize image by given mean and std.
         
     | 
| 199 | 
         
            +
                """
         
     | 
| 200 | 
         
            +
             
     | 
| 201 | 
         
            +
                def __init__(self, mean, std):
         
     | 
| 202 | 
         
            +
                    self.__mean = mean
         
     | 
| 203 | 
         
            +
                    self.__std = std
         
     | 
| 204 | 
         
            +
             
     | 
| 205 | 
         
            +
                def __call__(self, sample):
         
     | 
| 206 | 
         
            +
                    sample["image"] = (sample["image"] - self.__mean) / self.__std
         
     | 
| 207 | 
         
            +
             
     | 
| 208 | 
         
            +
                    return sample
         
     | 
| 209 | 
         
            +
             
     | 
| 210 | 
         
            +
             
     | 
| 211 | 
         
            +
            class PrepareForNet(object):
         
     | 
| 212 | 
         
            +
                """Prepare sample for usage as network input.
         
     | 
| 213 | 
         
            +
                """
         
     | 
| 214 | 
         
            +
             
     | 
| 215 | 
         
            +
                def __init__(self):
         
     | 
| 216 | 
         
            +
                    pass
         
     | 
| 217 | 
         
            +
             
     | 
| 218 | 
         
            +
                def __call__(self, sample):
         
     | 
| 219 | 
         
            +
                    image = np.transpose(sample["image"], (2, 0, 1))
         
     | 
| 220 | 
         
            +
                    sample["image"] = np.ascontiguousarray(image).astype(np.float32)
         
     | 
| 221 | 
         
            +
             
     | 
| 222 | 
         
            +
                    if "mask" in sample:
         
     | 
| 223 | 
         
            +
                        sample["mask"] = sample["mask"].astype(np.float32)
         
     | 
| 224 | 
         
            +
                        sample["mask"] = np.ascontiguousarray(sample["mask"])
         
     | 
| 225 | 
         
            +
             
     | 
| 226 | 
         
            +
                    if "disparity" in sample:
         
     | 
| 227 | 
         
            +
                        disparity = sample["disparity"].astype(np.float32)
         
     | 
| 228 | 
         
            +
                        sample["disparity"] = np.ascontiguousarray(disparity)
         
     | 
| 229 | 
         
            +
             
     | 
| 230 | 
         
            +
                    if "depth" in sample:
         
     | 
| 231 | 
         
            +
                        depth = sample["depth"].astype(np.float32)
         
     | 
| 232 | 
         
            +
                        sample["depth"] = np.ascontiguousarray(depth)
         
     | 
| 233 | 
         
            +
             
     | 
| 234 | 
         
            +
                    return sample
         
     | 
    	
        condition/midas/midas/vit.py
    ADDED
    
    | 
         @@ -0,0 +1,491 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import torch
         
     | 
| 2 | 
         
            +
            import torch.nn as nn
         
     | 
| 3 | 
         
            +
            import timm
         
     | 
| 4 | 
         
            +
            import types
         
     | 
| 5 | 
         
            +
            import math
         
     | 
| 6 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            class Slice(nn.Module):
         
     | 
| 10 | 
         
            +
                def __init__(self, start_index=1):
         
     | 
| 11 | 
         
            +
                    super(Slice, self).__init__()
         
     | 
| 12 | 
         
            +
                    self.start_index = start_index
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
                def forward(self, x):
         
     | 
| 15 | 
         
            +
                    return x[:, self.start_index :]
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            class AddReadout(nn.Module):
         
     | 
| 19 | 
         
            +
                def __init__(self, start_index=1):
         
     | 
| 20 | 
         
            +
                    super(AddReadout, self).__init__()
         
     | 
| 21 | 
         
            +
                    self.start_index = start_index
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
                def forward(self, x):
         
     | 
| 24 | 
         
            +
                    if self.start_index == 2:
         
     | 
| 25 | 
         
            +
                        readout = (x[:, 0] + x[:, 1]) / 2
         
     | 
| 26 | 
         
            +
                    else:
         
     | 
| 27 | 
         
            +
                        readout = x[:, 0]
         
     | 
| 28 | 
         
            +
                    return x[:, self.start_index :] + readout.unsqueeze(1)
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
            class ProjectReadout(nn.Module):
         
     | 
| 32 | 
         
            +
                def __init__(self, in_features, start_index=1):
         
     | 
| 33 | 
         
            +
                    super(ProjectReadout, self).__init__()
         
     | 
| 34 | 
         
            +
                    self.start_index = start_index
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
                    self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
                def forward(self, x):
         
     | 
| 39 | 
         
            +
                    readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
         
     | 
| 40 | 
         
            +
                    features = torch.cat((x[:, self.start_index :], readout), -1)
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
                    return self.project(features)
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
            class Transpose(nn.Module):
         
     | 
| 46 | 
         
            +
                def __init__(self, dim0, dim1):
         
     | 
| 47 | 
         
            +
                    super(Transpose, self).__init__()
         
     | 
| 48 | 
         
            +
                    self.dim0 = dim0
         
     | 
| 49 | 
         
            +
                    self.dim1 = dim1
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
                def forward(self, x):
         
     | 
| 52 | 
         
            +
                    x = x.transpose(self.dim0, self.dim1)
         
     | 
| 53 | 
         
            +
                    return x
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
            def forward_vit(pretrained, x):
         
     | 
| 57 | 
         
            +
                b, c, h, w = x.shape
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
                glob = pretrained.model.forward_flex(x)
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
                layer_1 = pretrained.activations["1"]
         
     | 
| 62 | 
         
            +
                layer_2 = pretrained.activations["2"]
         
     | 
| 63 | 
         
            +
                layer_3 = pretrained.activations["3"]
         
     | 
| 64 | 
         
            +
                layer_4 = pretrained.activations["4"]
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
                layer_1 = pretrained.act_postprocess1[0:2](layer_1)
         
     | 
| 67 | 
         
            +
                layer_2 = pretrained.act_postprocess2[0:2](layer_2)
         
     | 
| 68 | 
         
            +
                layer_3 = pretrained.act_postprocess3[0:2](layer_3)
         
     | 
| 69 | 
         
            +
                layer_4 = pretrained.act_postprocess4[0:2](layer_4)
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
                unflatten = nn.Sequential(
         
     | 
| 72 | 
         
            +
                    nn.Unflatten(
         
     | 
| 73 | 
         
            +
                        2,
         
     | 
| 74 | 
         
            +
                        torch.Size(
         
     | 
| 75 | 
         
            +
                            [
         
     | 
| 76 | 
         
            +
                                h // pretrained.model.patch_size[1],
         
     | 
| 77 | 
         
            +
                                w // pretrained.model.patch_size[0],
         
     | 
| 78 | 
         
            +
                            ]
         
     | 
| 79 | 
         
            +
                        ),
         
     | 
| 80 | 
         
            +
                    )
         
     | 
| 81 | 
         
            +
                )
         
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
                if layer_1.ndim == 3:
         
     | 
| 84 | 
         
            +
                    layer_1 = unflatten(layer_1)
         
     | 
| 85 | 
         
            +
                if layer_2.ndim == 3:
         
     | 
| 86 | 
         
            +
                    layer_2 = unflatten(layer_2)
         
     | 
| 87 | 
         
            +
                if layer_3.ndim == 3:
         
     | 
| 88 | 
         
            +
                    layer_3 = unflatten(layer_3)
         
     | 
| 89 | 
         
            +
                if layer_4.ndim == 3:
         
     | 
| 90 | 
         
            +
                    layer_4 = unflatten(layer_4)
         
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
                layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)
         
     | 
| 93 | 
         
            +
                layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)
         
     | 
| 94 | 
         
            +
                layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)
         
     | 
| 95 | 
         
            +
                layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
                return layer_1, layer_2, layer_3, layer_4
         
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
            def _resize_pos_embed(self, posemb, gs_h, gs_w):
         
     | 
| 101 | 
         
            +
                posemb_tok, posemb_grid = (
         
     | 
| 102 | 
         
            +
                    posemb[:, : self.start_index],
         
     | 
| 103 | 
         
            +
                    posemb[0, self.start_index :],
         
     | 
| 104 | 
         
            +
                )
         
     | 
| 105 | 
         
            +
             
     | 
| 106 | 
         
            +
                gs_old = int(math.sqrt(len(posemb_grid)))
         
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
                posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
         
     | 
| 109 | 
         
            +
                posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
         
     | 
| 110 | 
         
            +
                posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
                posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
         
     | 
| 113 | 
         
            +
             
     | 
| 114 | 
         
            +
                return posemb
         
     | 
| 115 | 
         
            +
             
     | 
| 116 | 
         
            +
             
     | 
| 117 | 
         
            +
            def forward_flex(self, x):
         
     | 
| 118 | 
         
            +
                b, c, h, w = x.shape
         
     | 
| 119 | 
         
            +
             
     | 
| 120 | 
         
            +
                pos_embed = self._resize_pos_embed(
         
     | 
| 121 | 
         
            +
                    self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
         
     | 
| 122 | 
         
            +
                )
         
     | 
| 123 | 
         
            +
             
     | 
| 124 | 
         
            +
                B = x.shape[0]
         
     | 
| 125 | 
         
            +
             
     | 
| 126 | 
         
            +
                if hasattr(self.patch_embed, "backbone"):
         
     | 
| 127 | 
         
            +
                    x = self.patch_embed.backbone(x)
         
     | 
| 128 | 
         
            +
                    if isinstance(x, (list, tuple)):
         
     | 
| 129 | 
         
            +
                        x = x[-1]  # last feature if backbone outputs list/tuple of features
         
     | 
| 130 | 
         
            +
             
     | 
| 131 | 
         
            +
                x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
         
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
                if getattr(self, "dist_token", None) is not None:
         
     | 
| 134 | 
         
            +
                    cls_tokens = self.cls_token.expand(
         
     | 
| 135 | 
         
            +
                        B, -1, -1
         
     | 
| 136 | 
         
            +
                    )  # stole cls_tokens impl from Phil Wang, thanks
         
     | 
| 137 | 
         
            +
                    dist_token = self.dist_token.expand(B, -1, -1)
         
     | 
| 138 | 
         
            +
                    x = torch.cat((cls_tokens, dist_token, x), dim=1)
         
     | 
| 139 | 
         
            +
                else:
         
     | 
| 140 | 
         
            +
                    cls_tokens = self.cls_token.expand(
         
     | 
| 141 | 
         
            +
                        B, -1, -1
         
     | 
| 142 | 
         
            +
                    )  # stole cls_tokens impl from Phil Wang, thanks
         
     | 
| 143 | 
         
            +
                    x = torch.cat((cls_tokens, x), dim=1)
         
     | 
| 144 | 
         
            +
             
     | 
| 145 | 
         
            +
                x = x + pos_embed
         
     | 
| 146 | 
         
            +
                x = self.pos_drop(x)
         
     | 
| 147 | 
         
            +
             
     | 
| 148 | 
         
            +
                for blk in self.blocks:
         
     | 
| 149 | 
         
            +
                    x = blk(x)
         
     | 
| 150 | 
         
            +
             
     | 
| 151 | 
         
            +
                x = self.norm(x)
         
     | 
| 152 | 
         
            +
             
     | 
| 153 | 
         
            +
                return x
         
     | 
| 154 | 
         
            +
             
     | 
| 155 | 
         
            +
             
     | 
| 156 | 
         
            +
            activations = {}
         
     | 
| 157 | 
         
            +
             
     | 
| 158 | 
         
            +
             
     | 
| 159 | 
         
            +
            def get_activation(name):
         
     | 
| 160 | 
         
            +
                def hook(model, input, output):
         
     | 
| 161 | 
         
            +
                    activations[name] = output
         
     | 
| 162 | 
         
            +
             
     | 
| 163 | 
         
            +
                return hook
         
     | 
| 164 | 
         
            +
             
     | 
| 165 | 
         
            +
             
     | 
| 166 | 
         
            +
            def get_readout_oper(vit_features, features, use_readout, start_index=1):
         
     | 
| 167 | 
         
            +
                if use_readout == "ignore":
         
     | 
| 168 | 
         
            +
                    readout_oper = [Slice(start_index)] * len(features)
         
     | 
| 169 | 
         
            +
                elif use_readout == "add":
         
     | 
| 170 | 
         
            +
                    readout_oper = [AddReadout(start_index)] * len(features)
         
     | 
| 171 | 
         
            +
                elif use_readout == "project":
         
     | 
| 172 | 
         
            +
                    readout_oper = [
         
     | 
| 173 | 
         
            +
                        ProjectReadout(vit_features, start_index) for out_feat in features
         
     | 
| 174 | 
         
            +
                    ]
         
     | 
| 175 | 
         
            +
                else:
         
     | 
| 176 | 
         
            +
                    assert (
         
     | 
| 177 | 
         
            +
                        False
         
     | 
| 178 | 
         
            +
                    ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
         
     | 
| 179 | 
         
            +
             
     | 
| 180 | 
         
            +
                return readout_oper
         
     | 
| 181 | 
         
            +
             
     | 
| 182 | 
         
            +
             
     | 
| 183 | 
         
            +
            def _make_vit_b16_backbone(
         
     | 
| 184 | 
         
            +
                model,
         
     | 
| 185 | 
         
            +
                features=[96, 192, 384, 768],
         
     | 
| 186 | 
         
            +
                size=[384, 384],
         
     | 
| 187 | 
         
            +
                hooks=[2, 5, 8, 11],
         
     | 
| 188 | 
         
            +
                vit_features=768,
         
     | 
| 189 | 
         
            +
                use_readout="ignore",
         
     | 
| 190 | 
         
            +
                start_index=1,
         
     | 
| 191 | 
         
            +
            ):
         
     | 
| 192 | 
         
            +
                pretrained = nn.Module()
         
     | 
| 193 | 
         
            +
             
     | 
| 194 | 
         
            +
                pretrained.model = model
         
     | 
| 195 | 
         
            +
                pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
         
     | 
| 196 | 
         
            +
                pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
         
     | 
| 197 | 
         
            +
                pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
         
     | 
| 198 | 
         
            +
                pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
         
     | 
| 199 | 
         
            +
             
     | 
| 200 | 
         
            +
                pretrained.activations = activations
         
     | 
| 201 | 
         
            +
             
     | 
| 202 | 
         
            +
                readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
         
     | 
| 203 | 
         
            +
             
     | 
| 204 | 
         
            +
                # 32, 48, 136, 384
         
     | 
| 205 | 
         
            +
                pretrained.act_postprocess1 = nn.Sequential(
         
     | 
| 206 | 
         
            +
                    readout_oper[0],
         
     | 
| 207 | 
         
            +
                    Transpose(1, 2),
         
     | 
| 208 | 
         
            +
                    nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
         
     | 
| 209 | 
         
            +
                    nn.Conv2d(
         
     | 
| 210 | 
         
            +
                        in_channels=vit_features,
         
     | 
| 211 | 
         
            +
                        out_channels=features[0],
         
     | 
| 212 | 
         
            +
                        kernel_size=1,
         
     | 
| 213 | 
         
            +
                        stride=1,
         
     | 
| 214 | 
         
            +
                        padding=0,
         
     | 
| 215 | 
         
            +
                    ),
         
     | 
| 216 | 
         
            +
                    nn.ConvTranspose2d(
         
     | 
| 217 | 
         
            +
                        in_channels=features[0],
         
     | 
| 218 | 
         
            +
                        out_channels=features[0],
         
     | 
| 219 | 
         
            +
                        kernel_size=4,
         
     | 
| 220 | 
         
            +
                        stride=4,
         
     | 
| 221 | 
         
            +
                        padding=0,
         
     | 
| 222 | 
         
            +
                        bias=True,
         
     | 
| 223 | 
         
            +
                        dilation=1,
         
     | 
| 224 | 
         
            +
                        groups=1,
         
     | 
| 225 | 
         
            +
                    ),
         
     | 
| 226 | 
         
            +
                )
         
     | 
| 227 | 
         
            +
             
     | 
| 228 | 
         
            +
                pretrained.act_postprocess2 = nn.Sequential(
         
     | 
| 229 | 
         
            +
                    readout_oper[1],
         
     | 
| 230 | 
         
            +
                    Transpose(1, 2),
         
     | 
| 231 | 
         
            +
                    nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
         
     | 
| 232 | 
         
            +
                    nn.Conv2d(
         
     | 
| 233 | 
         
            +
                        in_channels=vit_features,
         
     | 
| 234 | 
         
            +
                        out_channels=features[1],
         
     | 
| 235 | 
         
            +
                        kernel_size=1,
         
     | 
| 236 | 
         
            +
                        stride=1,
         
     | 
| 237 | 
         
            +
                        padding=0,
         
     | 
| 238 | 
         
            +
                    ),
         
     | 
| 239 | 
         
            +
                    nn.ConvTranspose2d(
         
     | 
| 240 | 
         
            +
                        in_channels=features[1],
         
     | 
| 241 | 
         
            +
                        out_channels=features[1],
         
     | 
| 242 | 
         
            +
                        kernel_size=2,
         
     | 
| 243 | 
         
            +
                        stride=2,
         
     | 
| 244 | 
         
            +
                        padding=0,
         
     | 
| 245 | 
         
            +
                        bias=True,
         
     | 
| 246 | 
         
            +
                        dilation=1,
         
     | 
| 247 | 
         
            +
                        groups=1,
         
     | 
| 248 | 
         
            +
                    ),
         
     | 
| 249 | 
         
            +
                )
         
     | 
| 250 | 
         
            +
             
     | 
| 251 | 
         
            +
                pretrained.act_postprocess3 = nn.Sequential(
         
     | 
| 252 | 
         
            +
                    readout_oper[2],
         
     | 
| 253 | 
         
            +
                    Transpose(1, 2),
         
     | 
| 254 | 
         
            +
                    nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
         
     | 
| 255 | 
         
            +
                    nn.Conv2d(
         
     | 
| 256 | 
         
            +
                        in_channels=vit_features,
         
     | 
| 257 | 
         
            +
                        out_channels=features[2],
         
     | 
| 258 | 
         
            +
                        kernel_size=1,
         
     | 
| 259 | 
         
            +
                        stride=1,
         
     | 
| 260 | 
         
            +
                        padding=0,
         
     | 
| 261 | 
         
            +
                    ),
         
     | 
| 262 | 
         
            +
                )
         
     | 
| 263 | 
         
            +
             
     | 
| 264 | 
         
            +
                pretrained.act_postprocess4 = nn.Sequential(
         
     | 
| 265 | 
         
            +
                    readout_oper[3],
         
     | 
| 266 | 
         
            +
                    Transpose(1, 2),
         
     | 
| 267 | 
         
            +
                    nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
         
     | 
| 268 | 
         
            +
                    nn.Conv2d(
         
     | 
| 269 | 
         
            +
                        in_channels=vit_features,
         
     | 
| 270 | 
         
            +
                        out_channels=features[3],
         
     | 
| 271 | 
         
            +
                        kernel_size=1,
         
     | 
| 272 | 
         
            +
                        stride=1,
         
     | 
| 273 | 
         
            +
                        padding=0,
         
     | 
| 274 | 
         
            +
                    ),
         
     | 
| 275 | 
         
            +
                    nn.Conv2d(
         
     | 
| 276 | 
         
            +
                        in_channels=features[3],
         
     | 
| 277 | 
         
            +
                        out_channels=features[3],
         
     | 
| 278 | 
         
            +
                        kernel_size=3,
         
     | 
| 279 | 
         
            +
                        stride=2,
         
     | 
| 280 | 
         
            +
                        padding=1,
         
     | 
| 281 | 
         
            +
                    ),
         
     | 
| 282 | 
         
            +
                )
         
     | 
| 283 | 
         
            +
             
     | 
| 284 | 
         
            +
                pretrained.model.start_index = start_index
         
     | 
| 285 | 
         
            +
                pretrained.model.patch_size = [16, 16]
         
     | 
| 286 | 
         
            +
             
     | 
| 287 | 
         
            +
                # We inject this function into the VisionTransformer instances so that
         
     | 
| 288 | 
         
            +
                # we can use it with interpolated position embeddings without modifying the library source.
         
     | 
| 289 | 
         
            +
                pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
         
     | 
| 290 | 
         
            +
                pretrained.model._resize_pos_embed = types.MethodType(
         
     | 
| 291 | 
         
            +
                    _resize_pos_embed, pretrained.model
         
     | 
| 292 | 
         
            +
                )
         
     | 
| 293 | 
         
            +
             
     | 
| 294 | 
         
            +
                return pretrained
         
     | 
| 295 | 
         
            +
             
     | 
| 296 | 
         
            +
             
     | 
| 297 | 
         
            +
            def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None):
         
     | 
| 298 | 
         
            +
                model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
         
     | 
| 299 | 
         
            +
             
     | 
| 300 | 
         
            +
                hooks = [5, 11, 17, 23] if hooks == None else hooks
         
     | 
| 301 | 
         
            +
                return _make_vit_b16_backbone(
         
     | 
| 302 | 
         
            +
                    model,
         
     | 
| 303 | 
         
            +
                    features=[256, 512, 1024, 1024],
         
     | 
| 304 | 
         
            +
                    hooks=hooks,
         
     | 
| 305 | 
         
            +
                    vit_features=1024,
         
     | 
| 306 | 
         
            +
                    use_readout=use_readout,
         
     | 
| 307 | 
         
            +
                )
         
     | 
| 308 | 
         
            +
             
     | 
| 309 | 
         
            +
             
     | 
| 310 | 
         
            +
            def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None):
         
     | 
| 311 | 
         
            +
                model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
         
     | 
| 312 | 
         
            +
             
     | 
| 313 | 
         
            +
                hooks = [2, 5, 8, 11] if hooks == None else hooks
         
     | 
| 314 | 
         
            +
                return _make_vit_b16_backbone(
         
     | 
| 315 | 
         
            +
                    model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
         
     | 
| 316 | 
         
            +
                )
         
     | 
| 317 | 
         
            +
             
     | 
| 318 | 
         
            +
             
     | 
| 319 | 
         
            +
            def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None):
         
     | 
| 320 | 
         
            +
                model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained)
         
     | 
| 321 | 
         
            +
             
     | 
| 322 | 
         
            +
                hooks = [2, 5, 8, 11] if hooks == None else hooks
         
     | 
| 323 | 
         
            +
                return _make_vit_b16_backbone(
         
     | 
| 324 | 
         
            +
                    model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
         
     | 
| 325 | 
         
            +
                )
         
     | 
| 326 | 
         
            +
             
     | 
| 327 | 
         
            +
             
     | 
| 328 | 
         
            +
            def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None):
         
     | 
| 329 | 
         
            +
                model = timm.create_model(
         
     | 
| 330 | 
         
            +
                    "vit_deit_base_distilled_patch16_384", pretrained=pretrained
         
     | 
| 331 | 
         
            +
                )
         
     | 
| 332 | 
         
            +
             
     | 
| 333 | 
         
            +
                hooks = [2, 5, 8, 11] if hooks == None else hooks
         
     | 
| 334 | 
         
            +
                return _make_vit_b16_backbone(
         
     | 
| 335 | 
         
            +
                    model,
         
     | 
| 336 | 
         
            +
                    features=[96, 192, 384, 768],
         
     | 
| 337 | 
         
            +
                    hooks=hooks,
         
     | 
| 338 | 
         
            +
                    use_readout=use_readout,
         
     | 
| 339 | 
         
            +
                    start_index=2,
         
     | 
| 340 | 
         
            +
                )
         
     | 
| 341 | 
         
            +
             
     | 
| 342 | 
         
            +
             
     | 
| 343 | 
         
            +
            def _make_vit_b_rn50_backbone(
         
     | 
| 344 | 
         
            +
                model,
         
     | 
| 345 | 
         
            +
                features=[256, 512, 768, 768],
         
     | 
| 346 | 
         
            +
                size=[384, 384],
         
     | 
| 347 | 
         
            +
                hooks=[0, 1, 8, 11],
         
     | 
| 348 | 
         
            +
                vit_features=768,
         
     | 
| 349 | 
         
            +
                use_vit_only=False,
         
     | 
| 350 | 
         
            +
                use_readout="ignore",
         
     | 
| 351 | 
         
            +
                start_index=1,
         
     | 
| 352 | 
         
            +
            ):
         
     | 
| 353 | 
         
            +
                pretrained = nn.Module()
         
     | 
| 354 | 
         
            +
             
     | 
| 355 | 
         
            +
                pretrained.model = model
         
     | 
| 356 | 
         
            +
             
     | 
| 357 | 
         
            +
                if use_vit_only == True:
         
     | 
| 358 | 
         
            +
                    pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
         
     | 
| 359 | 
         
            +
                    pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
         
     | 
| 360 | 
         
            +
                else:
         
     | 
| 361 | 
         
            +
                    pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
         
     | 
| 362 | 
         
            +
                        get_activation("1")
         
     | 
| 363 | 
         
            +
                    )
         
     | 
| 364 | 
         
            +
                    pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
         
     | 
| 365 | 
         
            +
                        get_activation("2")
         
     | 
| 366 | 
         
            +
                    )
         
     | 
| 367 | 
         
            +
             
     | 
| 368 | 
         
            +
                pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
         
     | 
| 369 | 
         
            +
                pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
         
     | 
| 370 | 
         
            +
             
     | 
| 371 | 
         
            +
                pretrained.activations = activations
         
     | 
| 372 | 
         
            +
             
     | 
| 373 | 
         
            +
                readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
         
     | 
| 374 | 
         
            +
             
     | 
| 375 | 
         
            +
                if use_vit_only == True:
         
     | 
| 376 | 
         
            +
                    pretrained.act_postprocess1 = nn.Sequential(
         
     | 
| 377 | 
         
            +
                        readout_oper[0],
         
     | 
| 378 | 
         
            +
                        Transpose(1, 2),
         
     | 
| 379 | 
         
            +
                        nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
         
     | 
| 380 | 
         
            +
                        nn.Conv2d(
         
     | 
| 381 | 
         
            +
                            in_channels=vit_features,
         
     | 
| 382 | 
         
            +
                            out_channels=features[0],
         
     | 
| 383 | 
         
            +
                            kernel_size=1,
         
     | 
| 384 | 
         
            +
                            stride=1,
         
     | 
| 385 | 
         
            +
                            padding=0,
         
     | 
| 386 | 
         
            +
                        ),
         
     | 
| 387 | 
         
            +
                        nn.ConvTranspose2d(
         
     | 
| 388 | 
         
            +
                            in_channels=features[0],
         
     | 
| 389 | 
         
            +
                            out_channels=features[0],
         
     | 
| 390 | 
         
            +
                            kernel_size=4,
         
     | 
| 391 | 
         
            +
                            stride=4,
         
     | 
| 392 | 
         
            +
                            padding=0,
         
     | 
| 393 | 
         
            +
                            bias=True,
         
     | 
| 394 | 
         
            +
                            dilation=1,
         
     | 
| 395 | 
         
            +
                            groups=1,
         
     | 
| 396 | 
         
            +
                        ),
         
     | 
| 397 | 
         
            +
                    )
         
     | 
| 398 | 
         
            +
             
     | 
| 399 | 
         
            +
                    pretrained.act_postprocess2 = nn.Sequential(
         
     | 
| 400 | 
         
            +
                        readout_oper[1],
         
     | 
| 401 | 
         
            +
                        Transpose(1, 2),
         
     | 
| 402 | 
         
            +
                        nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
         
     | 
| 403 | 
         
            +
                        nn.Conv2d(
         
     | 
| 404 | 
         
            +
                            in_channels=vit_features,
         
     | 
| 405 | 
         
            +
                            out_channels=features[1],
         
     | 
| 406 | 
         
            +
                            kernel_size=1,
         
     | 
| 407 | 
         
            +
                            stride=1,
         
     | 
| 408 | 
         
            +
                            padding=0,
         
     | 
| 409 | 
         
            +
                        ),
         
     | 
| 410 | 
         
            +
                        nn.ConvTranspose2d(
         
     | 
| 411 | 
         
            +
                            in_channels=features[1],
         
     | 
| 412 | 
         
            +
                            out_channels=features[1],
         
     | 
| 413 | 
         
            +
                            kernel_size=2,
         
     | 
| 414 | 
         
            +
                            stride=2,
         
     | 
| 415 | 
         
            +
                            padding=0,
         
     | 
| 416 | 
         
            +
                            bias=True,
         
     | 
| 417 | 
         
            +
                            dilation=1,
         
     | 
| 418 | 
         
            +
                            groups=1,
         
     | 
| 419 | 
         
            +
                        ),
         
     | 
| 420 | 
         
            +
                    )
         
     | 
| 421 | 
         
            +
                else:
         
     | 
| 422 | 
         
            +
                    pretrained.act_postprocess1 = nn.Sequential(
         
     | 
| 423 | 
         
            +
                        nn.Identity(), nn.Identity(), nn.Identity()
         
     | 
| 424 | 
         
            +
                    )
         
     | 
| 425 | 
         
            +
                    pretrained.act_postprocess2 = nn.Sequential(
         
     | 
| 426 | 
         
            +
                        nn.Identity(), nn.Identity(), nn.Identity()
         
     | 
| 427 | 
         
            +
                    )
         
     | 
| 428 | 
         
            +
             
     | 
| 429 | 
         
            +
                pretrained.act_postprocess3 = nn.Sequential(
         
     | 
| 430 | 
         
            +
                    readout_oper[2],
         
     | 
| 431 | 
         
            +
                    Transpose(1, 2),
         
     | 
| 432 | 
         
            +
                    nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
         
     | 
| 433 | 
         
            +
                    nn.Conv2d(
         
     | 
| 434 | 
         
            +
                        in_channels=vit_features,
         
     | 
| 435 | 
         
            +
                        out_channels=features[2],
         
     | 
| 436 | 
         
            +
                        kernel_size=1,
         
     | 
| 437 | 
         
            +
                        stride=1,
         
     | 
| 438 | 
         
            +
                        padding=0,
         
     | 
| 439 | 
         
            +
                    ),
         
     | 
| 440 | 
         
            +
                )
         
     | 
| 441 | 
         
            +
             
     | 
| 442 | 
         
            +
                pretrained.act_postprocess4 = nn.Sequential(
         
     | 
| 443 | 
         
            +
                    readout_oper[3],
         
     | 
| 444 | 
         
            +
                    Transpose(1, 2),
         
     | 
| 445 | 
         
            +
                    nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
         
     | 
| 446 | 
         
            +
                    nn.Conv2d(
         
     | 
| 447 | 
         
            +
                        in_channels=vit_features,
         
     | 
| 448 | 
         
            +
                        out_channels=features[3],
         
     | 
| 449 | 
         
            +
                        kernel_size=1,
         
     | 
| 450 | 
         
            +
                        stride=1,
         
     | 
| 451 | 
         
            +
                        padding=0,
         
     | 
| 452 | 
         
            +
                    ),
         
     | 
| 453 | 
         
            +
                    nn.Conv2d(
         
     | 
| 454 | 
         
            +
                        in_channels=features[3],
         
     | 
| 455 | 
         
            +
                        out_channels=features[3],
         
     | 
| 456 | 
         
            +
                        kernel_size=3,
         
     | 
| 457 | 
         
            +
                        stride=2,
         
     | 
| 458 | 
         
            +
                        padding=1,
         
     | 
| 459 | 
         
            +
                    ),
         
     | 
| 460 | 
         
            +
                )
         
     | 
| 461 | 
         
            +
             
     | 
| 462 | 
         
            +
                pretrained.model.start_index = start_index
         
     | 
| 463 | 
         
            +
                pretrained.model.patch_size = [16, 16]
         
     | 
| 464 | 
         
            +
             
     | 
| 465 | 
         
            +
                # We inject this function into the VisionTransformer instances so that
         
     | 
| 466 | 
         
            +
                # we can use it with interpolated position embeddings without modifying the library source.
         
     | 
| 467 | 
         
            +
                pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
         
     | 
| 468 | 
         
            +
             
     | 
| 469 | 
         
            +
                # We inject this function into the VisionTransformer instances so that
         
     | 
| 470 | 
         
            +
                # we can use it with interpolated position embeddings without modifying the library source.
         
     | 
| 471 | 
         
            +
                pretrained.model._resize_pos_embed = types.MethodType(
         
     | 
| 472 | 
         
            +
                    _resize_pos_embed, pretrained.model
         
     | 
| 473 | 
         
            +
                )
         
     | 
| 474 | 
         
            +
             
     | 
| 475 | 
         
            +
                return pretrained
         
     | 
| 476 | 
         
            +
             
     | 
| 477 | 
         
            +
             
     | 
| 478 | 
         
            +
            def _make_pretrained_vitb_rn50_384(
         
     | 
| 479 | 
         
            +
                pretrained, use_readout="ignore", hooks=None, use_vit_only=False
         
     | 
| 480 | 
         
            +
            ):
         
     | 
| 481 | 
         
            +
                model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
         
     | 
| 482 | 
         
            +
             
     | 
| 483 | 
         
            +
                hooks = [0, 1, 8, 11] if hooks == None else hooks
         
     | 
| 484 | 
         
            +
                return _make_vit_b_rn50_backbone(
         
     | 
| 485 | 
         
            +
                    model,
         
     | 
| 486 | 
         
            +
                    features=[256, 512, 768, 768],
         
     | 
| 487 | 
         
            +
                    size=[384, 384],
         
     | 
| 488 | 
         
            +
                    hooks=hooks,
         
     | 
| 489 | 
         
            +
                    use_vit_only=use_vit_only,
         
     | 
| 490 | 
         
            +
                    use_readout=use_readout,
         
     | 
| 491 | 
         
            +
                )
         
     | 
    	
        condition/utils.py
    ADDED
    
    | 
         @@ -0,0 +1,38 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import numpy as np
         
     | 
| 2 | 
         
            +
            import cv2
         
     | 
| 3 | 
         
            +
            import os
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            annotator_ckpts_path = os.path.join(os.path.dirname(__file__), 'ckpts')
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            def HWC3(x):
         
     | 
| 10 | 
         
            +
                assert x.dtype == np.uint8
         
     | 
| 11 | 
         
            +
                if x.ndim == 2:
         
     | 
| 12 | 
         
            +
                    x = x[:, :, None]
         
     | 
| 13 | 
         
            +
                assert x.ndim == 3
         
     | 
| 14 | 
         
            +
                H, W, C = x.shape
         
     | 
| 15 | 
         
            +
                assert C == 1 or C == 3 or C == 4
         
     | 
| 16 | 
         
            +
                if C == 3:
         
     | 
| 17 | 
         
            +
                    return x
         
     | 
| 18 | 
         
            +
                if C == 1:
         
     | 
| 19 | 
         
            +
                    return np.concatenate([x, x, x], axis=2)
         
     | 
| 20 | 
         
            +
                if C == 4:
         
     | 
| 21 | 
         
            +
                    color = x[:, :, 0:3].astype(np.float32)
         
     | 
| 22 | 
         
            +
                    alpha = x[:, :, 3:4].astype(np.float32) / 255.0
         
     | 
| 23 | 
         
            +
                    y = color * alpha + 255.0 * (1.0 - alpha)
         
     | 
| 24 | 
         
            +
                    y = y.clip(0, 255).astype(np.uint8)
         
     | 
| 25 | 
         
            +
                    return y
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
            def resize_image(input_image, resolution):
         
     | 
| 29 | 
         
            +
                H, W, C = input_image.shape
         
     | 
| 30 | 
         
            +
                H = float(H)
         
     | 
| 31 | 
         
            +
                W = float(W)
         
     | 
| 32 | 
         
            +
                k = float(resolution) / min(H, W)
         
     | 
| 33 | 
         
            +
                H *= k
         
     | 
| 34 | 
         
            +
                W *= k
         
     | 
| 35 | 
         
            +
                H = int(np.round(H / 64.0)) * 64
         
     | 
| 36 | 
         
            +
                W = int(np.round(W / 64.0)) * 64
         
     | 
| 37 | 
         
            +
                img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
         
     | 
| 38 | 
         
            +
                return img
         
     | 
    	
        language/README.md
    ADDED
    
    | 
         @@ -0,0 +1,14 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            ## Language models for text-conditional image generation
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            ### Requirements
         
     | 
| 4 | 
         
            +
            ```
         
     | 
| 5 | 
         
            +
            pip install ftfy
         
     | 
| 6 | 
         
            +
            pip install transformers
         
     | 
| 7 | 
         
            +
            pip install accelerate
         
     | 
| 8 | 
         
            +
            pip install sentencepiece
         
     | 
| 9 | 
         
            +
            pip install pandas
         
     | 
| 10 | 
         
            +
            pip install bs4
         
     | 
| 11 | 
         
            +
            ```
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            ### Language Models
         
     | 
| 14 | 
         
            +
            Download flan-t5-xl models from [flan-t5-xl](https://huggingface.co/google/flan-t5-xl) and put into the folder of `./pretrained_models/t5-ckpt/`
         
     | 
    	
        language/extract_t5_feature.py
    ADDED
    
    | 
         @@ -0,0 +1,129 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import torch
         
     | 
| 2 | 
         
            +
            torch.backends.cuda.matmul.allow_tf32 = True
         
     | 
| 3 | 
         
            +
            torch.backends.cudnn.allow_tf32 = True
         
     | 
| 4 | 
         
            +
            import torch.distributed as dist
         
     | 
| 5 | 
         
            +
            from torch.utils.data import Dataset, DataLoader
         
     | 
| 6 | 
         
            +
            from torch.utils.data.distributed import DistributedSampler
         
     | 
| 7 | 
         
            +
            import numpy as np
         
     | 
| 8 | 
         
            +
            import argparse
         
     | 
| 9 | 
         
            +
            import os
         
     | 
| 10 | 
         
            +
            import json
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            from utils.distributed import init_distributed_mode
         
     | 
| 13 | 
         
            +
            from language.t5 import T5Embedder
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            CAPTION_KEY = {
         
     | 
| 16 | 
         
            +
                'blip': 0,
         
     | 
| 17 | 
         
            +
                'llava': 1,
         
     | 
| 18 | 
         
            +
                'llava_first': 2,
         
     | 
| 19 | 
         
            +
            }
         
     | 
| 20 | 
         
            +
            #################################################################################
         
     | 
| 21 | 
         
            +
            #                             Training Helper Functions                         #
         
     | 
| 22 | 
         
            +
            #################################################################################
         
     | 
| 23 | 
         
            +
            class CustomDataset(Dataset):
         
     | 
| 24 | 
         
            +
                def __init__(self, lst_dir, start, end, caption_key, trunc_caption=False):
         
     | 
| 25 | 
         
            +
                    img_path_list = []
         
     | 
| 26 | 
         
            +
                    for lst_name in sorted(os.listdir(lst_dir))[start: end+1]:
         
     | 
| 27 | 
         
            +
                        if not lst_name.endswith('.jsonl'):
         
     | 
| 28 | 
         
            +
                            continue
         
     | 
| 29 | 
         
            +
                        file_path = os.path.join(lst_dir, lst_name)
         
     | 
| 30 | 
         
            +
                        with open(file_path, 'r') as file:
         
     | 
| 31 | 
         
            +
                            for line_idx, line in enumerate(file):
         
     | 
| 32 | 
         
            +
                                data = json.loads(line)
         
     | 
| 33 | 
         
            +
                                # caption = data[caption_key]
         
     | 
| 34 | 
         
            +
                                caption = data['text'][CAPTION_KEY[caption_key]]
         
     | 
| 35 | 
         
            +
                                code_dir = file_path.split('/')[-1].split('.')[0]
         
     | 
| 36 | 
         
            +
                                if trunc_caption:
         
     | 
| 37 | 
         
            +
                                    caption = caption.split('.')[0]
         
     | 
| 38 | 
         
            +
                                img_path_list.append((caption, code_dir, line_idx))
         
     | 
| 39 | 
         
            +
                    self.img_path_list = img_path_list
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                def __len__(self):
         
     | 
| 42 | 
         
            +
                    return len(self.img_path_list)
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
                def __getitem__(self, index):
         
     | 
| 45 | 
         
            +
                    caption, code_dir, code_name = self.img_path_list[index]
         
     | 
| 46 | 
         
            +
                    return caption, code_dir, code_name
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
                    
         
     | 
| 50 | 
         
            +
            #################################################################################
         
     | 
| 51 | 
         
            +
            #                                  Training Loop                                #
         
     | 
| 52 | 
         
            +
            #################################################################################
         
     | 
| 53 | 
         
            +
            def main(args):
         
     | 
| 54 | 
         
            +
                """
         
     | 
| 55 | 
         
            +
                Trains a new DiT model.
         
     | 
| 56 | 
         
            +
                """
         
     | 
| 57 | 
         
            +
                assert torch.cuda.is_available(), "Training currently requires at least one GPU."
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
                # Setup DDP:
         
     | 
| 60 | 
         
            +
                # dist.init_process_group("nccl")
         
     | 
| 61 | 
         
            +
                init_distributed_mode(args)
         
     | 
| 62 | 
         
            +
                rank = dist.get_rank()
         
     | 
| 63 | 
         
            +
                device = rank % torch.cuda.device_count()
         
     | 
| 64 | 
         
            +
                seed = args.global_seed * dist.get_world_size() + rank
         
     | 
| 65 | 
         
            +
                torch.manual_seed(seed)
         
     | 
| 66 | 
         
            +
                torch.cuda.set_device(device)
         
     | 
| 67 | 
         
            +
                print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
         
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
            +
                # Setup a feature folder:
         
     | 
| 70 | 
         
            +
                if rank == 0:
         
     | 
| 71 | 
         
            +
                    os.makedirs(args.t5_path, exist_ok=True)
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
                # Setup data:
         
     | 
| 74 | 
         
            +
                print(f"Dataset is preparing...")
         
     | 
| 75 | 
         
            +
                dataset = CustomDataset(args.data_path, args.data_start, args.data_end, args.caption_key, args.trunc_caption)
         
     | 
| 76 | 
         
            +
                sampler = DistributedSampler(
         
     | 
| 77 | 
         
            +
                    dataset,
         
     | 
| 78 | 
         
            +
                    num_replicas=dist.get_world_size(),
         
     | 
| 79 | 
         
            +
                    rank=rank,
         
     | 
| 80 | 
         
            +
                    shuffle=False,
         
     | 
| 81 | 
         
            +
                    seed=args.global_seed
         
     | 
| 82 | 
         
            +
                )
         
     | 
| 83 | 
         
            +
                loader = DataLoader(
         
     | 
| 84 | 
         
            +
                    dataset,
         
     | 
| 85 | 
         
            +
                    batch_size=1, # important!
         
     | 
| 86 | 
         
            +
                    shuffle=False,
         
     | 
| 87 | 
         
            +
                    sampler=sampler,
         
     | 
| 88 | 
         
            +
                    num_workers=args.num_workers,
         
     | 
| 89 | 
         
            +
                    pin_memory=True,
         
     | 
| 90 | 
         
            +
                    drop_last=False
         
     | 
| 91 | 
         
            +
                )
         
     | 
| 92 | 
         
            +
                print(f"Dataset contains {len(dataset):,} images")
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
                precision = {'none': torch.float32, 'bf16': torch.bfloat16, 'fp16': torch.float16}[args.precision]
         
     | 
| 95 | 
         
            +
                assert os.path.exists(args.t5_model_path)
         
     | 
| 96 | 
         
            +
                t5_xxl = T5Embedder(
         
     | 
| 97 | 
         
            +
                    device=device, 
         
     | 
| 98 | 
         
            +
                    local_cache=True, 
         
     | 
| 99 | 
         
            +
                    cache_dir=args.t5_model_path, 
         
     | 
| 100 | 
         
            +
                    dir_or_name=args.t5_model_type,
         
     | 
| 101 | 
         
            +
                    torch_dtype=precision
         
     | 
| 102 | 
         
            +
                )
         
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
                for caption, code_dir, code_name in loader:
         
     | 
| 105 | 
         
            +
                    caption_embs, emb_masks = t5_xxl.get_text_embeddings(caption)
         
     | 
| 106 | 
         
            +
                    valid_caption_embs = caption_embs[:, :emb_masks.sum()]
         
     | 
| 107 | 
         
            +
                    x = valid_caption_embs.to(torch.float32).detach().cpu().numpy()
         
     | 
| 108 | 
         
            +
                    os.makedirs(os.path.join(args.t5_path, code_dir[0]), exist_ok=True)
         
     | 
| 109 | 
         
            +
                    np.save(os.path.join(args.t5_path, code_dir[0], '{}.npy'.format(code_name.item())), x)
         
     | 
| 110 | 
         
            +
                    print(code_name.item())
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
                dist.destroy_process_group()
         
     | 
| 113 | 
         
            +
             
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 116 | 
         
            +
                parser = argparse.ArgumentParser()
         
     | 
| 117 | 
         
            +
                parser.add_argument("--data-path", type=str, required=True)
         
     | 
| 118 | 
         
            +
                parser.add_argument("--t5-path", type=str, required=True)
         
     | 
| 119 | 
         
            +
                parser.add_argument("--data-start", type=int, required=True)
         
     | 
| 120 | 
         
            +
                parser.add_argument("--data-end", type=int, required=True)
         
     | 
| 121 | 
         
            +
                parser.add_argument("--caption-key", type=str, default='blip', choices=list(CAPTION_KEY.keys()))
         
     | 
| 122 | 
         
            +
                parser.add_argument("--trunc-caption", action='store_true', default=False)
         
     | 
| 123 | 
         
            +
                parser.add_argument("--t5-model-path", type=str, default='./pretrained_models/t5-ckpt')
         
     | 
| 124 | 
         
            +
                parser.add_argument("--t5-model-type", type=str, default='flan-t5-xl')
         
     | 
| 125 | 
         
            +
                parser.add_argument("--precision", type=str, default='bf16', choices=["none", "fp16", "bf16"])
         
     | 
| 126 | 
         
            +
                parser.add_argument("--global-seed", type=int, default=0)
         
     | 
| 127 | 
         
            +
                parser.add_argument("--num-workers", type=int, default=24)
         
     | 
| 128 | 
         
            +
                args = parser.parse_args()
         
     | 
| 129 | 
         
            +
                main(args)
         
     | 
    	
        language/t5.py
    ADDED
    
    | 
         @@ -0,0 +1,201 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Modified from:
         
     | 
| 2 | 
         
            +
            #   PixArt: https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/t5.py
         
     | 
| 3 | 
         
            +
            import os
         
     | 
| 4 | 
         
            +
            import re
         
     | 
| 5 | 
         
            +
            import html
         
     | 
| 6 | 
         
            +
            import urllib.parse as ul
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            import ftfy
         
     | 
| 9 | 
         
            +
            import torch
         
     | 
| 10 | 
         
            +
            from bs4 import BeautifulSoup
         
     | 
| 11 | 
         
            +
            from transformers import T5EncoderModel, AutoTokenizer
         
     | 
| 12 | 
         
            +
            from huggingface_hub import hf_hub_download
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            class T5Embedder:
         
     | 
| 16 | 
         
            +
                available_models = ['t5-v1_1-xxl', 't5-v1_1-xl', 'flan-t5-xl']
         
     | 
| 17 | 
         
            +
                bad_punct_regex = re.compile(r'['+'#®•©™&@·º½¾¿¡§~'+'\)'+'\('+'\]'+'\['+'\}'+'\{'+'\|'+'\\'+'\/'+'\*' + r']{1,}')  # noqa
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
                def __init__(self, device, dir_or_name='t5-v1_1-xxl', *, local_cache=False, cache_dir=None, hf_token=None, use_text_preprocessing=True,
         
     | 
| 20 | 
         
            +
                             t5_model_kwargs=None, torch_dtype=None, use_offload_folder=None, model_max_length=120):
         
     | 
| 21 | 
         
            +
                    self.device = torch.device(device)
         
     | 
| 22 | 
         
            +
                    self.torch_dtype = torch_dtype or torch.bfloat16
         
     | 
| 23 | 
         
            +
                    if t5_model_kwargs is None:
         
     | 
| 24 | 
         
            +
                        t5_model_kwargs = {'low_cpu_mem_usage': True, 'torch_dtype': self.torch_dtype}
         
     | 
| 25 | 
         
            +
                        t5_model_kwargs['device_map'] = {'shared': self.device, 'encoder': self.device}
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
                    self.use_text_preprocessing = use_text_preprocessing
         
     | 
| 28 | 
         
            +
                    self.hf_token = hf_token
         
     | 
| 29 | 
         
            +
                    self.cache_dir = cache_dir or os.path.expanduser('~/.cache/IF_')
         
     | 
| 30 | 
         
            +
                    self.dir_or_name = dir_or_name
         
     | 
| 31 | 
         
            +
                    tokenizer_path, path = dir_or_name, dir_or_name
         
     | 
| 32 | 
         
            +
                    if local_cache:
         
     | 
| 33 | 
         
            +
                        cache_dir = os.path.join(self.cache_dir, dir_or_name)
         
     | 
| 34 | 
         
            +
                        tokenizer_path, path = cache_dir, cache_dir
         
     | 
| 35 | 
         
            +
                    elif dir_or_name in self.available_models:
         
     | 
| 36 | 
         
            +
                        cache_dir = os.path.join(self.cache_dir, dir_or_name)
         
     | 
| 37 | 
         
            +
                        for filename in [
         
     | 
| 38 | 
         
            +
                            'config.json', 'special_tokens_map.json', 'spiece.model', 'tokenizer_config.json',
         
     | 
| 39 | 
         
            +
                            'pytorch_model.bin.index.json', 'pytorch_model-00001-of-00002.bin', 'pytorch_model-00002-of-00002.bin'
         
     | 
| 40 | 
         
            +
                        ]:
         
     | 
| 41 | 
         
            +
                            hf_hub_download(repo_id=f'DeepFloyd/{dir_or_name}', filename=filename, cache_dir=cache_dir,
         
     | 
| 42 | 
         
            +
                                            force_filename=filename, token=self.hf_token)
         
     | 
| 43 | 
         
            +
                        tokenizer_path, path = cache_dir, cache_dir
         
     | 
| 44 | 
         
            +
                    else:
         
     | 
| 45 | 
         
            +
                        cache_dir = os.path.join(self.cache_dir, 't5-v1_1-xxl')
         
     | 
| 46 | 
         
            +
                        for filename in [
         
     | 
| 47 | 
         
            +
                            'config.json', 'special_tokens_map.json', 'spiece.model', 'tokenizer_config.json',
         
     | 
| 48 | 
         
            +
                        ]:
         
     | 
| 49 | 
         
            +
                            hf_hub_download(repo_id='DeepFloyd/t5-v1_1-xxl', filename=filename, cache_dir=cache_dir,
         
     | 
| 50 | 
         
            +
                                            force_filename=filename, token=self.hf_token)
         
     | 
| 51 | 
         
            +
                        tokenizer_path = cache_dir
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
                    print(tokenizer_path)
         
     | 
| 54 | 
         
            +
                    self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
         
     | 
| 55 | 
         
            +
                    self.model = T5EncoderModel.from_pretrained(path, **t5_model_kwargs).eval()
         
     | 
| 56 | 
         
            +
                    self.model_max_length = model_max_length
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
                def get_text_embeddings(self, texts):
         
     | 
| 59 | 
         
            +
                    texts = [self.text_preprocessing(text) for text in texts]
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
                    text_tokens_and_mask = self.tokenizer(
         
     | 
| 62 | 
         
            +
                        texts,
         
     | 
| 63 | 
         
            +
                        max_length=self.model_max_length,
         
     | 
| 64 | 
         
            +
                        padding='max_length',
         
     | 
| 65 | 
         
            +
                        truncation=True,
         
     | 
| 66 | 
         
            +
                        return_attention_mask=True,
         
     | 
| 67 | 
         
            +
                        add_special_tokens=True,
         
     | 
| 68 | 
         
            +
                        return_tensors='pt'
         
     | 
| 69 | 
         
            +
                    )
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
                    text_tokens_and_mask['input_ids'] = text_tokens_and_mask['input_ids']
         
     | 
| 72 | 
         
            +
                    text_tokens_and_mask['attention_mask'] = text_tokens_and_mask['attention_mask']
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
                    with torch.no_grad():
         
     | 
| 75 | 
         
            +
                        text_encoder_embs = self.model(
         
     | 
| 76 | 
         
            +
                            input_ids=text_tokens_and_mask['input_ids'].to(self.device),
         
     | 
| 77 | 
         
            +
                            attention_mask=text_tokens_and_mask['attention_mask'].to(self.device),
         
     | 
| 78 | 
         
            +
                        )['last_hidden_state'].detach()
         
     | 
| 79 | 
         
            +
                    return text_encoder_embs, text_tokens_and_mask['attention_mask'].to(self.device)
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
                def text_preprocessing(self, text):
         
     | 
| 82 | 
         
            +
                    if self.use_text_preprocessing:
         
     | 
| 83 | 
         
            +
                        # The exact text cleaning as was in the training stage:
         
     | 
| 84 | 
         
            +
                        text = self.clean_caption(text)
         
     | 
| 85 | 
         
            +
                        text = self.clean_caption(text)
         
     | 
| 86 | 
         
            +
                        return text
         
     | 
| 87 | 
         
            +
                    else:
         
     | 
| 88 | 
         
            +
                        return text.lower().strip()
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
                @staticmethod
         
     | 
| 91 | 
         
            +
                def basic_clean(text):
         
     | 
| 92 | 
         
            +
                    text = ftfy.fix_text(text)
         
     | 
| 93 | 
         
            +
                    text = html.unescape(html.unescape(text))
         
     | 
| 94 | 
         
            +
                    return text.strip()
         
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
                def clean_caption(self, caption):
         
     | 
| 97 | 
         
            +
                    caption = str(caption)
         
     | 
| 98 | 
         
            +
                    caption = ul.unquote_plus(caption)
         
     | 
| 99 | 
         
            +
                    caption = caption.strip().lower()
         
     | 
| 100 | 
         
            +
                    caption = re.sub('<person>', 'person', caption)
         
     | 
| 101 | 
         
            +
                    # urls:
         
     | 
| 102 | 
         
            +
                    caption = re.sub(
         
     | 
| 103 | 
         
            +
                        r'\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))',  # noqa
         
     | 
| 104 | 
         
            +
                        '', caption)  # regex for urls
         
     | 
| 105 | 
         
            +
                    caption = re.sub(
         
     | 
| 106 | 
         
            +
                        r'\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))',  # noqa
         
     | 
| 107 | 
         
            +
                        '', caption)  # regex for urls
         
     | 
| 108 | 
         
            +
                    # html:
         
     | 
| 109 | 
         
            +
                    caption = BeautifulSoup(caption, features='html.parser').text
         
     | 
| 110 | 
         
            +
             
     | 
| 111 | 
         
            +
                    # @<nickname>
         
     | 
| 112 | 
         
            +
                    caption = re.sub(r'@[\w\d]+\b', '', caption)
         
     | 
| 113 | 
         
            +
             
     | 
| 114 | 
         
            +
                    # 31C0—31EF CJK Strokes
         
     | 
| 115 | 
         
            +
                    # 31F0—31FF Katakana Phonetic Extensions
         
     | 
| 116 | 
         
            +
                    # 3200—32FF Enclosed CJK Letters and Months
         
     | 
| 117 | 
         
            +
                    # 3300—33FF CJK Compatibility
         
     | 
| 118 | 
         
            +
                    # 3400—4DBF CJK Unified Ideographs Extension A
         
     | 
| 119 | 
         
            +
                    # 4DC0—4DFF Yijing Hexagram Symbols
         
     | 
| 120 | 
         
            +
                    # 4E00—9FFF CJK Unified Ideographs
         
     | 
| 121 | 
         
            +
                    caption = re.sub(r'[\u31c0-\u31ef]+', '', caption)
         
     | 
| 122 | 
         
            +
                    caption = re.sub(r'[\u31f0-\u31ff]+', '', caption)
         
     | 
| 123 | 
         
            +
                    caption = re.sub(r'[\u3200-\u32ff]+', '', caption)
         
     | 
| 124 | 
         
            +
                    caption = re.sub(r'[\u3300-\u33ff]+', '', caption)
         
     | 
| 125 | 
         
            +
                    caption = re.sub(r'[\u3400-\u4dbf]+', '', caption)
         
     | 
| 126 | 
         
            +
                    caption = re.sub(r'[\u4dc0-\u4dff]+', '', caption)
         
     | 
| 127 | 
         
            +
                    caption = re.sub(r'[\u4e00-\u9fff]+', '', caption)
         
     | 
| 128 | 
         
            +
                    #######################################################
         
     | 
| 129 | 
         
            +
             
     | 
| 130 | 
         
            +
                    # все виды тире / all types of dash --> "-"
         
     | 
| 131 | 
         
            +
                    caption = re.sub(
         
     | 
| 132 | 
         
            +
                        r'[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+',  # noqa
         
     | 
| 133 | 
         
            +
                        '-', caption)
         
     | 
| 134 | 
         
            +
             
     | 
| 135 | 
         
            +
                    # кавычки к одному стандарту
         
     | 
| 136 | 
         
            +
                    caption = re.sub(r'[`´«»“”¨]', '"', caption)
         
     | 
| 137 | 
         
            +
                    caption = re.sub(r'[‘’]', "'", caption)
         
     | 
| 138 | 
         
            +
             
     | 
| 139 | 
         
            +
                    # "
         
     | 
| 140 | 
         
            +
                    caption = re.sub(r'"?', '', caption)
         
     | 
| 141 | 
         
            +
                    # &
         
     | 
| 142 | 
         
            +
                    caption = re.sub(r'&', '', caption)
         
     | 
| 143 | 
         
            +
             
     | 
| 144 | 
         
            +
                    # ip adresses:
         
     | 
| 145 | 
         
            +
                    caption = re.sub(r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}', ' ', caption)
         
     | 
| 146 | 
         
            +
             
     | 
| 147 | 
         
            +
                    # article ids:
         
     | 
| 148 | 
         
            +
                    caption = re.sub(r'\d:\d\d\s+$', '', caption)
         
     | 
| 149 | 
         
            +
             
     | 
| 150 | 
         
            +
                    # \n
         
     | 
| 151 | 
         
            +
                    caption = re.sub(r'\\n', ' ', caption)
         
     | 
| 152 | 
         
            +
             
     | 
| 153 | 
         
            +
                    # "#123"
         
     | 
| 154 | 
         
            +
                    caption = re.sub(r'#\d{1,3}\b', '', caption)
         
     | 
| 155 | 
         
            +
                    # "#12345.."
         
     | 
| 156 | 
         
            +
                    caption = re.sub(r'#\d{5,}\b', '', caption)
         
     | 
| 157 | 
         
            +
                    # "123456.."
         
     | 
| 158 | 
         
            +
                    caption = re.sub(r'\b\d{6,}\b', '', caption)
         
     | 
| 159 | 
         
            +
                    # filenames:
         
     | 
| 160 | 
         
            +
                    caption = re.sub(r'[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)', '', caption)
         
     | 
| 161 | 
         
            +
             
     | 
| 162 | 
         
            +
                    #
         
     | 
| 163 | 
         
            +
                    caption = re.sub(r'[\"\']{2,}', r'"', caption)  # """AUSVERKAUFT"""
         
     | 
| 164 | 
         
            +
                    caption = re.sub(r'[\.]{2,}', r' ', caption)  # """AUSVERKAUFT"""
         
     | 
| 165 | 
         
            +
             
     | 
| 166 | 
         
            +
                    caption = re.sub(self.bad_punct_regex, r' ', caption)  # ***AUSVERKAUFT***, #AUSVERKAUFT
         
     | 
| 167 | 
         
            +
                    caption = re.sub(r'\s+\.\s+', r' ', caption)  # " . "
         
     | 
| 168 | 
         
            +
             
     | 
| 169 | 
         
            +
                    # this-is-my-cute-cat / this_is_my_cute_cat
         
     | 
| 170 | 
         
            +
                    regex2 = re.compile(r'(?:\-|\_)')
         
     | 
| 171 | 
         
            +
                    if len(re.findall(regex2, caption)) > 3:
         
     | 
| 172 | 
         
            +
                        caption = re.sub(regex2, ' ', caption)
         
     | 
| 173 | 
         
            +
             
     | 
| 174 | 
         
            +
                    caption = self.basic_clean(caption)
         
     | 
| 175 | 
         
            +
             
     | 
| 176 | 
         
            +
                    caption = re.sub(r'\b[a-zA-Z]{1,3}\d{3,15}\b', '', caption)  # jc6640
         
     | 
| 177 | 
         
            +
                    caption = re.sub(r'\b[a-zA-Z]+\d+[a-zA-Z]+\b', '', caption)  # jc6640vc
         
     | 
| 178 | 
         
            +
                    caption = re.sub(r'\b\d+[a-zA-Z]+\d+\b', '', caption)  # 6640vc231
         
     | 
| 179 | 
         
            +
             
     | 
| 180 | 
         
            +
                    caption = re.sub(r'(worldwide\s+)?(free\s+)?shipping', '', caption)
         
     | 
| 181 | 
         
            +
                    caption = re.sub(r'(free\s)?download(\sfree)?', '', caption)
         
     | 
| 182 | 
         
            +
                    caption = re.sub(r'\bclick\b\s(?:for|on)\s\w+', '', caption)
         
     | 
| 183 | 
         
            +
                    caption = re.sub(r'\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?', '', caption)
         
     | 
| 184 | 
         
            +
                    caption = re.sub(r'\bpage\s+\d+\b', '', caption)
         
     | 
| 185 | 
         
            +
             
     | 
| 186 | 
         
            +
                    caption = re.sub(r'\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b', r' ', caption)  # j2d1a2a...
         
     | 
| 187 | 
         
            +
             
     | 
| 188 | 
         
            +
                    caption = re.sub(r'\b\d+\.?\d*[xх×]\d+\.?\d*\b', '', caption)
         
     | 
| 189 | 
         
            +
             
     | 
| 190 | 
         
            +
                    caption = re.sub(r'\b\s+\:\s+', r': ', caption)
         
     | 
| 191 | 
         
            +
                    caption = re.sub(r'(\D[,\./])\b', r'\1 ', caption)
         
     | 
| 192 | 
         
            +
                    caption = re.sub(r'\s+', ' ', caption)
         
     | 
| 193 | 
         
            +
             
     | 
| 194 | 
         
            +
                    caption.strip()
         
     | 
| 195 | 
         
            +
             
     | 
| 196 | 
         
            +
                    caption = re.sub(r'^[\"\']([\w\W]+)[\"\']$', r'\1', caption)
         
     | 
| 197 | 
         
            +
                    caption = re.sub(r'^[\'\_,\-\:;]', r'', caption)
         
     | 
| 198 | 
         
            +
                    caption = re.sub(r'[\'\_,\-\:\-\+]$', r'', caption)
         
     | 
| 199 | 
         
            +
                    caption = re.sub(r'^\.\S+$', '', caption)
         
     | 
| 200 | 
         
            +
             
     | 
| 201 | 
         
            +
                    return caption.strip()
         
     | 
    	
        model.py
    ADDED
    
    | 
         @@ -0,0 +1,242 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import gc
         
     | 
| 2 | 
         
            +
            import spaces
         
     | 
| 3 | 
         
            +
            from safetensors.torch import load_file
         
     | 
| 4 | 
         
            +
            from autoregressive.models.gpt_t2i import GPT_models
         
     | 
| 5 | 
         
            +
            from tokenizer.tokenizer_image.vq_model import VQ_models
         
     | 
| 6 | 
         
            +
            from language.t5 import T5Embedder
         
     | 
| 7 | 
         
            +
            import torch
         
     | 
| 8 | 
         
            +
            import numpy as np
         
     | 
| 9 | 
         
            +
            import PIL
         
     | 
| 10 | 
         
            +
            from PIL import Image
         
     | 
| 11 | 
         
            +
            from condition.canny import CannyDetector
         
     | 
| 12 | 
         
            +
            import time
         
     | 
| 13 | 
         
            +
            from autoregressive.models.generate import generate
         
     | 
| 14 | 
         
            +
            from condition.midas.depth import MidasDetector
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            models = {
         
     | 
| 17 | 
         
            +
                "canny": "checkpoints/t2i/canny_MR.safetensors",
         
     | 
| 18 | 
         
            +
                "depth": "checkpoints/t2i/depth_MR.safetensors",
         
     | 
| 19 | 
         
            +
            }
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
            def resize_image_to_16_multiple(image, condition_type='canny'):
         
     | 
| 23 | 
         
            +
                if isinstance(image, np.ndarray):
         
     | 
| 24 | 
         
            +
                    image = Image.fromarray(image)
         
     | 
| 25 | 
         
            +
                # image = Image.open(image_path)
         
     | 
| 26 | 
         
            +
                width, height = image.size
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
                if condition_type == 'depth':  # The depth model requires a side length that is a multiple of 32
         
     | 
| 29 | 
         
            +
                    new_width = (width + 31) // 32 * 32
         
     | 
| 30 | 
         
            +
                    new_height = (height + 31) // 32 * 32
         
     | 
| 31 | 
         
            +
                else:
         
     | 
| 32 | 
         
            +
                    new_width = (width + 15) // 16 * 16
         
     | 
| 33 | 
         
            +
                    new_height = (height + 15) // 16 * 16
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
                resized_image = image.resize((new_width, new_height))
         
     | 
| 36 | 
         
            +
                return resized_image
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
            class Model:
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                def __init__(self):
         
     | 
| 42 | 
         
            +
                    self.device = torch.device(
         
     | 
| 43 | 
         
            +
                        "cuda:0" if torch.cuda.is_available() else "cpu")
         
     | 
| 44 | 
         
            +
                    self.base_model_id = ""
         
     | 
| 45 | 
         
            +
                    self.task_name = ""
         
     | 
| 46 | 
         
            +
                    self.vq_model = self.load_vq()
         
     | 
| 47 | 
         
            +
                    self.t5_model = self.load_t5()
         
     | 
| 48 | 
         
            +
                    self.gpt_model_canny = self.load_gpt(condition_type='canny')
         
     | 
| 49 | 
         
            +
                    self.gpt_model_depth = self.load_gpt(condition_type='depth')
         
     | 
| 50 | 
         
            +
                    self.get_control_canny = CannyDetector()
         
     | 
| 51 | 
         
            +
                    self.get_control_depth = MidasDetector(device=self.device)
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
                def load_vq(self):
         
     | 
| 54 | 
         
            +
                    vq_model = VQ_models["VQ-16"](codebook_size=16384,
         
     | 
| 55 | 
         
            +
                                                  codebook_embed_dim=8)
         
     | 
| 56 | 
         
            +
                    vq_model.to(self.device)
         
     | 
| 57 | 
         
            +
                    vq_model.eval()
         
     | 
| 58 | 
         
            +
                    checkpoint = torch.load(f"checkpoints/vq_ds16_t2i.pt",
         
     | 
| 59 | 
         
            +
                                            map_location="cpu")
         
     | 
| 60 | 
         
            +
                    vq_model.load_state_dict(checkpoint["model"])
         
     | 
| 61 | 
         
            +
                    del checkpoint
         
     | 
| 62 | 
         
            +
                    print(f"image tokenizer is loaded")
         
     | 
| 63 | 
         
            +
                    return vq_model
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
                def load_gpt(self, condition_type='canny'):
         
     | 
| 66 | 
         
            +
                    gpt_ckpt = models[condition_type]
         
     | 
| 67 | 
         
            +
                    precision = torch.bfloat16
         
     | 
| 68 | 
         
            +
                    latent_size = 768 // 16
         
     | 
| 69 | 
         
            +
                    gpt_model = GPT_models["GPT-XL"](
         
     | 
| 70 | 
         
            +
                        block_size=latent_size**2,
         
     | 
| 71 | 
         
            +
                        cls_token_num=120,
         
     | 
| 72 | 
         
            +
                        model_type='t2i',
         
     | 
| 73 | 
         
            +
                        condition_type=condition_type,
         
     | 
| 74 | 
         
            +
                    ).to(device=self.device, dtype=precision)
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
                    model_weight = load_file(gpt_ckpt)
         
     | 
| 77 | 
         
            +
                    gpt_model.load_state_dict(model_weight, strict=False)
         
     | 
| 78 | 
         
            +
                    gpt_model.eval()
         
     | 
| 79 | 
         
            +
                    print(f"gpt model is loaded")
         
     | 
| 80 | 
         
            +
                    return gpt_model
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
                def load_t5(self):
         
     | 
| 83 | 
         
            +
                    precision = torch.bfloat16
         
     | 
| 84 | 
         
            +
                    t5_model = T5Embedder(
         
     | 
| 85 | 
         
            +
                        device=self.device,
         
     | 
| 86 | 
         
            +
                        local_cache=True,
         
     | 
| 87 | 
         
            +
                        # cache_dir='checkpoints/t5-ckpt',
         
     | 
| 88 | 
         
            +
                        dir_or_name='flan-t5-xl',
         
     | 
| 89 | 
         
            +
                        torch_dtype=precision,
         
     | 
| 90 | 
         
            +
                        model_max_length=120,
         
     | 
| 91 | 
         
            +
                    )
         
     | 
| 92 | 
         
            +
                    return t5_model
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
                @torch.no_grad()
         
     | 
| 95 | 
         
            +
                @spaces.GPU(enable_queue=True)
         
     | 
| 96 | 
         
            +
                def process_canny(
         
     | 
| 97 | 
         
            +
                    self,
         
     | 
| 98 | 
         
            +
                    image: np.ndarray,
         
     | 
| 99 | 
         
            +
                    prompt: str,
         
     | 
| 100 | 
         
            +
                    cfg_scale: float,
         
     | 
| 101 | 
         
            +
                    temperature: float,
         
     | 
| 102 | 
         
            +
                    top_k: int,
         
     | 
| 103 | 
         
            +
                    top_p: int,
         
     | 
| 104 | 
         
            +
                    seed: int,
         
     | 
| 105 | 
         
            +
                    low_threshold: int,
         
     | 
| 106 | 
         
            +
                    high_threshold: int,
         
     | 
| 107 | 
         
            +
                ) -> list[PIL.Image.Image]:
         
     | 
| 108 | 
         
            +
             
     | 
| 109 | 
         
            +
                    image = resize_image_to_16_multiple(image, 'canny')
         
     | 
| 110 | 
         
            +
                    W, H = image.size
         
     | 
| 111 | 
         
            +
                    print(W, H)
         
     | 
| 112 | 
         
            +
                    condition_img = self.get_control_canny(np.array(image), low_threshold,
         
     | 
| 113 | 
         
            +
                                                           high_threshold)
         
     | 
| 114 | 
         
            +
                    condition_img = torch.from_numpy(condition_img[None, None,
         
     | 
| 115 | 
         
            +
                                                                   ...]).repeat(
         
     | 
| 116 | 
         
            +
                                                                       2, 3, 1, 1)
         
     | 
| 117 | 
         
            +
                    condition_img = condition_img.to(self.device)
         
     | 
| 118 | 
         
            +
                    condition_img = 2 * (condition_img / 255 - 0.5)
         
     | 
| 119 | 
         
            +
                    prompts = [prompt] * 2
         
     | 
| 120 | 
         
            +
                    caption_embs, emb_masks = self.t5_model.get_text_embeddings(prompts)
         
     | 
| 121 | 
         
            +
             
     | 
| 122 | 
         
            +
                    print(f"processing left-padding...")
         
     | 
| 123 | 
         
            +
                    new_emb_masks = torch.flip(emb_masks, dims=[-1])
         
     | 
| 124 | 
         
            +
                    new_caption_embs = []
         
     | 
| 125 | 
         
            +
                    for idx, (caption_emb,
         
     | 
| 126 | 
         
            +
                              emb_mask) in enumerate(zip(caption_embs, emb_masks)):
         
     | 
| 127 | 
         
            +
                        valid_num = int(emb_mask.sum().item())
         
     | 
| 128 | 
         
            +
                        print(f'  prompt {idx} token len: {valid_num}')
         
     | 
| 129 | 
         
            +
                        new_caption_emb = torch.cat(
         
     | 
| 130 | 
         
            +
                            [caption_emb[valid_num:], caption_emb[:valid_num]])
         
     | 
| 131 | 
         
            +
                        new_caption_embs.append(new_caption_emb)
         
     | 
| 132 | 
         
            +
                    new_caption_embs = torch.stack(new_caption_embs)
         
     | 
| 133 | 
         
            +
                    c_indices = new_caption_embs * new_emb_masks[:, :, None]
         
     | 
| 134 | 
         
            +
                    c_emb_masks = new_emb_masks
         
     | 
| 135 | 
         
            +
                    qzshape = [len(c_indices), 8, H // 16, W // 16]
         
     | 
| 136 | 
         
            +
                    t1 = time.time()
         
     | 
| 137 | 
         
            +
                    index_sample = generate(
         
     | 
| 138 | 
         
            +
                        self.gpt_model_canny,
         
     | 
| 139 | 
         
            +
                        c_indices,
         
     | 
| 140 | 
         
            +
                        (H // 16) * (W // 16),
         
     | 
| 141 | 
         
            +
                        c_emb_masks,
         
     | 
| 142 | 
         
            +
                        condition=condition_img,
         
     | 
| 143 | 
         
            +
                        cfg_scale=cfg_scale,
         
     | 
| 144 | 
         
            +
                        temperature=temperature,
         
     | 
| 145 | 
         
            +
                        top_k=top_k,
         
     | 
| 146 | 
         
            +
                        top_p=top_p,
         
     | 
| 147 | 
         
            +
                        sample_logits=True,
         
     | 
| 148 | 
         
            +
                    )
         
     | 
| 149 | 
         
            +
                    sampling_time = time.time() - t1
         
     | 
| 150 | 
         
            +
                    print(f"Full sampling takes about {sampling_time:.2f} seconds.")
         
     | 
| 151 | 
         
            +
             
     | 
| 152 | 
         
            +
                    t2 = time.time()
         
     | 
| 153 | 
         
            +
                    print(index_sample.shape)
         
     | 
| 154 | 
         
            +
                    samples = self.vq_model.decode_code(
         
     | 
| 155 | 
         
            +
                        index_sample, qzshape)  # output value is between [-1, 1]
         
     | 
| 156 | 
         
            +
                    decoder_time = time.time() - t2
         
     | 
| 157 | 
         
            +
                    print(f"decoder takes about {decoder_time:.2f} seconds.")
         
     | 
| 158 | 
         
            +
             
     | 
| 159 | 
         
            +
                    samples = torch.cat((condition_img[0:1], samples), dim=0)
         
     | 
| 160 | 
         
            +
                    samples = 255 * (samples * 0.5 + 0.5)
         
     | 
| 161 | 
         
            +
                    samples = [image] + [
         
     | 
| 162 | 
         
            +
                        Image.fromarray(
         
     | 
| 163 | 
         
            +
                            sample.permute(1, 2, 0).cpu().detach().numpy().clip(
         
     | 
| 164 | 
         
            +
                                0, 255).astype(np.uint8)) for sample in samples
         
     | 
| 165 | 
         
            +
                    ]
         
     | 
| 166 | 
         
            +
                    del condition_img
         
     | 
| 167 | 
         
            +
                    torch.cuda.empty_cache()
         
     | 
| 168 | 
         
            +
                    return samples
         
     | 
| 169 | 
         
            +
             
     | 
| 170 | 
         
            +
                @torch.no_grad()
         
     | 
| 171 | 
         
            +
                @spaces.GPU(enable_queue=True)
         
     | 
| 172 | 
         
            +
                def process_depth(
         
     | 
| 173 | 
         
            +
                    self,
         
     | 
| 174 | 
         
            +
                    image: np.ndarray,
         
     | 
| 175 | 
         
            +
                    prompt: str,
         
     | 
| 176 | 
         
            +
                    cfg_scale: float,
         
     | 
| 177 | 
         
            +
                    temperature: float,
         
     | 
| 178 | 
         
            +
                    top_k: int,
         
     | 
| 179 | 
         
            +
                    top_p: int,
         
     | 
| 180 | 
         
            +
                    seed: int,
         
     | 
| 181 | 
         
            +
                ) -> list[PIL.Image.Image]:
         
     | 
| 182 | 
         
            +
                    image = resize_image_to_16_multiple(image, 'depth')
         
     | 
| 183 | 
         
            +
                    W, H = image.size
         
     | 
| 184 | 
         
            +
                    print(W, H)
         
     | 
| 185 | 
         
            +
                    image_tensor = torch.from_numpy(np.array(image)).to(self.device)
         
     | 
| 186 | 
         
            +
                    condition_img = torch.from_numpy(
         
     | 
| 187 | 
         
            +
                        self.get_control_depth(image_tensor)).unsqueeze(0)
         
     | 
| 188 | 
         
            +
                    condition_img = condition_img.unsqueeze(0).repeat(2, 3, 1, 1)
         
     | 
| 189 | 
         
            +
                    condition_img = condition_img.to(self.device)
         
     | 
| 190 | 
         
            +
                    condition_img = 2 * (condition_img / 255 - 0.5)
         
     | 
| 191 | 
         
            +
                    prompts = [prompt] * 2
         
     | 
| 192 | 
         
            +
                    caption_embs, emb_masks = self.t5_model.get_text_embeddings(prompts)
         
     | 
| 193 | 
         
            +
             
     | 
| 194 | 
         
            +
                    print(f"processing left-padding...")
         
     | 
| 195 | 
         
            +
                    new_emb_masks = torch.flip(emb_masks, dims=[-1])
         
     | 
| 196 | 
         
            +
                    new_caption_embs = []
         
     | 
| 197 | 
         
            +
                    for idx, (caption_emb,
         
     | 
| 198 | 
         
            +
                              emb_mask) in enumerate(zip(caption_embs, emb_masks)):
         
     | 
| 199 | 
         
            +
                        valid_num = int(emb_mask.sum().item())
         
     | 
| 200 | 
         
            +
                        print(f'  prompt {idx} token len: {valid_num}')
         
     | 
| 201 | 
         
            +
                        new_caption_emb = torch.cat(
         
     | 
| 202 | 
         
            +
                            [caption_emb[valid_num:], caption_emb[:valid_num]])
         
     | 
| 203 | 
         
            +
                        new_caption_embs.append(new_caption_emb)
         
     | 
| 204 | 
         
            +
                    new_caption_embs = torch.stack(new_caption_embs)
         
     | 
| 205 | 
         
            +
             
     | 
| 206 | 
         
            +
                    c_indices = new_caption_embs * new_emb_masks[:, :, None]
         
     | 
| 207 | 
         
            +
                    c_emb_masks = new_emb_masks
         
     | 
| 208 | 
         
            +
                    qzshape = [len(c_indices), 8, H // 16, W // 16]
         
     | 
| 209 | 
         
            +
                    t1 = time.time()
         
     | 
| 210 | 
         
            +
                    index_sample = generate(
         
     | 
| 211 | 
         
            +
                        self.gpt_model_depth,
         
     | 
| 212 | 
         
            +
                        c_indices,
         
     | 
| 213 | 
         
            +
                        (H // 16) * (W // 16),
         
     | 
| 214 | 
         
            +
                        c_emb_masks,
         
     | 
| 215 | 
         
            +
                        condition=condition_img,
         
     | 
| 216 | 
         
            +
                        cfg_scale=cfg_scale,
         
     | 
| 217 | 
         
            +
                        temperature=temperature,
         
     | 
| 218 | 
         
            +
                        top_k=top_k,
         
     | 
| 219 | 
         
            +
                        top_p=top_p,
         
     | 
| 220 | 
         
            +
                        sample_logits=True,
         
     | 
| 221 | 
         
            +
                    )
         
     | 
| 222 | 
         
            +
                    sampling_time = time.time() - t1
         
     | 
| 223 | 
         
            +
                    print(f"Full sampling takes about {sampling_time:.2f} seconds.")
         
     | 
| 224 | 
         
            +
             
     | 
| 225 | 
         
            +
                    t2 = time.time()
         
     | 
| 226 | 
         
            +
                    print(index_sample.shape)
         
     | 
| 227 | 
         
            +
                    samples = self.vq_model.decode_code(index_sample, qzshape)
         
     | 
| 228 | 
         
            +
                    decoder_time = time.time() - t2
         
     | 
| 229 | 
         
            +
                    print(f"decoder takes about {decoder_time:.2f} seconds.")
         
     | 
| 230 | 
         
            +
                    condition_img = condition_img.cpu()
         
     | 
| 231 | 
         
            +
                    samples = samples.cpu()
         
     | 
| 232 | 
         
            +
                    samples = torch.cat((condition_img[0:1], samples), dim=0)
         
     | 
| 233 | 
         
            +
                    samples = 255 * (samples * 0.5 + 0.5)
         
     | 
| 234 | 
         
            +
                    samples = [image] + [
         
     | 
| 235 | 
         
            +
                        Image.fromarray(
         
     | 
| 236 | 
         
            +
                            sample.permute(1, 2, 0).numpy().clip(0, 255).astype(np.uint8))
         
     | 
| 237 | 
         
            +
                        for sample in samples
         
     | 
| 238 | 
         
            +
                    ]
         
     | 
| 239 | 
         
            +
                    del image_tensor
         
     | 
| 240 | 
         
            +
                    del condition_img
         
     | 
| 241 | 
         
            +
                    torch.cuda.empty_cache()
         
     | 
| 242 | 
         
            +
                    return samples
         
     | 
    	
        style.css
    ADDED
    
    | 
         @@ -0,0 +1,10 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            h1 {
         
     | 
| 2 | 
         
            +
                text-align: center;
         
     | 
| 3 | 
         
            +
              }
         
     | 
| 4 | 
         
            +
              
         
     | 
| 5 | 
         
            +
              #duplicate-button {
         
     | 
| 6 | 
         
            +
                margin: auto;
         
     | 
| 7 | 
         
            +
                color: #fff;
         
     | 
| 8 | 
         
            +
                background: #1565c0;
         
     | 
| 9 | 
         
            +
                border-radius: 100vh;
         
     | 
| 10 | 
         
            +
              }
         
     | 
    	
        tokenizer/consistencydecoder/README.md
    ADDED
    
    | 
         @@ -0,0 +1,14 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            ## Consistency Decoder from OpenAI
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            ### install
         
     | 
| 4 | 
         
            +
            ```
         
     | 
| 5 | 
         
            +
            pip install diffusers
         
     | 
| 6 | 
         
            +
            pip install accelerate
         
     | 
| 7 | 
         
            +
            ```
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            ### demo
         
     | 
| 10 | 
         
            +
            ```
         
     | 
| 11 | 
         
            +
            cd ${THIS_REPO_ROOT}
         
     | 
| 12 | 
         
            +
            python3 tokenizer/consistencydecoder/cd_demo.py
         
     | 
| 13 | 
         
            +
            ```
         
     | 
| 14 | 
         
            +
             
     | 
    	
        tokenizer/consistencydecoder/cd_demo.py
    ADDED
    
    | 
         @@ -0,0 +1,57 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import argparse
         
     | 
| 2 | 
         
            +
            import torch
         
     | 
| 3 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 4 | 
         
            +
            import numpy as np
         
     | 
| 5 | 
         
            +
            from PIL import Image
         
     | 
| 6 | 
         
            +
            from diffusers import ConsistencyDecoderVAE
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            def main(args):
         
     | 
| 10 | 
         
            +
                # Setup PyTorch:
         
     | 
| 11 | 
         
            +
                torch.manual_seed(args.seed)
         
     | 
| 12 | 
         
            +
                torch.set_grad_enabled(False)
         
     | 
| 13 | 
         
            +
                device = "cuda" if torch.cuda.is_available() else "cpu"
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
                # create and load model
         
     | 
| 16 | 
         
            +
                vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=torch.float16).to(device)
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
                # load image
         
     | 
| 19 | 
         
            +
                img_path = args.image_path
         
     | 
| 20 | 
         
            +
                out_path = args.image_path.replace('.jpg', '_cd.jpg').replace('.jpeg', '_cd.jpeg').replace('.png', '_cd.png')
         
     | 
| 21 | 
         
            +
                input_size = args.image_size
         
     | 
| 22 | 
         
            +
                img = Image.open(img_path).convert("RGB")
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
                # preprocess
         
     | 
| 25 | 
         
            +
                size_org = img.size
         
     | 
| 26 | 
         
            +
                img = img.resize((input_size, input_size))
         
     | 
| 27 | 
         
            +
                img = np.array(img) / 255.
         
     | 
| 28 | 
         
            +
                x = 2.0 * img - 1.0 # x value is between [-1, 1]
         
     | 
| 29 | 
         
            +
                x = torch.tensor(x)
         
     | 
| 30 | 
         
            +
                x = x.unsqueeze(dim=0)
         
     | 
| 31 | 
         
            +
                x = torch.einsum('nhwc->nchw', x)
         
     | 
| 32 | 
         
            +
                x_input = x.half().to(device)
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
                # inference
         
     | 
| 35 | 
         
            +
                with torch.no_grad():
         
     | 
| 36 | 
         
            +
                    # Map input images to latent space + normalize latents:
         
     | 
| 37 | 
         
            +
                    latent = vae.encode(x_input).latent_dist.sample().mul_(0.18215)
         
     | 
| 38 | 
         
            +
                    # reconstruct:
         
     | 
| 39 | 
         
            +
                    output = vae.decode(latent / 0.18215).sample # output value is between [-1, 1]
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                # postprocess
         
     | 
| 42 | 
         
            +
                output = F.interpolate(output, size=[size_org[1], size_org[0]], mode='bilinear').permute(0, 2, 3, 1)[0]
         
     | 
| 43 | 
         
            +
                sample = torch.clamp(127.5 * output + 128.0, 0, 255).to("cpu", dtype=torch.uint8).numpy()
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
                # save        
         
     | 
| 46 | 
         
            +
                Image.fromarray(sample).save(out_path)
         
     | 
| 47 | 
         
            +
                print("Reconstructed image is saved to {}".format(out_path))
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 52 | 
         
            +
                parser = argparse.ArgumentParser()
         
     | 
| 53 | 
         
            +
                parser.add_argument("--image-path", type=str, default="assets/example.jpg")
         
     | 
| 54 | 
         
            +
                parser.add_argument("--image-size", type=int, choices=[256, 512, 1024], default=512)
         
     | 
| 55 | 
         
            +
                parser.add_argument("--seed", type=int, default=0)
         
     | 
| 56 | 
         
            +
                args = parser.parse_args()
         
     | 
| 57 | 
         
            +
                main(args)
         
     | 
    	
        tokenizer/consistencydecoder/reconstruction_cd_ddp.py
    ADDED
    
    | 
         @@ -0,0 +1,208 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import torch
         
     | 
| 2 | 
         
            +
            torch.backends.cuda.matmul.allow_tf32 = True
         
     | 
| 3 | 
         
            +
            torch.backends.cudnn.allow_tf32 = True
         
     | 
| 4 | 
         
            +
            import torch.distributed as dist
         
     | 
| 5 | 
         
            +
            from torch.utils.data import Dataset, DataLoader
         
     | 
| 6 | 
         
            +
            from torch.utils.data.distributed import DistributedSampler
         
     | 
| 7 | 
         
            +
            from torchvision.datasets import ImageFolder
         
     | 
| 8 | 
         
            +
            from torchvision import transforms
         
     | 
| 9 | 
         
            +
            from tqdm import tqdm
         
     | 
| 10 | 
         
            +
            import os
         
     | 
| 11 | 
         
            +
            import itertools
         
     | 
| 12 | 
         
            +
            from PIL import Image
         
     | 
| 13 | 
         
            +
            import numpy as np
         
     | 
| 14 | 
         
            +
            import argparse
         
     | 
| 15 | 
         
            +
            import random
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            from skimage.metrics import peak_signal_noise_ratio as psnr_loss
         
     | 
| 18 | 
         
            +
            from skimage.metrics import structural_similarity as ssim_loss
         
     | 
| 19 | 
         
            +
            from diffusers.models import ConsistencyDecoderVAE
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
            class SingleFolderDataset(Dataset):
         
     | 
| 23 | 
         
            +
                def __init__(self, directory, transform=None):
         
     | 
| 24 | 
         
            +
                    super().__init__()
         
     | 
| 25 | 
         
            +
                    self.directory = directory
         
     | 
| 26 | 
         
            +
                    self.transform = transform
         
     | 
| 27 | 
         
            +
                    self.image_paths = [os.path.join(directory, file_name) for file_name in os.listdir(directory)
         
     | 
| 28 | 
         
            +
                                        if os.path.isfile(os.path.join(directory, file_name))]
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
                def __len__(self):
         
     | 
| 31 | 
         
            +
                    return len(self.image_paths)
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
                def __getitem__(self, idx):
         
     | 
| 34 | 
         
            +
                    image_path = self.image_paths[idx]
         
     | 
| 35 | 
         
            +
                    image = Image.open(image_path).convert('RGB')
         
     | 
| 36 | 
         
            +
                    if self.transform:
         
     | 
| 37 | 
         
            +
                        image = self.transform(image)
         
     | 
| 38 | 
         
            +
                    return image, torch.tensor(0)
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
            def create_npz_from_sample_folder(sample_dir, num=50_000):
         
     | 
| 42 | 
         
            +
                """
         
     | 
| 43 | 
         
            +
                Builds a single .npz file from a folder of .png samples.
         
     | 
| 44 | 
         
            +
                """
         
     | 
| 45 | 
         
            +
                samples = []
         
     | 
| 46 | 
         
            +
                for i in tqdm(range(num), desc="Building .npz file from samples"):
         
     | 
| 47 | 
         
            +
                    sample_pil = Image.open(f"{sample_dir}/{i:06d}.png")
         
     | 
| 48 | 
         
            +
                    sample_np = np.asarray(sample_pil).astype(np.uint8)
         
     | 
| 49 | 
         
            +
                    samples.append(sample_np)
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
                random.shuffle(samples) # This is very important for IS(Inception Score) !!!
         
     | 
| 52 | 
         
            +
                samples = np.stack(samples)
         
     | 
| 53 | 
         
            +
                assert samples.shape == (num, samples.shape[1], samples.shape[2], 3)
         
     | 
| 54 | 
         
            +
                npz_path = f"{sample_dir}.npz"
         
     | 
| 55 | 
         
            +
                np.savez(npz_path, arr_0=samples)
         
     | 
| 56 | 
         
            +
                print(f"Saved .npz file to {npz_path} [shape={samples.shape}].")
         
     | 
| 57 | 
         
            +
                return npz_path
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
            def center_crop_arr(pil_image, image_size):
         
     | 
| 61 | 
         
            +
                """
         
     | 
| 62 | 
         
            +
                Center cropping implementation from ADM.
         
     | 
| 63 | 
         
            +
                https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
         
     | 
| 64 | 
         
            +
                """
         
     | 
| 65 | 
         
            +
                while min(*pil_image.size) >= 2 * image_size:
         
     | 
| 66 | 
         
            +
                    pil_image = pil_image.resize(
         
     | 
| 67 | 
         
            +
                        tuple(x // 2 for x in pil_image.size), resample=Image.BOX
         
     | 
| 68 | 
         
            +
                    )
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
                scale = image_size / min(*pil_image.size)
         
     | 
| 71 | 
         
            +
                pil_image = pil_image.resize(
         
     | 
| 72 | 
         
            +
                    tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
         
     | 
| 73 | 
         
            +
                )
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
                arr = np.array(pil_image)
         
     | 
| 76 | 
         
            +
                crop_y = (arr.shape[0] - image_size) // 2
         
     | 
| 77 | 
         
            +
                crop_x = (arr.shape[1] - image_size) // 2
         
     | 
| 78 | 
         
            +
                return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
         
     | 
| 79 | 
         
            +
             
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
            def main(args):
         
     | 
| 82 | 
         
            +
                # Setup PyTorch:
         
     | 
| 83 | 
         
            +
                assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage"
         
     | 
| 84 | 
         
            +
                torch.set_grad_enabled(False)
         
     | 
| 85 | 
         
            +
             
     | 
| 86 | 
         
            +
                # Setup env
         
     | 
| 87 | 
         
            +
                dist.init_process_group("nccl")
         
     | 
| 88 | 
         
            +
                rank = dist.get_rank()
         
     | 
| 89 | 
         
            +
                device = rank % torch.cuda.device_count()
         
     | 
| 90 | 
         
            +
                seed = args.global_seed * dist.get_world_size() + rank
         
     | 
| 91 | 
         
            +
                torch.manual_seed(seed)
         
     | 
| 92 | 
         
            +
                torch.cuda.set_device(device)
         
     | 
| 93 | 
         
            +
                print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
         
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
                # create and load model
         
     | 
| 96 | 
         
            +
                vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=torch.float16).to("cuda:{}".format(device))
         
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
            +
                # Create folder to save samples:
         
     | 
| 99 | 
         
            +
                folder_name = f"openai-consistencydecoder-{args.dataset}-size-{args.image_size}-seed-{args.global_seed}"
         
     | 
| 100 | 
         
            +
                sample_folder_dir = f"{args.sample_dir}/{folder_name}"
         
     | 
| 101 | 
         
            +
                if rank == 0:
         
     | 
| 102 | 
         
            +
                    os.makedirs(sample_folder_dir, exist_ok=True)
         
     | 
| 103 | 
         
            +
                    print(f"Saving .png samples at {sample_folder_dir}")
         
     | 
| 104 | 
         
            +
                dist.barrier()
         
     | 
| 105 | 
         
            +
             
     | 
| 106 | 
         
            +
                # Setup data:
         
     | 
| 107 | 
         
            +
                transform = transforms.Compose([
         
     | 
| 108 | 
         
            +
                    transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.image_size)),
         
     | 
| 109 | 
         
            +
                    transforms.ToTensor(),
         
     | 
| 110 | 
         
            +
                    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
         
     | 
| 111 | 
         
            +
                ])
         
     | 
| 112 | 
         
            +
                if args.dataset == 'imagenet':
         
     | 
| 113 | 
         
            +
                    dataset = ImageFolder(args.data_path, transform=transform)
         
     | 
| 114 | 
         
            +
                    num_fid_samples = 50000
         
     | 
| 115 | 
         
            +
                elif args.dataset == 'coco':
         
     | 
| 116 | 
         
            +
                    dataset = SingleFolderDataset(args.data_path, transform=transform)
         
     | 
| 117 | 
         
            +
                    num_fid_samples = 5000
         
     | 
| 118 | 
         
            +
                else:
         
     | 
| 119 | 
         
            +
                    raise Exception("please check dataset")
         
     | 
| 120 | 
         
            +
                sampler = DistributedSampler(
         
     | 
| 121 | 
         
            +
                    dataset,
         
     | 
| 122 | 
         
            +
                    num_replicas=dist.get_world_size(),
         
     | 
| 123 | 
         
            +
                    rank=rank,
         
     | 
| 124 | 
         
            +
                    shuffle=False,
         
     | 
| 125 | 
         
            +
                    seed=args.global_seed
         
     | 
| 126 | 
         
            +
                )
         
     | 
| 127 | 
         
            +
                loader = DataLoader(
         
     | 
| 128 | 
         
            +
                    dataset,
         
     | 
| 129 | 
         
            +
                    batch_size=args.per_proc_batch_size,
         
     | 
| 130 | 
         
            +
                    shuffle=False,
         
     | 
| 131 | 
         
            +
                    sampler=sampler,
         
     | 
| 132 | 
         
            +
                    num_workers=args.num_workers,
         
     | 
| 133 | 
         
            +
                    pin_memory=True,
         
     | 
| 134 | 
         
            +
                    drop_last=False
         
     | 
| 135 | 
         
            +
                )    
         
     | 
| 136 | 
         
            +
             
     | 
| 137 | 
         
            +
                # Figure out how many samples we need to generate on each GPU and how many iterations we need to run:
         
     | 
| 138 | 
         
            +
                n = args.per_proc_batch_size
         
     | 
| 139 | 
         
            +
                global_batch_size = n * dist.get_world_size()
         
     | 
| 140 | 
         
            +
                psnr_val_rgb = []
         
     | 
| 141 | 
         
            +
                ssim_val_rgb = []
         
     | 
| 142 | 
         
            +
             
     | 
| 143 | 
         
            +
                loader = tqdm(loader) if rank == 0 else loader
         
     | 
| 144 | 
         
            +
                total = 0
         
     | 
| 145 | 
         
            +
                for x, _ in loader:
         
     | 
| 146 | 
         
            +
                    rgb_gts = x
         
     | 
| 147 | 
         
            +
                    rgb_gts = (rgb_gts.permute(0, 2, 3, 1).to("cpu").numpy() + 1.0) / 2.0 # rgb_gt value is between [0, 1]
         
     | 
| 148 | 
         
            +
                    x = x.half().to("cuda:{}".format(device))
         
     | 
| 149 | 
         
            +
                    with torch.no_grad():
         
     | 
| 150 | 
         
            +
                        # Map input images to latent space + normalize latents:
         
     | 
| 151 | 
         
            +
                        latent = vae.encode(x).latent_dist.sample().mul_(0.18215)
         
     | 
| 152 | 
         
            +
                        # reconstruct:
         
     | 
| 153 | 
         
            +
                        samples = vae.decode(latent / 0.18215).sample # output value is between [-1, 1]
         
     | 
| 154 | 
         
            +
                    samples = torch.clamp(127.5 * samples + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
         
     | 
| 155 | 
         
            +
                    
         
     | 
| 156 | 
         
            +
                    # Save samples to disk as individual .png files
         
     | 
| 157 | 
         
            +
                    for i, (sample, rgb_gt) in enumerate(zip(samples, rgb_gts)):
         
     | 
| 158 | 
         
            +
                        index = i * dist.get_world_size() + rank + total
         
     | 
| 159 | 
         
            +
                        Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png")
         
     | 
| 160 | 
         
            +
                        # metric
         
     | 
| 161 | 
         
            +
                        rgb_restored = sample.astype(np.float32) / 255. # rgb_restored value is between [0, 1]
         
     | 
| 162 | 
         
            +
                        psnr = psnr_loss(rgb_restored, rgb_gt)
         
     | 
| 163 | 
         
            +
                        ssim = ssim_loss(rgb_restored, rgb_gt, multichannel=True, data_range=2.0, channel_axis=-1)
         
     | 
| 164 | 
         
            +
                        psnr_val_rgb.append(psnr)
         
     | 
| 165 | 
         
            +
                        ssim_val_rgb.append(ssim)
         
     | 
| 166 | 
         
            +
                    total += global_batch_size
         
     | 
| 167 | 
         
            +
             
     | 
| 168 | 
         
            +
                # ------------------------------------
         
     | 
| 169 | 
         
            +
                #       Summary
         
     | 
| 170 | 
         
            +
                # ------------------------------------
         
     | 
| 171 | 
         
            +
                # Make sure all processes have finished saving their samples
         
     | 
| 172 | 
         
            +
                dist.barrier()
         
     | 
| 173 | 
         
            +
                world_size = dist.get_world_size()
         
     | 
| 174 | 
         
            +
                gather_psnr_val = [None for _ in range(world_size)]
         
     | 
| 175 | 
         
            +
                gather_ssim_val = [None for _ in range(world_size)]
         
     | 
| 176 | 
         
            +
                dist.all_gather_object(gather_psnr_val, psnr_val_rgb)
         
     | 
| 177 | 
         
            +
                dist.all_gather_object(gather_ssim_val, ssim_val_rgb)
         
     | 
| 178 | 
         
            +
             
     | 
| 179 | 
         
            +
                if rank == 0:
         
     | 
| 180 | 
         
            +
                    gather_psnr_val = list(itertools.chain(*gather_psnr_val))
         
     | 
| 181 | 
         
            +
                    gather_ssim_val = list(itertools.chain(*gather_ssim_val))        
         
     | 
| 182 | 
         
            +
                    psnr_val_rgb = sum(gather_psnr_val) / len(gather_psnr_val)
         
     | 
| 183 | 
         
            +
                    ssim_val_rgb = sum(gather_ssim_val) / len(gather_ssim_val)
         
     | 
| 184 | 
         
            +
                    print("PSNR: %f, SSIM: %f " % (psnr_val_rgb, ssim_val_rgb))
         
     | 
| 185 | 
         
            +
             
     | 
| 186 | 
         
            +
                    result_file = f"{sample_folder_dir}_results.txt"
         
     | 
| 187 | 
         
            +
                    print("writing results to {}".format(result_file))
         
     | 
| 188 | 
         
            +
                    with open(result_file, 'w') as f:
         
     | 
| 189 | 
         
            +
                        print("PSNR: %f, SSIM: %f " % (psnr_val_rgb, ssim_val_rgb), file=f)
         
     | 
| 190 | 
         
            +
             
     | 
| 191 | 
         
            +
                    create_npz_from_sample_folder(sample_folder_dir, num_fid_samples)
         
     | 
| 192 | 
         
            +
                    print("Done.")
         
     | 
| 193 | 
         
            +
                
         
     | 
| 194 | 
         
            +
                dist.barrier()
         
     | 
| 195 | 
         
            +
                dist.destroy_process_group()
         
     | 
| 196 | 
         
            +
             
     | 
| 197 | 
         
            +
             
     | 
| 198 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 199 | 
         
            +
                parser = argparse.ArgumentParser()
         
     | 
| 200 | 
         
            +
                parser.add_argument("--data-path", type=str, required=True)
         
     | 
| 201 | 
         
            +
                parser.add_argument("--dataset", type=str, choices=['imagenet', 'coco'], default='imagenet')
         
     | 
| 202 | 
         
            +
                parser.add_argument("--image-size", type=int, choices=[256, 512], default=256)
         
     | 
| 203 | 
         
            +
                parser.add_argument("--sample-dir", type=str, default="reconstructions")
         
     | 
| 204 | 
         
            +
                parser.add_argument("--per-proc-batch-size", type=int, default=32)
         
     | 
| 205 | 
         
            +
                parser.add_argument("--global-seed", type=int, default=0)
         
     | 
| 206 | 
         
            +
                parser.add_argument("--num-workers", type=int, default=4)
         
     | 
| 207 | 
         
            +
                args = parser.parse_args()
         
     | 
| 208 | 
         
            +
                main(args)
         
     | 
    	
        tokenizer/tokenizer_image/cache/vgg.pth
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:a78928a0af1e5f0fcb1f3b9e8f8c3a2a5a3de244d830ad5c1feddc79b8432868
         
     | 
| 3 | 
         
            +
            size 7289
         
     | 
    	
        tokenizer/tokenizer_image/discriminator.py
    ADDED
    
    | 
         @@ -0,0 +1,255 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Modified from:
         
     | 
| 2 | 
         
            +
            #   taming-transformers:  https://github.com/CompVis/taming-transformers
         
     | 
| 3 | 
         
            +
            #   stylegan2-pytorch:    https://github.com/rosinality/stylegan2-pytorch/blob/master/model.py
         
     | 
| 4 | 
         
            +
            #   maskgit: https://github.com/google-research/maskgit/blob/main/maskgit/nets/discriminator.py
         
     | 
| 5 | 
         
            +
            import functools
         
     | 
| 6 | 
         
            +
            import math
         
     | 
| 7 | 
         
            +
            import torch
         
     | 
| 8 | 
         
            +
            import torch.nn as nn
         
     | 
| 9 | 
         
            +
            try:
         
     | 
| 10 | 
         
            +
                from kornia.filters import filter2d
         
     | 
| 11 | 
         
            +
            except:
         
     | 
| 12 | 
         
            +
                pass
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            #################################################################################
         
     | 
| 15 | 
         
            +
            #                                    PatchGAN                                   #
         
     | 
| 16 | 
         
            +
            #################################################################################
         
     | 
| 17 | 
         
            +
            class PatchGANDiscriminator(nn.Module):
         
     | 
| 18 | 
         
            +
                """Defines a PatchGAN discriminator as in Pix2Pix
         
     | 
| 19 | 
         
            +
                    --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
         
     | 
| 20 | 
         
            +
                """
         
     | 
| 21 | 
         
            +
                def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
         
     | 
| 22 | 
         
            +
                    """Construct a PatchGAN discriminator
         
     | 
| 23 | 
         
            +
                    Parameters:
         
     | 
| 24 | 
         
            +
                        input_nc (int)  -- the number of channels in input images
         
     | 
| 25 | 
         
            +
                        ndf (int)       -- the number of filters in the last conv layer
         
     | 
| 26 | 
         
            +
                        n_layers (int)  -- the number of conv layers in the discriminator
         
     | 
| 27 | 
         
            +
                        norm_layer      -- normalization layer
         
     | 
| 28 | 
         
            +
                    """
         
     | 
| 29 | 
         
            +
                    super(PatchGANDiscriminator, self).__init__()
         
     | 
| 30 | 
         
            +
                    if not use_actnorm:
         
     | 
| 31 | 
         
            +
                        norm_layer = nn.BatchNorm2d
         
     | 
| 32 | 
         
            +
                    else:
         
     | 
| 33 | 
         
            +
                        norm_layer = ActNorm
         
     | 
| 34 | 
         
            +
                    if type(norm_layer) == functools.partial:  # no need to use bias as BatchNorm2d has affine parameters
         
     | 
| 35 | 
         
            +
                        use_bias = norm_layer.func != nn.BatchNorm2d
         
     | 
| 36 | 
         
            +
                    else:
         
     | 
| 37 | 
         
            +
                        use_bias = norm_layer != nn.BatchNorm2d
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
                    kw = 4
         
     | 
| 40 | 
         
            +
                    padw = 1
         
     | 
| 41 | 
         
            +
                    sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
         
     | 
| 42 | 
         
            +
                    nf_mult = 1
         
     | 
| 43 | 
         
            +
                    nf_mult_prev = 1
         
     | 
| 44 | 
         
            +
                    for n in range(1, n_layers):  # gradually increase the number of filters
         
     | 
| 45 | 
         
            +
                        nf_mult_prev = nf_mult
         
     | 
| 46 | 
         
            +
                        nf_mult = min(2 ** n, 8)
         
     | 
| 47 | 
         
            +
                        sequence += [
         
     | 
| 48 | 
         
            +
                            nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
         
     | 
| 49 | 
         
            +
                            norm_layer(ndf * nf_mult),
         
     | 
| 50 | 
         
            +
                            nn.LeakyReLU(0.2, True)
         
     | 
| 51 | 
         
            +
                        ]
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
                    nf_mult_prev = nf_mult
         
     | 
| 54 | 
         
            +
                    nf_mult = min(2 ** n_layers, 8)
         
     | 
| 55 | 
         
            +
                    sequence += [
         
     | 
| 56 | 
         
            +
                        nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
         
     | 
| 57 | 
         
            +
                        norm_layer(ndf * nf_mult),
         
     | 
| 58 | 
         
            +
                        nn.LeakyReLU(0.2, True)
         
     | 
| 59 | 
         
            +
                    ]
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
                    sequence += [
         
     | 
| 62 | 
         
            +
                        nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]  # output 1 channel prediction map
         
     | 
| 63 | 
         
            +
                    self.main = nn.Sequential(*sequence)
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
                    self.apply(self._init_weights)
         
     | 
| 66 | 
         
            +
                
         
     | 
| 67 | 
         
            +
                def _init_weights(self, module):    
         
     | 
| 68 | 
         
            +
                    if isinstance(module, nn.Conv2d):
         
     | 
| 69 | 
         
            +
                        nn.init.normal_(module.weight.data, 0.0, 0.02)
         
     | 
| 70 | 
         
            +
                    elif isinstance(module, nn.BatchNorm2d):
         
     | 
| 71 | 
         
            +
                        nn.init.normal_(module.weight.data, 1.0, 0.02)
         
     | 
| 72 | 
         
            +
                        nn.init.constant_(module.bias.data, 0)
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
                def forward(self, input):
         
     | 
| 75 | 
         
            +
                    """Standard forward."""
         
     | 
| 76 | 
         
            +
                    return self.main(input)
         
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
            class ActNorm(nn.Module):
         
     | 
| 80 | 
         
            +
                def __init__(self, num_features, logdet=False, affine=True,
         
     | 
| 81 | 
         
            +
                             allow_reverse_init=False):
         
     | 
| 82 | 
         
            +
                    assert affine
         
     | 
| 83 | 
         
            +
                    super().__init__()
         
     | 
| 84 | 
         
            +
                    self.logdet = logdet
         
     | 
| 85 | 
         
            +
                    self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
         
     | 
| 86 | 
         
            +
                    self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
         
     | 
| 87 | 
         
            +
                    self.allow_reverse_init = allow_reverse_init
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
                    self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
                def initialize(self, input):
         
     | 
| 92 | 
         
            +
                    with torch.no_grad():
         
     | 
| 93 | 
         
            +
                        flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
         
     | 
| 94 | 
         
            +
                        mean = (
         
     | 
| 95 | 
         
            +
                            flatten.mean(1)
         
     | 
| 96 | 
         
            +
                            .unsqueeze(1)
         
     | 
| 97 | 
         
            +
                            .unsqueeze(2)
         
     | 
| 98 | 
         
            +
                            .unsqueeze(3)
         
     | 
| 99 | 
         
            +
                            .permute(1, 0, 2, 3)
         
     | 
| 100 | 
         
            +
                        )
         
     | 
| 101 | 
         
            +
                        std = (
         
     | 
| 102 | 
         
            +
                            flatten.std(1)
         
     | 
| 103 | 
         
            +
                            .unsqueeze(1)
         
     | 
| 104 | 
         
            +
                            .unsqueeze(2)
         
     | 
| 105 | 
         
            +
                            .unsqueeze(3)
         
     | 
| 106 | 
         
            +
                            .permute(1, 0, 2, 3)
         
     | 
| 107 | 
         
            +
                        )
         
     | 
| 108 | 
         
            +
             
     | 
| 109 | 
         
            +
                        self.loc.data.copy_(-mean)
         
     | 
| 110 | 
         
            +
                        self.scale.data.copy_(1 / (std + 1e-6))
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
                def forward(self, input, reverse=False):
         
     | 
| 113 | 
         
            +
                    if reverse:
         
     | 
| 114 | 
         
            +
                        return self.reverse(input)
         
     | 
| 115 | 
         
            +
                    if len(input.shape) == 2:
         
     | 
| 116 | 
         
            +
                        input = input[:,:,None,None]
         
     | 
| 117 | 
         
            +
                        squeeze = True
         
     | 
| 118 | 
         
            +
                    else:
         
     | 
| 119 | 
         
            +
                        squeeze = False
         
     | 
| 120 | 
         
            +
             
     | 
| 121 | 
         
            +
                    _, _, height, width = input.shape
         
     | 
| 122 | 
         
            +
             
     | 
| 123 | 
         
            +
                    if self.training and self.initialized.item() == 0:
         
     | 
| 124 | 
         
            +
                        self.initialize(input)
         
     | 
| 125 | 
         
            +
                        self.initialized.fill_(1)
         
     | 
| 126 | 
         
            +
             
     | 
| 127 | 
         
            +
                    h = self.scale * (input + self.loc)
         
     | 
| 128 | 
         
            +
             
     | 
| 129 | 
         
            +
                    if squeeze:
         
     | 
| 130 | 
         
            +
                        h = h.squeeze(-1).squeeze(-1)
         
     | 
| 131 | 
         
            +
             
     | 
| 132 | 
         
            +
                    if self.logdet:
         
     | 
| 133 | 
         
            +
                        log_abs = torch.log(torch.abs(self.scale))
         
     | 
| 134 | 
         
            +
                        logdet = height*width*torch.sum(log_abs)
         
     | 
| 135 | 
         
            +
                        logdet = logdet * torch.ones(input.shape[0]).to(input)
         
     | 
| 136 | 
         
            +
                        return h, logdet
         
     | 
| 137 | 
         
            +
             
     | 
| 138 | 
         
            +
                    return h
         
     | 
| 139 | 
         
            +
             
     | 
| 140 | 
         
            +
                def reverse(self, output):
         
     | 
| 141 | 
         
            +
                    if self.training and self.initialized.item() == 0:
         
     | 
| 142 | 
         
            +
                        if not self.allow_reverse_init:
         
     | 
| 143 | 
         
            +
                            raise RuntimeError(
         
     | 
| 144 | 
         
            +
                                "Initializing ActNorm in reverse direction is "
         
     | 
| 145 | 
         
            +
                                "disabled by default. Use allow_reverse_init=True to enable."
         
     | 
| 146 | 
         
            +
                            )
         
     | 
| 147 | 
         
            +
                        else:
         
     | 
| 148 | 
         
            +
                            self.initialize(output)
         
     | 
| 149 | 
         
            +
                            self.initialized.fill_(1)
         
     | 
| 150 | 
         
            +
             
     | 
| 151 | 
         
            +
                    if len(output.shape) == 2:
         
     | 
| 152 | 
         
            +
                        output = output[:,:,None,None]
         
     | 
| 153 | 
         
            +
                        squeeze = True
         
     | 
| 154 | 
         
            +
                    else:
         
     | 
| 155 | 
         
            +
                        squeeze = False
         
     | 
| 156 | 
         
            +
             
     | 
| 157 | 
         
            +
                    h = output / self.scale - self.loc
         
     | 
| 158 | 
         
            +
             
     | 
| 159 | 
         
            +
                    if squeeze:
         
     | 
| 160 | 
         
            +
                        h = h.squeeze(-1).squeeze(-1)
         
     | 
| 161 | 
         
            +
                    return h
         
     | 
| 162 | 
         
            +
             
     | 
| 163 | 
         
            +
             
     | 
| 164 | 
         
            +
             
     | 
| 165 | 
         
            +
            #################################################################################
         
     | 
| 166 | 
         
            +
            #                                    StyleGAN                                   #
         
     | 
| 167 | 
         
            +
            #################################################################################
         
     | 
| 168 | 
         
            +
            class StyleGANDiscriminator(nn.Module):
         
     | 
| 169 | 
         
            +
                def __init__(self, input_nc=3, ndf=64, n_layers=3, channel_multiplier=1, image_size=256):
         
     | 
| 170 | 
         
            +
                    super().__init__()
         
     | 
| 171 | 
         
            +
                    channels = {
         
     | 
| 172 | 
         
            +
                        4: 512,
         
     | 
| 173 | 
         
            +
                        8: 512,
         
     | 
| 174 | 
         
            +
                        16: 512,
         
     | 
| 175 | 
         
            +
                        32: 512,
         
     | 
| 176 | 
         
            +
                        64: 256 * channel_multiplier,
         
     | 
| 177 | 
         
            +
                        128: 128 * channel_multiplier,
         
     | 
| 178 | 
         
            +
                        256: 64 * channel_multiplier,
         
     | 
| 179 | 
         
            +
                        512: 32 * channel_multiplier,
         
     | 
| 180 | 
         
            +
                        1024: 16 * channel_multiplier,
         
     | 
| 181 | 
         
            +
                    }
         
     | 
| 182 | 
         
            +
                    
         
     | 
| 183 | 
         
            +
                    log_size = int(math.log(image_size, 2))
         
     | 
| 184 | 
         
            +
                    in_channel = channels[image_size]
         
     | 
| 185 | 
         
            +
             
     | 
| 186 | 
         
            +
                    blocks = [nn.Conv2d(input_nc, in_channel, 3, padding=1), leaky_relu()]
         
     | 
| 187 | 
         
            +
                    for i in range(log_size, 2, -1):
         
     | 
| 188 | 
         
            +
                        out_channel = channels[2 ** (i - 1)]
         
     | 
| 189 | 
         
            +
                        blocks.append(DiscriminatorBlock(in_channel, out_channel))
         
     | 
| 190 | 
         
            +
                        in_channel = out_channel
         
     | 
| 191 | 
         
            +
                    self.blocks = nn.ModuleList(blocks)
         
     | 
| 192 | 
         
            +
             
     | 
| 193 | 
         
            +
                    self.final_conv = nn.Sequential(
         
     | 
| 194 | 
         
            +
                        nn.Conv2d(in_channel, channels[4], 3, padding=1),
         
     | 
| 195 | 
         
            +
                        leaky_relu(),
         
     | 
| 196 | 
         
            +
                    )
         
     | 
| 197 | 
         
            +
                    self.final_linear = nn.Sequential(
         
     | 
| 198 | 
         
            +
                        nn.Linear(channels[4] * 4 * 4, channels[4]),
         
     | 
| 199 | 
         
            +
                        leaky_relu(),
         
     | 
| 200 | 
         
            +
                        nn.Linear(channels[4], 1)
         
     | 
| 201 | 
         
            +
                    )
         
     | 
| 202 | 
         
            +
                
         
     | 
| 203 | 
         
            +
                def forward(self, x):
         
     | 
| 204 | 
         
            +
                    for block in self.blocks:
         
     | 
| 205 | 
         
            +
                        x = block(x)
         
     | 
| 206 | 
         
            +
                    x = self.final_conv(x)
         
     | 
| 207 | 
         
            +
                    x = x.view(x.shape[0], -1)
         
     | 
| 208 | 
         
            +
                    x = self.final_linear(x)
         
     | 
| 209 | 
         
            +
                    return x
         
     | 
| 210 | 
         
            +
             
     | 
| 211 | 
         
            +
             
     | 
| 212 | 
         
            +
            class DiscriminatorBlock(nn.Module):
         
     | 
| 213 | 
         
            +
                def __init__(self, input_channels, filters, downsample=True):
         
     | 
| 214 | 
         
            +
                    super().__init__()
         
     | 
| 215 | 
         
            +
                    self.conv_res = nn.Conv2d(input_channels, filters, 1, stride = (2 if downsample else 1))
         
     | 
| 216 | 
         
            +
             
     | 
| 217 | 
         
            +
                    self.net = nn.Sequential(
         
     | 
| 218 | 
         
            +
                        nn.Conv2d(input_channels, filters, 3, padding=1),
         
     | 
| 219 | 
         
            +
                        leaky_relu(),
         
     | 
| 220 | 
         
            +
                        nn.Conv2d(filters, filters, 3, padding=1),
         
     | 
| 221 | 
         
            +
                        leaky_relu()
         
     | 
| 222 | 
         
            +
                    )
         
     | 
| 223 | 
         
            +
             
     | 
| 224 | 
         
            +
                    self.downsample = nn.Sequential(
         
     | 
| 225 | 
         
            +
                        Blur(),
         
     | 
| 226 | 
         
            +
                        nn.Conv2d(filters, filters, 3, padding = 1, stride = 2)
         
     | 
| 227 | 
         
            +
                    ) if downsample else None
         
     | 
| 228 | 
         
            +
             
     | 
| 229 | 
         
            +
                def forward(self, x):
         
     | 
| 230 | 
         
            +
                    res = self.conv_res(x)
         
     | 
| 231 | 
         
            +
                    x = self.net(x)
         
     | 
| 232 | 
         
            +
                    if exists(self.downsample):
         
     | 
| 233 | 
         
            +
                        x = self.downsample(x)
         
     | 
| 234 | 
         
            +
                    x = (x + res) * (1 / math.sqrt(2))
         
     | 
| 235 | 
         
            +
                    return x
         
     | 
| 236 | 
         
            +
             
     | 
| 237 | 
         
            +
             
     | 
| 238 | 
         
            +
            class Blur(nn.Module):
         
     | 
| 239 | 
         
            +
                def __init__(self):
         
     | 
| 240 | 
         
            +
                    super().__init__()
         
     | 
| 241 | 
         
            +
                    f = torch.Tensor([1, 2, 1])
         
     | 
| 242 | 
         
            +
                    self.register_buffer('f', f)
         
     | 
| 243 | 
         
            +
                
         
     | 
| 244 | 
         
            +
                def forward(self, x):
         
     | 
| 245 | 
         
            +
                    f = self.f
         
     | 
| 246 | 
         
            +
                    f = f[None, None, :] * f [None, :, None]
         
     | 
| 247 | 
         
            +
                    return filter2d(x, f, normalized=True)
         
     | 
| 248 | 
         
            +
             
     | 
| 249 | 
         
            +
             
     | 
| 250 | 
         
            +
            def leaky_relu(p=0.2):
         
     | 
| 251 | 
         
            +
                return nn.LeakyReLU(p, inplace=True)
         
     | 
| 252 | 
         
            +
             
     | 
| 253 | 
         
            +
             
     | 
| 254 | 
         
            +
            def exists(val):
         
     | 
| 255 | 
         
            +
                return val is not None
         
     | 
    	
        tokenizer/tokenizer_image/discriminator_patchgan.py
    ADDED
    
    | 
         @@ -0,0 +1,152 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Modified from:
         
     | 
| 2 | 
         
            +
            #   taming-transformers:  https://github.com/CompVis/taming-transformers
         
     | 
| 3 | 
         
            +
            import functools
         
     | 
| 4 | 
         
            +
            import torch
         
     | 
| 5 | 
         
            +
            import torch.nn as nn
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            class NLayerDiscriminator(nn.Module):
         
     | 
| 9 | 
         
            +
                """Defines a PatchGAN discriminator as in Pix2Pix
         
     | 
| 10 | 
         
            +
                    --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
         
     | 
| 11 | 
         
            +
                """
         
     | 
| 12 | 
         
            +
                def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
         
     | 
| 13 | 
         
            +
                    """Construct a PatchGAN discriminator
         
     | 
| 14 | 
         
            +
                    Parameters:
         
     | 
| 15 | 
         
            +
                        input_nc (int)  -- the number of channels in input images
         
     | 
| 16 | 
         
            +
                        ndf (int)       -- the number of filters in the last conv layer
         
     | 
| 17 | 
         
            +
                        n_layers (int)  -- the number of conv layers in the discriminator
         
     | 
| 18 | 
         
            +
                        norm_layer      -- normalization layer
         
     | 
| 19 | 
         
            +
                    """
         
     | 
| 20 | 
         
            +
                    super(NLayerDiscriminator, self).__init__()
         
     | 
| 21 | 
         
            +
                    if not use_actnorm:
         
     | 
| 22 | 
         
            +
                        norm_layer = nn.BatchNorm2d
         
     | 
| 23 | 
         
            +
                    else:
         
     | 
| 24 | 
         
            +
                        norm_layer = ActNorm
         
     | 
| 25 | 
         
            +
                    if type(norm_layer) == functools.partial:  # no need to use bias as BatchNorm2d has affine parameters
         
     | 
| 26 | 
         
            +
                        use_bias = norm_layer.func != nn.BatchNorm2d
         
     | 
| 27 | 
         
            +
                    else:
         
     | 
| 28 | 
         
            +
                        use_bias = norm_layer != nn.BatchNorm2d
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
                    kw = 4
         
     | 
| 31 | 
         
            +
                    padw = 1
         
     | 
| 32 | 
         
            +
                    sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
         
     | 
| 33 | 
         
            +
                    nf_mult = 1
         
     | 
| 34 | 
         
            +
                    nf_mult_prev = 1
         
     | 
| 35 | 
         
            +
                    for n in range(1, n_layers):  # gradually increase the number of filters
         
     | 
| 36 | 
         
            +
                        nf_mult_prev = nf_mult
         
     | 
| 37 | 
         
            +
                        nf_mult = min(2 ** n, 8)
         
     | 
| 38 | 
         
            +
                        sequence += [
         
     | 
| 39 | 
         
            +
                            nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
         
     | 
| 40 | 
         
            +
                            norm_layer(ndf * nf_mult),
         
     | 
| 41 | 
         
            +
                            nn.LeakyReLU(0.2, True)
         
     | 
| 42 | 
         
            +
                        ]
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
                    nf_mult_prev = nf_mult
         
     | 
| 45 | 
         
            +
                    nf_mult = min(2 ** n_layers, 8)
         
     | 
| 46 | 
         
            +
                    sequence += [
         
     | 
| 47 | 
         
            +
                        nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
         
     | 
| 48 | 
         
            +
                        norm_layer(ndf * nf_mult),
         
     | 
| 49 | 
         
            +
                        nn.LeakyReLU(0.2, True)
         
     | 
| 50 | 
         
            +
                    ]
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
                    sequence += [
         
     | 
| 53 | 
         
            +
                        nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]  # output 1 channel prediction map
         
     | 
| 54 | 
         
            +
                    self.main = nn.Sequential(*sequence)
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
                    self.apply(self._init_weights)
         
     | 
| 57 | 
         
            +
                
         
     | 
| 58 | 
         
            +
                def _init_weights(self, module):    
         
     | 
| 59 | 
         
            +
                    if isinstance(module, nn.Conv2d):
         
     | 
| 60 | 
         
            +
                        nn.init.normal_(module.weight.data, 0.0, 0.02)
         
     | 
| 61 | 
         
            +
                    elif isinstance(module, nn.BatchNorm2d):
         
     | 
| 62 | 
         
            +
                        nn.init.normal_(module.weight.data, 1.0, 0.02)
         
     | 
| 63 | 
         
            +
                        nn.init.constant_(module.bias.data, 0)
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
                def forward(self, input):
         
     | 
| 66 | 
         
            +
                    """Standard forward."""
         
     | 
| 67 | 
         
            +
                    return self.main(input)
         
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
            class ActNorm(nn.Module):
         
     | 
| 71 | 
         
            +
                def __init__(self, num_features, logdet=False, affine=True,
         
     | 
| 72 | 
         
            +
                             allow_reverse_init=False):
         
     | 
| 73 | 
         
            +
                    assert affine
         
     | 
| 74 | 
         
            +
                    super().__init__()
         
     | 
| 75 | 
         
            +
                    self.logdet = logdet
         
     | 
| 76 | 
         
            +
                    self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
         
     | 
| 77 | 
         
            +
                    self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
         
     | 
| 78 | 
         
            +
                    self.allow_reverse_init = allow_reverse_init
         
     | 
| 79 | 
         
            +
             
     | 
| 80 | 
         
            +
                    self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
                def initialize(self, input):
         
     | 
| 83 | 
         
            +
                    with torch.no_grad():
         
     | 
| 84 | 
         
            +
                        flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
         
     | 
| 85 | 
         
            +
                        mean = (
         
     | 
| 86 | 
         
            +
                            flatten.mean(1)
         
     | 
| 87 | 
         
            +
                            .unsqueeze(1)
         
     | 
| 88 | 
         
            +
                            .unsqueeze(2)
         
     | 
| 89 | 
         
            +
                            .unsqueeze(3)
         
     | 
| 90 | 
         
            +
                            .permute(1, 0, 2, 3)
         
     | 
| 91 | 
         
            +
                        )
         
     | 
| 92 | 
         
            +
                        std = (
         
     | 
| 93 | 
         
            +
                            flatten.std(1)
         
     | 
| 94 | 
         
            +
                            .unsqueeze(1)
         
     | 
| 95 | 
         
            +
                            .unsqueeze(2)
         
     | 
| 96 | 
         
            +
                            .unsqueeze(3)
         
     | 
| 97 | 
         
            +
                            .permute(1, 0, 2, 3)
         
     | 
| 98 | 
         
            +
                        )
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
                        self.loc.data.copy_(-mean)
         
     | 
| 101 | 
         
            +
                        self.scale.data.copy_(1 / (std + 1e-6))
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
                def forward(self, input, reverse=False):
         
     | 
| 104 | 
         
            +
                    if reverse:
         
     | 
| 105 | 
         
            +
                        return self.reverse(input)
         
     | 
| 106 | 
         
            +
                    if len(input.shape) == 2:
         
     | 
| 107 | 
         
            +
                        input = input[:,:,None,None]
         
     | 
| 108 | 
         
            +
                        squeeze = True
         
     | 
| 109 | 
         
            +
                    else:
         
     | 
| 110 | 
         
            +
                        squeeze = False
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
                    _, _, height, width = input.shape
         
     | 
| 113 | 
         
            +
             
     | 
| 114 | 
         
            +
                    if self.training and self.initialized.item() == 0:
         
     | 
| 115 | 
         
            +
                        self.initialize(input)
         
     | 
| 116 | 
         
            +
                        self.initialized.fill_(1)
         
     | 
| 117 | 
         
            +
             
     | 
| 118 | 
         
            +
                    h = self.scale * (input + self.loc)
         
     | 
| 119 | 
         
            +
             
     | 
| 120 | 
         
            +
                    if squeeze:
         
     | 
| 121 | 
         
            +
                        h = h.squeeze(-1).squeeze(-1)
         
     | 
| 122 | 
         
            +
             
     | 
| 123 | 
         
            +
                    if self.logdet:
         
     | 
| 124 | 
         
            +
                        log_abs = torch.log(torch.abs(self.scale))
         
     | 
| 125 | 
         
            +
                        logdet = height*width*torch.sum(log_abs)
         
     | 
| 126 | 
         
            +
                        logdet = logdet * torch.ones(input.shape[0]).to(input)
         
     | 
| 127 | 
         
            +
                        return h, logdet
         
     | 
| 128 | 
         
            +
             
     | 
| 129 | 
         
            +
                    return h
         
     | 
| 130 | 
         
            +
             
     | 
| 131 | 
         
            +
                def reverse(self, output):
         
     | 
| 132 | 
         
            +
                    if self.training and self.initialized.item() == 0:
         
     | 
| 133 | 
         
            +
                        if not self.allow_reverse_init:
         
     | 
| 134 | 
         
            +
                            raise RuntimeError(
         
     | 
| 135 | 
         
            +
                                "Initializing ActNorm in reverse direction is "
         
     | 
| 136 | 
         
            +
                                "disabled by default. Use allow_reverse_init=True to enable."
         
     | 
| 137 | 
         
            +
                            )
         
     | 
| 138 | 
         
            +
                        else:
         
     | 
| 139 | 
         
            +
                            self.initialize(output)
         
     | 
| 140 | 
         
            +
                            self.initialized.fill_(1)
         
     | 
| 141 | 
         
            +
             
     | 
| 142 | 
         
            +
                    if len(output.shape) == 2:
         
     | 
| 143 | 
         
            +
                        output = output[:,:,None,None]
         
     | 
| 144 | 
         
            +
                        squeeze = True
         
     | 
| 145 | 
         
            +
                    else:
         
     | 
| 146 | 
         
            +
                        squeeze = False
         
     | 
| 147 | 
         
            +
             
     | 
| 148 | 
         
            +
                    h = output / self.scale - self.loc
         
     | 
| 149 | 
         
            +
             
     | 
| 150 | 
         
            +
                    if squeeze:
         
     | 
| 151 | 
         
            +
                        h = h.squeeze(-1).squeeze(-1)
         
     | 
| 152 | 
         
            +
                    return h
         
     | 
    	
        tokenizer/tokenizer_image/discriminator_stylegan.py
    ADDED
    
    | 
         @@ -0,0 +1,101 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Modified from:
         
     | 
| 2 | 
         
            +
            #   stylegan2-pytorch: https://github.com/lucidrains/stylegan2-pytorch/blob/master/stylegan2_pytorch/stylegan2_pytorch.py
         
     | 
| 3 | 
         
            +
            #   stylegan2-pytorch: https://github.com/rosinality/stylegan2-pytorch/blob/master/model.py
         
     | 
| 4 | 
         
            +
            #   maskgit: https://github.com/google-research/maskgit/blob/main/maskgit/nets/discriminator.py
         
     | 
| 5 | 
         
            +
            import math
         
     | 
| 6 | 
         
            +
            import torch
         
     | 
| 7 | 
         
            +
            import torch.nn as nn
         
     | 
| 8 | 
         
            +
            try:
         
     | 
| 9 | 
         
            +
                from kornia.filters import filter2d
         
     | 
| 10 | 
         
            +
            except:
         
     | 
| 11 | 
         
            +
                pass
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            class Discriminator(nn.Module):
         
     | 
| 14 | 
         
            +
                def __init__(self, input_nc=3, ndf=64, n_layers=3, channel_multiplier=1, image_size=256):
         
     | 
| 15 | 
         
            +
                    super().__init__()
         
     | 
| 16 | 
         
            +
                    channels = {
         
     | 
| 17 | 
         
            +
                        4: 512,
         
     | 
| 18 | 
         
            +
                        8: 512,
         
     | 
| 19 | 
         
            +
                        16: 512,
         
     | 
| 20 | 
         
            +
                        32: 512,
         
     | 
| 21 | 
         
            +
                        64: 256 * channel_multiplier,
         
     | 
| 22 | 
         
            +
                        128: 128 * channel_multiplier,
         
     | 
| 23 | 
         
            +
                        256: 64 * channel_multiplier,
         
     | 
| 24 | 
         
            +
                        512: 32 * channel_multiplier,
         
     | 
| 25 | 
         
            +
                        1024: 16 * channel_multiplier,
         
     | 
| 26 | 
         
            +
                    }
         
     | 
| 27 | 
         
            +
                    
         
     | 
| 28 | 
         
            +
                    log_size = int(math.log(image_size, 2))
         
     | 
| 29 | 
         
            +
                    in_channel = channels[image_size]
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
                    blocks = [nn.Conv2d(input_nc, in_channel, 3, padding=1), leaky_relu()]
         
     | 
| 32 | 
         
            +
                    for i in range(log_size, 2, -1):
         
     | 
| 33 | 
         
            +
                        out_channel = channels[2 ** (i - 1)]
         
     | 
| 34 | 
         
            +
                        blocks.append(DiscriminatorBlock(in_channel, out_channel))
         
     | 
| 35 | 
         
            +
                        in_channel = out_channel
         
     | 
| 36 | 
         
            +
                    self.blocks = nn.ModuleList(blocks)
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
                    self.final_conv = nn.Sequential(
         
     | 
| 39 | 
         
            +
                        nn.Conv2d(in_channel, channels[4], 3, padding=1),
         
     | 
| 40 | 
         
            +
                        leaky_relu(),
         
     | 
| 41 | 
         
            +
                    )
         
     | 
| 42 | 
         
            +
                    self.final_linear = nn.Sequential(
         
     | 
| 43 | 
         
            +
                        nn.Linear(channels[4] * 4 * 4, channels[4]),
         
     | 
| 44 | 
         
            +
                        leaky_relu(),
         
     | 
| 45 | 
         
            +
                        nn.Linear(channels[4], 1)
         
     | 
| 46 | 
         
            +
                    )
         
     | 
| 47 | 
         
            +
                
         
     | 
| 48 | 
         
            +
                def forward(self, x):
         
     | 
| 49 | 
         
            +
                    for block in self.blocks:
         
     | 
| 50 | 
         
            +
                        x = block(x)
         
     | 
| 51 | 
         
            +
                    x = self.final_conv(x)
         
     | 
| 52 | 
         
            +
                    x = x.view(x.shape[0], -1)
         
     | 
| 53 | 
         
            +
                    x = self.final_linear(x)
         
     | 
| 54 | 
         
            +
                    return x
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
            class DiscriminatorBlock(nn.Module):
         
     | 
| 58 | 
         
            +
                def __init__(self, input_channels, filters, downsample=True):
         
     | 
| 59 | 
         
            +
                    super().__init__()
         
     | 
| 60 | 
         
            +
                    self.conv_res = nn.Conv2d(input_channels, filters, 1, stride = (2 if downsample else 1))
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
                    self.net = nn.Sequential(
         
     | 
| 63 | 
         
            +
                        nn.Conv2d(input_channels, filters, 3, padding=1),
         
     | 
| 64 | 
         
            +
                        leaky_relu(),
         
     | 
| 65 | 
         
            +
                        nn.Conv2d(filters, filters, 3, padding=1),
         
     | 
| 66 | 
         
            +
                        leaky_relu()
         
     | 
| 67 | 
         
            +
                    )
         
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
            +
                    self.downsample = nn.Sequential(
         
     | 
| 70 | 
         
            +
                        Blur(),
         
     | 
| 71 | 
         
            +
                        nn.Conv2d(filters, filters, 3, padding = 1, stride = 2)
         
     | 
| 72 | 
         
            +
                    ) if downsample else None
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
                def forward(self, x):
         
     | 
| 75 | 
         
            +
                    res = self.conv_res(x)
         
     | 
| 76 | 
         
            +
                    x = self.net(x)
         
     | 
| 77 | 
         
            +
                    if exists(self.downsample):
         
     | 
| 78 | 
         
            +
                        x = self.downsample(x)
         
     | 
| 79 | 
         
            +
                    x = (x + res) * (1 / math.sqrt(2))
         
     | 
| 80 | 
         
            +
                    return x
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
            class Blur(nn.Module):
         
     | 
| 85 | 
         
            +
                def __init__(self):
         
     | 
| 86 | 
         
            +
                    super().__init__()
         
     | 
| 87 | 
         
            +
                    f = torch.Tensor([1, 2, 1])
         
     | 
| 88 | 
         
            +
                    self.register_buffer('f', f)
         
     | 
| 89 | 
         
            +
                
         
     | 
| 90 | 
         
            +
                def forward(self, x):
         
     | 
| 91 | 
         
            +
                    f = self.f
         
     | 
| 92 | 
         
            +
                    f = f[None, None, :] * f [None, :, None]
         
     | 
| 93 | 
         
            +
                    return filter2d(x, f, normalized=True)
         
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
            def leaky_relu(p=0.2):
         
     | 
| 97 | 
         
            +
                return nn.LeakyReLU(p, inplace=True)
         
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
            def exists(val):
         
     | 
| 101 | 
         
            +
                return val is not None
         
     | 
    	
        tokenizer/tokenizer_image/lpips.py
    ADDED
    
    | 
         @@ -0,0 +1,164 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            import os, hashlib
         
     | 
| 4 | 
         
            +
            import requests
         
     | 
| 5 | 
         
            +
            from tqdm import tqdm
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            import torch
         
     | 
| 8 | 
         
            +
            import torch.nn as nn
         
     | 
| 9 | 
         
            +
            from torchvision import models
         
     | 
| 10 | 
         
            +
            from collections import namedtuple
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            URL_MAP = {
         
     | 
| 13 | 
         
            +
                "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"
         
     | 
| 14 | 
         
            +
            }
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            CKPT_MAP = {
         
     | 
| 17 | 
         
            +
                "vgg_lpips": "vgg.pth"
         
     | 
| 18 | 
         
            +
            }
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            MD5_MAP = {
         
     | 
| 21 | 
         
            +
                "vgg_lpips": "d507d7349b931f0638a25a48a722f98a"
         
     | 
| 22 | 
         
            +
            }
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            def download(url, local_path, chunk_size=1024):
         
     | 
| 25 | 
         
            +
                os.makedirs(os.path.split(local_path)[0], exist_ok=True)
         
     | 
| 26 | 
         
            +
                with requests.get(url, stream=True) as r:
         
     | 
| 27 | 
         
            +
                    total_size = int(r.headers.get("content-length", 0))
         
     | 
| 28 | 
         
            +
                    with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
         
     | 
| 29 | 
         
            +
                        with open(local_path, "wb") as f:
         
     | 
| 30 | 
         
            +
                            for data in r.iter_content(chunk_size=chunk_size):
         
     | 
| 31 | 
         
            +
                                if data:
         
     | 
| 32 | 
         
            +
                                    f.write(data)
         
     | 
| 33 | 
         
            +
                                    pbar.update(chunk_size)
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
            def md5_hash(path):
         
     | 
| 37 | 
         
            +
                with open(path, "rb") as f:
         
     | 
| 38 | 
         
            +
                    content = f.read()
         
     | 
| 39 | 
         
            +
                return hashlib.md5(content).hexdigest()
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
            def get_ckpt_path(name, root, check=False):
         
     | 
| 43 | 
         
            +
                assert name in URL_MAP
         
     | 
| 44 | 
         
            +
                path = os.path.join(root, CKPT_MAP[name])
         
     | 
| 45 | 
         
            +
                if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
         
     | 
| 46 | 
         
            +
                    print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
         
     | 
| 47 | 
         
            +
                    download(URL_MAP[name], path)
         
     | 
| 48 | 
         
            +
                    md5 = md5_hash(path)
         
     | 
| 49 | 
         
            +
                    assert md5 == MD5_MAP[name], md5
         
     | 
| 50 | 
         
            +
                return path
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
            class LPIPS(nn.Module):
         
     | 
| 54 | 
         
            +
                # Learned perceptual metric
         
     | 
| 55 | 
         
            +
                def __init__(self, use_dropout=True):
         
     | 
| 56 | 
         
            +
                    super().__init__()
         
     | 
| 57 | 
         
            +
                    self.scaling_layer = ScalingLayer()
         
     | 
| 58 | 
         
            +
                    self.chns = [64, 128, 256, 512, 512]  # vg16 features
         
     | 
| 59 | 
         
            +
                    self.net = vgg16(pretrained=True, requires_grad=False)
         
     | 
| 60 | 
         
            +
                    self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
         
     | 
| 61 | 
         
            +
                    self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
         
     | 
| 62 | 
         
            +
                    self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
         
     | 
| 63 | 
         
            +
                    self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
         
     | 
| 64 | 
         
            +
                    self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
         
     | 
| 65 | 
         
            +
                    self.load_from_pretrained()
         
     | 
| 66 | 
         
            +
                    for param in self.parameters():
         
     | 
| 67 | 
         
            +
                        param.requires_grad = False
         
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
            +
                def load_from_pretrained(self, name="vgg_lpips"):
         
     | 
| 70 | 
         
            +
                    ckpt = get_ckpt_path(name, os.path.join(os.path.dirname(os.path.abspath(__file__)), "cache"))
         
     | 
| 71 | 
         
            +
                    self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
         
     | 
| 72 | 
         
            +
                    print("loaded pretrained LPIPS loss from {}".format(ckpt))
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
                @classmethod
         
     | 
| 75 | 
         
            +
                def from_pretrained(cls, name="vgg_lpips"):
         
     | 
| 76 | 
         
            +
                    if name != "vgg_lpips":
         
     | 
| 77 | 
         
            +
                        raise NotImplementedError
         
     | 
| 78 | 
         
            +
                    model = cls()
         
     | 
| 79 | 
         
            +
                    ckpt = get_ckpt_path(name, os.path.join(os.path.dirname(os.path.abspath(__file__)), "cache"))
         
     | 
| 80 | 
         
            +
                    model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
         
     | 
| 81 | 
         
            +
                    return model
         
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
                def forward(self, input, target):
         
     | 
| 84 | 
         
            +
                    in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
         
     | 
| 85 | 
         
            +
                    outs0, outs1 = self.net(in0_input), self.net(in1_input)
         
     | 
| 86 | 
         
            +
                    feats0, feats1, diffs = {}, {}, {}
         
     | 
| 87 | 
         
            +
                    lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
         
     | 
| 88 | 
         
            +
                    for kk in range(len(self.chns)):
         
     | 
| 89 | 
         
            +
                        feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
         
     | 
| 90 | 
         
            +
                        diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
         
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
                    res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]
         
     | 
| 93 | 
         
            +
                    val = res[0]
         
     | 
| 94 | 
         
            +
                    for l in range(1, len(self.chns)):
         
     | 
| 95 | 
         
            +
                        val += res[l]
         
     | 
| 96 | 
         
            +
                    return val
         
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
            class ScalingLayer(nn.Module):
         
     | 
| 100 | 
         
            +
                def __init__(self):
         
     | 
| 101 | 
         
            +
                    super(ScalingLayer, self).__init__()
         
     | 
| 102 | 
         
            +
                    self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
         
     | 
| 103 | 
         
            +
                    self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])
         
     | 
| 104 | 
         
            +
             
     | 
| 105 | 
         
            +
                def forward(self, inp):
         
     | 
| 106 | 
         
            +
                    return (inp - self.shift) / self.scale
         
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
             
     | 
| 109 | 
         
            +
            class NetLinLayer(nn.Module):
         
     | 
| 110 | 
         
            +
                """ A single linear layer which does a 1x1 conv """
         
     | 
| 111 | 
         
            +
                def __init__(self, chn_in, chn_out=1, use_dropout=False):
         
     | 
| 112 | 
         
            +
                    super(NetLinLayer, self).__init__()
         
     | 
| 113 | 
         
            +
                    layers = [nn.Dropout(), ] if (use_dropout) else []
         
     | 
| 114 | 
         
            +
                    layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]
         
     | 
| 115 | 
         
            +
                    self.model = nn.Sequential(*layers)
         
     | 
| 116 | 
         
            +
             
     | 
| 117 | 
         
            +
             
     | 
| 118 | 
         
            +
            class vgg16(torch.nn.Module):
         
     | 
| 119 | 
         
            +
                def __init__(self, requires_grad=False, pretrained=True):
         
     | 
| 120 | 
         
            +
                    super(vgg16, self).__init__()
         
     | 
| 121 | 
         
            +
                    vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
         
     | 
| 122 | 
         
            +
                    self.slice1 = torch.nn.Sequential()
         
     | 
| 123 | 
         
            +
                    self.slice2 = torch.nn.Sequential()
         
     | 
| 124 | 
         
            +
                    self.slice3 = torch.nn.Sequential()
         
     | 
| 125 | 
         
            +
                    self.slice4 = torch.nn.Sequential()
         
     | 
| 126 | 
         
            +
                    self.slice5 = torch.nn.Sequential()
         
     | 
| 127 | 
         
            +
                    self.N_slices = 5
         
     | 
| 128 | 
         
            +
                    for x in range(4):
         
     | 
| 129 | 
         
            +
                        self.slice1.add_module(str(x), vgg_pretrained_features[x])
         
     | 
| 130 | 
         
            +
                    for x in range(4, 9):
         
     | 
| 131 | 
         
            +
                        self.slice2.add_module(str(x), vgg_pretrained_features[x])
         
     | 
| 132 | 
         
            +
                    for x in range(9, 16):
         
     | 
| 133 | 
         
            +
                        self.slice3.add_module(str(x), vgg_pretrained_features[x])
         
     | 
| 134 | 
         
            +
                    for x in range(16, 23):
         
     | 
| 135 | 
         
            +
                        self.slice4.add_module(str(x), vgg_pretrained_features[x])
         
     | 
| 136 | 
         
            +
                    for x in range(23, 30):
         
     | 
| 137 | 
         
            +
                        self.slice5.add_module(str(x), vgg_pretrained_features[x])
         
     | 
| 138 | 
         
            +
                    if not requires_grad:
         
     | 
| 139 | 
         
            +
                        for param in self.parameters():
         
     | 
| 140 | 
         
            +
                            param.requires_grad = False
         
     | 
| 141 | 
         
            +
             
     | 
| 142 | 
         
            +
                def forward(self, X):
         
     | 
| 143 | 
         
            +
                    h = self.slice1(X)
         
     | 
| 144 | 
         
            +
                    h_relu1_2 = h
         
     | 
| 145 | 
         
            +
                    h = self.slice2(h)
         
     | 
| 146 | 
         
            +
                    h_relu2_2 = h
         
     | 
| 147 | 
         
            +
                    h = self.slice3(h)
         
     | 
| 148 | 
         
            +
                    h_relu3_3 = h
         
     | 
| 149 | 
         
            +
                    h = self.slice4(h)
         
     | 
| 150 | 
         
            +
                    h_relu4_3 = h
         
     | 
| 151 | 
         
            +
                    h = self.slice5(h)
         
     | 
| 152 | 
         
            +
                    h_relu5_3 = h
         
     | 
| 153 | 
         
            +
                    vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
         
     | 
| 154 | 
         
            +
                    out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
         
     | 
| 155 | 
         
            +
                    return out
         
     | 
| 156 | 
         
            +
             
     | 
| 157 | 
         
            +
             
     | 
| 158 | 
         
            +
            def normalize_tensor(x,eps=1e-10):
         
     | 
| 159 | 
         
            +
                norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True))
         
     | 
| 160 | 
         
            +
                return x/(norm_factor+eps)
         
     | 
| 161 | 
         
            +
             
     | 
| 162 | 
         
            +
             
     | 
| 163 | 
         
            +
            def spatial_average(x, keepdim=True):
         
     | 
| 164 | 
         
            +
                return x.mean([2,3],keepdim=keepdim)
         
     | 
    	
        tokenizer/tokenizer_image/reconstruction_vq_ddp.py
    ADDED
    
    | 
         @@ -0,0 +1,207 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import torch
         
     | 
| 2 | 
         
            +
            torch.backends.cuda.matmul.allow_tf32 = True
         
     | 
| 3 | 
         
            +
            torch.backends.cudnn.allow_tf32 = True
         
     | 
| 4 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 5 | 
         
            +
            import torch.distributed as dist
         
     | 
| 6 | 
         
            +
            from torch.utils.data import DataLoader
         
     | 
| 7 | 
         
            +
            from torch.utils.data.distributed import DistributedSampler
         
     | 
| 8 | 
         
            +
            from torchvision import transforms
         
     | 
| 9 | 
         
            +
            from tqdm import tqdm
         
     | 
| 10 | 
         
            +
            import os
         
     | 
| 11 | 
         
            +
            from PIL import Image
         
     | 
| 12 | 
         
            +
            import numpy as np
         
     | 
| 13 | 
         
            +
            import argparse
         
     | 
| 14 | 
         
            +
            import itertools
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            from skimage.metrics import peak_signal_noise_ratio as psnr_loss
         
     | 
| 17 | 
         
            +
            from skimage.metrics import structural_similarity as ssim_loss
         
     | 
| 18 | 
         
            +
            from dataset.augmentation import center_crop_arr
         
     | 
| 19 | 
         
            +
            from dataset.build import build_dataset
         
     | 
| 20 | 
         
            +
            from tokenizer.tokenizer_image.vq_model import VQ_models
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            def create_npz_from_sample_folder(sample_dir, num=50000):
         
     | 
| 25 | 
         
            +
                """
         
     | 
| 26 | 
         
            +
                Builds a single .npz file from a folder of .png samples.
         
     | 
| 27 | 
         
            +
                """
         
     | 
| 28 | 
         
            +
                samples = []
         
     | 
| 29 | 
         
            +
                for i in tqdm(range(num), desc="Building .npz file from samples"):
         
     | 
| 30 | 
         
            +
                    sample_pil = Image.open(f"{sample_dir}/{i:06d}.png")
         
     | 
| 31 | 
         
            +
                    sample_np = np.asarray(sample_pil).astype(np.uint8)
         
     | 
| 32 | 
         
            +
                    samples.append(sample_np)
         
     | 
| 33 | 
         
            +
                samples = np.stack(samples)
         
     | 
| 34 | 
         
            +
                assert samples.shape == (num, samples.shape[1], samples.shape[2], 3)
         
     | 
| 35 | 
         
            +
                npz_path = f"{sample_dir}.npz"
         
     | 
| 36 | 
         
            +
                np.savez(npz_path, arr_0=samples)
         
     | 
| 37 | 
         
            +
                print(f"Saved .npz file to {npz_path} [shape={samples.shape}].")
         
     | 
| 38 | 
         
            +
                return npz_path
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
            def main(args):
         
     | 
| 43 | 
         
            +
                # Setup PyTorch:
         
     | 
| 44 | 
         
            +
                assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage"
         
     | 
| 45 | 
         
            +
                torch.set_grad_enabled(False)
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
                # Setup DDP:
         
     | 
| 48 | 
         
            +
                dist.init_process_group("nccl")
         
     | 
| 49 | 
         
            +
                rank = dist.get_rank()
         
     | 
| 50 | 
         
            +
                device = rank % torch.cuda.device_count()
         
     | 
| 51 | 
         
            +
                seed = args.global_seed * dist.get_world_size() + rank
         
     | 
| 52 | 
         
            +
                torch.manual_seed(seed)
         
     | 
| 53 | 
         
            +
                torch.cuda.set_device(device)
         
     | 
| 54 | 
         
            +
                print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
                # create and load model
         
     | 
| 57 | 
         
            +
                vq_model = VQ_models[args.vq_model](
         
     | 
| 58 | 
         
            +
                    codebook_size=args.codebook_size,
         
     | 
| 59 | 
         
            +
                    codebook_embed_dim=args.codebook_embed_dim)
         
     | 
| 60 | 
         
            +
                vq_model.to(device)
         
     | 
| 61 | 
         
            +
                vq_model.eval()
         
     | 
| 62 | 
         
            +
                checkpoint = torch.load(args.vq_ckpt, map_location="cpu")
         
     | 
| 63 | 
         
            +
                if "ema" in checkpoint:  # ema
         
     | 
| 64 | 
         
            +
                    model_weight = checkpoint["ema"]
         
     | 
| 65 | 
         
            +
                elif "model" in checkpoint:  # ddp
         
     | 
| 66 | 
         
            +
                    model_weight = checkpoint["model"]
         
     | 
| 67 | 
         
            +
                elif "state_dict" in checkpoint:
         
     | 
| 68 | 
         
            +
                    model_weight = checkpoint["state_dict"]
         
     | 
| 69 | 
         
            +
                else:
         
     | 
| 70 | 
         
            +
                    raise Exception("please check model weight")
         
     | 
| 71 | 
         
            +
                vq_model.load_state_dict(model_weight)
         
     | 
| 72 | 
         
            +
                del checkpoint
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
                # Create folder to save samples:
         
     | 
| 75 | 
         
            +
                folder_name = (f"{args.vq_model}-{args.dataset}-size-{args.image_size}-size-{args.image_size_eval}"
         
     | 
| 76 | 
         
            +
                              f"-codebook-size-{args.codebook_size}-dim-{args.codebook_embed_dim}-seed-{args.global_seed}")
         
     | 
| 77 | 
         
            +
                sample_folder_dir = f"{args.sample_dir}/{folder_name}"
         
     | 
| 78 | 
         
            +
                if rank == 0:
         
     | 
| 79 | 
         
            +
                    os.makedirs(sample_folder_dir, exist_ok=True)
         
     | 
| 80 | 
         
            +
                    print(f"Saving .png samples at {sample_folder_dir}")
         
     | 
| 81 | 
         
            +
                dist.barrier()
         
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
                # Setup data:
         
     | 
| 84 | 
         
            +
                transform = transforms.Compose([
         
     | 
| 85 | 
         
            +
                    transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.image_size)),
         
     | 
| 86 | 
         
            +
                    transforms.ToTensor(),
         
     | 
| 87 | 
         
            +
                    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
         
     | 
| 88 | 
         
            +
                ])
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
                if args.dataset == 'imagenet':
         
     | 
| 91 | 
         
            +
                    dataset = build_dataset(args, transform=transform)
         
     | 
| 92 | 
         
            +
                    num_fid_samples = 50000
         
     | 
| 93 | 
         
            +
                elif args.dataset == 'coco':
         
     | 
| 94 | 
         
            +
                    dataset = build_dataset(args, transform=transform)
         
     | 
| 95 | 
         
            +
                    num_fid_samples = 5000
         
     | 
| 96 | 
         
            +
                elif args.dataset == 'imagenet_code':
         
     | 
| 97 | 
         
            +
                    dataset = build_dataset(args)
         
     | 
| 98 | 
         
            +
                    num_fid_samples = 50000
         
     | 
| 99 | 
         
            +
                else:
         
     | 
| 100 | 
         
            +
                    raise Exception("please check dataset")
         
     | 
| 101 | 
         
            +
                
         
     | 
| 102 | 
         
            +
                sampler = DistributedSampler(
         
     | 
| 103 | 
         
            +
                    dataset,
         
     | 
| 104 | 
         
            +
                    num_replicas=dist.get_world_size(),
         
     | 
| 105 | 
         
            +
                    rank=rank,
         
     | 
| 106 | 
         
            +
                    shuffle=False,
         
     | 
| 107 | 
         
            +
                    seed=args.global_seed
         
     | 
| 108 | 
         
            +
                )
         
     | 
| 109 | 
         
            +
                loader = DataLoader(
         
     | 
| 110 | 
         
            +
                    dataset,
         
     | 
| 111 | 
         
            +
                    batch_size=args.per_proc_batch_size,
         
     | 
| 112 | 
         
            +
                    shuffle=False,
         
     | 
| 113 | 
         
            +
                    sampler=sampler,
         
     | 
| 114 | 
         
            +
                    num_workers=args.num_workers,
         
     | 
| 115 | 
         
            +
                    pin_memory=True,
         
     | 
| 116 | 
         
            +
                    drop_last=False
         
     | 
| 117 | 
         
            +
                )    
         
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
                # Figure out how many samples we need to generate on each GPU and how many iterations we need to run:
         
     | 
| 120 | 
         
            +
                n = args.per_proc_batch_size
         
     | 
| 121 | 
         
            +
                global_batch_size = n * dist.get_world_size()
         
     | 
| 122 | 
         
            +
                
         
     | 
| 123 | 
         
            +
                psnr_val_rgb = []
         
     | 
| 124 | 
         
            +
                ssim_val_rgb = []
         
     | 
| 125 | 
         
            +
                loader = tqdm(loader) if rank == 0 else loader
         
     | 
| 126 | 
         
            +
                total = 0
         
     | 
| 127 | 
         
            +
                # for x, _ in loader:
         
     | 
| 128 | 
         
            +
                for batch in loader:
         
     | 
| 129 | 
         
            +
                    x = batch['condition_imgs'].repeat(1,3,1,1)
         
     | 
| 130 | 
         
            +
                    # import pdb 
         
     | 
| 131 | 
         
            +
                    # pdb.set_trace()
         
     | 
| 132 | 
         
            +
                    if args.image_size_eval != args.image_size:
         
     | 
| 133 | 
         
            +
                        rgb_gts = F.interpolate(x, size=(args.image_size_eval, args.image_size_eval), mode='bicubic')
         
     | 
| 134 | 
         
            +
                    else:
         
     | 
| 135 | 
         
            +
                        rgb_gts = x
         
     | 
| 136 | 
         
            +
                    rgb_gts = (rgb_gts.permute(0, 2, 3, 1).to("cpu").numpy() + 1.0) / 2.0 # rgb_gt value is between [0, 1]
         
     | 
| 137 | 
         
            +
                    x = x.to(device, non_blocking=True)
         
     | 
| 138 | 
         
            +
                    with torch.no_grad():
         
     | 
| 139 | 
         
            +
                        latent, _, [_, _, indices] = vq_model.encode(x.float())
         
     | 
| 140 | 
         
            +
                        import pdb;pdb.set_trace()
         
     | 
| 141 | 
         
            +
                        samples = vq_model.decode_code(indices, latent.shape) # output value is between [-1, 1]
         
     | 
| 142 | 
         
            +
                        if args.image_size_eval != args.image_size:
         
     | 
| 143 | 
         
            +
                            samples = F.interpolate(samples, size=(args.image_size_eval, args.image_size_eval), mode='bicubic')
         
     | 
| 144 | 
         
            +
                    samples = torch.clamp(127.5 * samples + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
         
     | 
| 145 | 
         
            +
             
     | 
| 146 | 
         
            +
                    # Save samples to disk as individual .png files
         
     | 
| 147 | 
         
            +
                    for i, (sample, rgb_gt) in enumerate(zip(samples, rgb_gts)):
         
     | 
| 148 | 
         
            +
                        index = i * dist.get_world_size() + rank + total
         
     | 
| 149 | 
         
            +
                        # Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png")
         
     | 
| 150 | 
         
            +
                        # metric
         
     | 
| 151 | 
         
            +
                        rgb_restored = sample.astype(np.float32) / 255. # rgb_restored value is between [0, 1]
         
     | 
| 152 | 
         
            +
                        psnr = psnr_loss(rgb_restored, rgb_gt)
         
     | 
| 153 | 
         
            +
                        ssim = ssim_loss(rgb_restored, rgb_gt, multichannel=True, data_range=2.0, channel_axis=-1)
         
     | 
| 154 | 
         
            +
                        psnr_val_rgb.append(psnr)
         
     | 
| 155 | 
         
            +
                        ssim_val_rgb.append(ssim)
         
     | 
| 156 | 
         
            +
                        
         
     | 
| 157 | 
         
            +
                    total += global_batch_size
         
     | 
| 158 | 
         
            +
             
     | 
| 159 | 
         
            +
                # ------------------------------------
         
     | 
| 160 | 
         
            +
                #       Summary
         
     | 
| 161 | 
         
            +
                # ------------------------------------
         
     | 
| 162 | 
         
            +
                # Make sure all processes have finished saving their samples
         
     | 
| 163 | 
         
            +
                dist.barrier()
         
     | 
| 164 | 
         
            +
                world_size = dist.get_world_size()
         
     | 
| 165 | 
         
            +
                gather_psnr_val = [None for _ in range(world_size)]
         
     | 
| 166 | 
         
            +
                gather_ssim_val = [None for _ in range(world_size)]
         
     | 
| 167 | 
         
            +
                dist.all_gather_object(gather_psnr_val, psnr_val_rgb)
         
     | 
| 168 | 
         
            +
                dist.all_gather_object(gather_ssim_val, ssim_val_rgb)
         
     | 
| 169 | 
         
            +
             
     | 
| 170 | 
         
            +
                if rank == 0:
         
     | 
| 171 | 
         
            +
                    gather_psnr_val = list(itertools.chain(*gather_psnr_val))
         
     | 
| 172 | 
         
            +
                    gather_ssim_val = list(itertools.chain(*gather_ssim_val))        
         
     | 
| 173 | 
         
            +
                    psnr_val_rgb = sum(gather_psnr_val) / len(gather_psnr_val)
         
     | 
| 174 | 
         
            +
                    ssim_val_rgb = sum(gather_ssim_val) / len(gather_ssim_val)
         
     | 
| 175 | 
         
            +
                    print("PSNR: %f, SSIM: %f " % (psnr_val_rgb, ssim_val_rgb))
         
     | 
| 176 | 
         
            +
             
     | 
| 177 | 
         
            +
                    result_file = f"{sample_folder_dir}_results.txt"
         
     | 
| 178 | 
         
            +
                    print("writing results to {}".format(result_file))
         
     | 
| 179 | 
         
            +
                    with open(result_file, 'w') as f:
         
     | 
| 180 | 
         
            +
                        print("PSNR: %f, SSIM: %f " % (psnr_val_rgb, ssim_val_rgb), file=f)
         
     | 
| 181 | 
         
            +
             
     | 
| 182 | 
         
            +
                    create_npz_from_sample_folder(sample_folder_dir, num_fid_samples)
         
     | 
| 183 | 
         
            +
                    print("Done.")
         
     | 
| 184 | 
         
            +
                
         
     | 
| 185 | 
         
            +
                dist.barrier()
         
     | 
| 186 | 
         
            +
                dist.destroy_process_group()
         
     | 
| 187 | 
         
            +
             
     | 
| 188 | 
         
            +
             
     | 
| 189 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 190 | 
         
            +
                parser = argparse.ArgumentParser()
         
     | 
| 191 | 
         
            +
                parser.add_argument("--data-path", type=str, default=None)
         
     | 
| 192 | 
         
            +
                parser.add_argument("--code-path", type=str, required=True)
         
     | 
| 193 | 
         
            +
                parser.add_argument("--dataset", type=str, choices=['imagenet', 'coco', 'imagenet_code'], default='imagenet')
         
     | 
| 194 | 
         
            +
                parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16")
         
     | 
| 195 | 
         
            +
                parser.add_argument("--vq-ckpt", type=str, default=None, help="ckpt path for vq model")
         
     | 
| 196 | 
         
            +
                parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization")
         
     | 
| 197 | 
         
            +
                parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization")
         
     | 
| 198 | 
         
            +
                parser.add_argument("--image-size", type=int, choices=[256, 384, 512], default=256)
         
     | 
| 199 | 
         
            +
                parser.add_argument("--image-size-eval", type=int, choices=[256, 384, 512], default=256)
         
     | 
| 200 | 
         
            +
                parser.add_argument("--sample-dir", type=str, default="reconstructions")
         
     | 
| 201 | 
         
            +
                parser.add_argument("--per-proc-batch-size", type=int, default=32)
         
     | 
| 202 | 
         
            +
                parser.add_argument("--global-seed", type=int, default=0)
         
     | 
| 203 | 
         
            +
                parser.add_argument("--num-workers", type=int, default=4)
         
     | 
| 204 | 
         
            +
                parser.add_argument("--condition", type=str, choices=['canny', 'hed'], default='canny')
         
     | 
| 205 | 
         
            +
                parser.add_argument("--get-condition-img", type=bool, default=False)
         
     | 
| 206 | 
         
            +
                args = parser.parse_args()
         
     | 
| 207 | 
         
            +
                main(args)
         
     | 
    	
        tokenizer/tokenizer_image/vq_demo.py
    ADDED
    
    | 
         @@ -0,0 +1,84 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import torch
         
     | 
| 2 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            import os
         
     | 
| 5 | 
         
            +
            import argparse
         
     | 
| 6 | 
         
            +
            import numpy as np
         
     | 
| 7 | 
         
            +
            from PIL import Image
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            from tokenizer.tokenizer_image.vq_model import VQ_models
         
     | 
| 10 | 
         
            +
            from dataset.augmentation import center_crop_arr
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            def main(args):
         
     | 
| 14 | 
         
            +
                # Setup PyTorch:
         
     | 
| 15 | 
         
            +
                torch.manual_seed(args.seed)
         
     | 
| 16 | 
         
            +
                torch.set_grad_enabled(False)
         
     | 
| 17 | 
         
            +
                device = "cuda" if torch.cuda.is_available() else "cpu"
         
     | 
| 18 | 
         
            +
                
         
     | 
| 19 | 
         
            +
                # create and load model
         
     | 
| 20 | 
         
            +
                model = VQ_models[args.vq_model](
         
     | 
| 21 | 
         
            +
                    codebook_size=args.codebook_size,
         
     | 
| 22 | 
         
            +
                    codebook_embed_dim=args.codebook_embed_dim)
         
     | 
| 23 | 
         
            +
                model.to(device)
         
     | 
| 24 | 
         
            +
                model.eval()
         
     | 
| 25 | 
         
            +
                checkpoint = torch.load(args.vq_ckpt, map_location="cpu")
         
     | 
| 26 | 
         
            +
                if "ema" in checkpoint:  # ema
         
     | 
| 27 | 
         
            +
                    model_weight = checkpoint["ema"]
         
     | 
| 28 | 
         
            +
                elif "model" in checkpoint:  # ddp
         
     | 
| 29 | 
         
            +
                    model_weight = checkpoint["model"]
         
     | 
| 30 | 
         
            +
                elif "state_dict" in checkpoint:
         
     | 
| 31 | 
         
            +
                    model_weight = checkpoint["state_dict"]
         
     | 
| 32 | 
         
            +
                else:
         
     | 
| 33 | 
         
            +
                    raise Exception("please check model weight")
         
     | 
| 34 | 
         
            +
                model.load_state_dict(model_weight)
         
     | 
| 35 | 
         
            +
                del checkpoint
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
                # output dir
         
     | 
| 38 | 
         
            +
                os.makedirs(args.output_dir, exist_ok=True)
         
     | 
| 39 | 
         
            +
                out_path = args.image_path.replace('.jpg', '_{}.jpg'.format(args.suffix))
         
     | 
| 40 | 
         
            +
                out_path = out_path.replace('.jpeg', '_{}.jpeg'.format(args.suffix))
         
     | 
| 41 | 
         
            +
                out_path = out_path.replace('.png', '_{}.png'.format(args.suffix))
         
     | 
| 42 | 
         
            +
                out_filename = out_path.split('/')[-1]
         
     | 
| 43 | 
         
            +
                out_path = os.path.join(args.output_dir, out_filename)
         
     | 
| 44 | 
         
            +
                
         
     | 
| 45 | 
         
            +
                # load image
         
     | 
| 46 | 
         
            +
                pil_image = Image.open(args.image_path).convert("RGB")
         
     | 
| 47 | 
         
            +
                img = center_crop_arr(pil_image, args.image_size)
         
     | 
| 48 | 
         
            +
                # # preprocess
         
     | 
| 49 | 
         
            +
                # size_org = img.size
         
     | 
| 50 | 
         
            +
                # img = img.resize((input_size, input_size))
         
     | 
| 51 | 
         
            +
                img = np.array(img) / 255.
         
     | 
| 52 | 
         
            +
                x = 2.0 * img - 1.0 # x value is between [-1, 1]
         
     | 
| 53 | 
         
            +
                x = torch.tensor(x)
         
     | 
| 54 | 
         
            +
                x = x.unsqueeze(dim=0)
         
     | 
| 55 | 
         
            +
                x = torch.einsum('nhwc->nchw', x)
         
     | 
| 56 | 
         
            +
                x_input = x.float().to("cuda")
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
                # inference
         
     | 
| 59 | 
         
            +
                with torch.no_grad():
         
     | 
| 60 | 
         
            +
                    latent, _, [_, _, indices] = model.encode(x_input)
         
     | 
| 61 | 
         
            +
                    output = model.decode_code(indices, latent.shape) # output value is between [-1, 1]
         
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
                # postprocess
         
     | 
| 64 | 
         
            +
                output = F.interpolate(output, size=[args.image_size, args.image_size], mode='bicubic').permute(0, 2, 3, 1)[0]
         
     | 
| 65 | 
         
            +
                sample = torch.clamp(127.5 * output + 128.0, 0, 255).to("cpu", dtype=torch.uint8).numpy()
         
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
                # save        
         
     | 
| 68 | 
         
            +
                Image.fromarray(sample).save(out_path)
         
     | 
| 69 | 
         
            +
                print("Reconstructed image is saved to {}".format(out_path))
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 73 | 
         
            +
                parser = argparse.ArgumentParser()
         
     | 
| 74 | 
         
            +
                parser.add_argument("--image-path", type=str, default="assets/example.jpg")
         
     | 
| 75 | 
         
            +
                parser.add_argument("--output-dir", type=str, default="output_vq_demo")
         
     | 
| 76 | 
         
            +
                parser.add_argument("--suffix", type=str, default="tokenizer_image")
         
     | 
| 77 | 
         
            +
                parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16")
         
     | 
| 78 | 
         
            +
                parser.add_argument("--vq-ckpt", type=str, default=None, help="ckpt path for vq model")
         
     | 
| 79 | 
         
            +
                parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization")
         
     | 
| 80 | 
         
            +
                parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization")
         
     | 
| 81 | 
         
            +
                parser.add_argument("--image-size", type=int, choices=[256, 384, 448, 512, 1024], default=512)
         
     | 
| 82 | 
         
            +
                parser.add_argument("--seed", type=int, default=0)
         
     | 
| 83 | 
         
            +
                args = parser.parse_args()
         
     | 
| 84 | 
         
            +
                main(args)
         
     |