Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	
		jhaozhuang
		
	commited on
		
		
					Commit 
							
							·
						
						77771e4
	
1
								Parent(s):
							
							e826d2f
								
app
Browse filesThis view is limited to 50 files because it contains too many changes.  
							See raw diff
- BidirectionalTranslation/LICENSE +26 -0
 - BidirectionalTranslation/README.md +72 -0
 - BidirectionalTranslation/data/__init__.py +100 -0
 - BidirectionalTranslation/data/aligned_dataset.py +60 -0
 - BidirectionalTranslation/data/base_dataset.py +164 -0
 - BidirectionalTranslation/data/image_folder.py +66 -0
 - BidirectionalTranslation/data/singleCo_dataset.py +85 -0
 - BidirectionalTranslation/data/singleSr_dataset.py +73 -0
 - BidirectionalTranslation/models/__init__.py +61 -0
 - BidirectionalTranslation/models/base_model.py +277 -0
 - BidirectionalTranslation/models/cycle_ganstft_model.py +103 -0
 - BidirectionalTranslation/models/networks.py +1375 -0
 - BidirectionalTranslation/options/base_options.py +142 -0
 - BidirectionalTranslation/options/test_options.py +19 -0
 - BidirectionalTranslation/requirements.txt +8 -0
 - BidirectionalTranslation/scripts/test_western2manga.sh +49 -0
 - BidirectionalTranslation/test.py +71 -0
 - BidirectionalTranslation/util/html.py +86 -0
 - BidirectionalTranslation/util/util.py +136 -0
 - BidirectionalTranslation/util/visualizer.py +221 -0
 - app.py +507 -0
 - assets/example_0/input.jpg +0 -0
 - assets/example_0/ref1.jpg +0 -0
 - assets/example_1/input.jpg +0 -0
 - assets/example_1/ref1.jpg +0 -0
 - assets/example_1/ref2.jpg +0 -0
 - assets/example_1/ref3.jpg +0 -0
 - assets/example_2/input.png +0 -0
 - assets/example_2/ref1.png +0 -0
 - assets/example_2/ref2.png +0 -0
 - assets/example_2/ref3.png +0 -0
 - assets/example_3/input.png +0 -0
 - assets/example_3/ref1.png +0 -0
 - assets/example_3/ref2.png +0 -0
 - assets/example_3/ref3.png +0 -0
 - assets/example_4/input.jpg +0 -0
 - assets/example_4/ref1.jpg +0 -0
 - assets/example_4/ref2.jpg +0 -0
 - assets/example_4/ref3.jpg +0 -0
 - assets/example_5/input.png +0 -0
 - assets/example_5/ref1.png +0 -0
 - assets/example_5/ref2.png +0 -0
 - assets/example_5/ref3.png +0 -0
 - assets/mask.png +0 -0
 - diffusers/.github/ISSUE_TEMPLATE/bug-report.yml +110 -0
 - diffusers/.github/ISSUE_TEMPLATE/config.yml +4 -0
 - diffusers/.github/ISSUE_TEMPLATE/feature_request.md +20 -0
 - diffusers/.github/ISSUE_TEMPLATE/feedback.md +12 -0
 - diffusers/.github/ISSUE_TEMPLATE/new-model-addition.yml +31 -0
 - diffusers/.github/ISSUE_TEMPLATE/translate.md +29 -0
 
    	
        BidirectionalTranslation/LICENSE
    ADDED
    
    | 
         @@ -0,0 +1,26 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            Manga Filling Style Conversion with Screentone Variational Autoencoder
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            Copyright (c) 2020 The Chinese University of Hong Kong
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            Copyright and License Information: The source code, the binary executable, and all data files (hereafter, Software) are copyrighted by The Chinese University of Hong Kong and Tien-Tsin Wong (hereafter, Author), Copyright (c) 2021 The Chinese University of Hong Kong. All Rights Reserved.
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            The Author grants to you ("Licensee") a non-exclusive license to use the Software for academic, research and commercial purposes, without fee. For commercial use, Licensee should submit a WRITTEN NOTICE to the Author. The notice should clearly identify the software package/system/hardware (name, version, and/or model number) using the Software. Licensee may distribute the Software to third parties provided that the copyright notice and this statement appears on all copies. Licensee agrees that the copyright notice and this statement will appear on all copies of the Software, or portions thereof. The Author retains exclusive ownership of the Software.
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            Licensee may make derivatives of the Software, provided that such derivatives can only be used for the purposes specified in the license grant above.
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            THE AUTHOR MAKES NO REPRESENTATIONS OR WARRANTIES ABOUT THE SUITABILITY OF THE SOFTWARE, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE IMPLIED WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, OR NON-INFRINGEMENT. THE AUTHOR SHALL NOT BE LIABLE FOR ANY DAMAGES SUFFERED BY LICENSEE AS A RESULT OF USING, MODIFYING OR DISTRIBUTING THE SOFTWARE OR ITS DERIVATIVES.
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            By using the source code, Licensee agrees to cite the following papers in
         
     | 
| 14 | 
         
            +
            Licensee's publication/work:
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
              Minshan Xie, Chengze Li, Xueting Liu and Tien-Tsin Wong
         
     | 
| 17 | 
         
            +
              "Manga Filling Style Conversion with Screentone Variational Autoencoder"
         
     | 
| 18 | 
         
            +
              ACM Transactions on Graphics (SIGGRAPH Asia 2020 issue), Vol. 39, No. 6, December 2020, pp. 226:1-226:15.
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            By using or copying the Software, Licensee agrees to abide by the intellectual property laws, and all other applicable laws of the U.S., and the terms of this license.
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
            Author shall have the right to terminate this license immediately by written notice upon Licensee's breach of, or non-compliance with, any of its terms.
         
     | 
| 24 | 
         
            +
            Licensee may be held legally responsible for any infringement that is caused or encouraged by Licensee's failure to abide by the terms of this license.
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
            For more information or comments, send mail to: ttwong@acm.org
         
     | 
    	
        BidirectionalTranslation/README.md
    ADDED
    
    | 
         @@ -0,0 +1,72 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Bidirectional Translation
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            Pytorch implementation for multimodal comic-to-manga translation. 
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            **Note**: The current software works well with PyTorch 1.6.0+. 
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            ## Prerequisites
         
     | 
| 8 | 
         
            +
            - Linux
         
     | 
| 9 | 
         
            +
            - Python 3
         
     | 
| 10 | 
         
            +
            - CPU or NVIDIA GPU + CUDA CuDNN
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            ## Getting Started ###
         
     | 
| 13 | 
         
            +
            ### Installation
         
     | 
| 14 | 
         
            +
            - Clone this repo:
         
     | 
| 15 | 
         
            +
            ```bash
         
     | 
| 16 | 
         
            +
            git clone https://github.com/msxie/ScreenStyle.git
         
     | 
| 17 | 
         
            +
            cd ScreenStyle/MangaScreening
         
     | 
| 18 | 
         
            +
            ```
         
     | 
| 19 | 
         
            +
            - Install PyTorch and dependencies from http://pytorch.org
         
     | 
| 20 | 
         
            +
            - Install python libraries [tensorboardX](https://github.com/lanpa/tensorboardX)
         
     | 
| 21 | 
         
            +
            - Install other libraries
         
     | 
| 22 | 
         
            +
            For pip users:
         
     | 
| 23 | 
         
            +
            ```
         
     | 
| 24 | 
         
            +
            pip install -r requirements.txt
         
     | 
| 25 | 
         
            +
            ```
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
            ## Data praperation
         
     | 
| 28 | 
         
            +
            The training requires paired data (including manga image, western image and their line drawings). 
         
     | 
| 29 | 
         
            +
            The line drawing can be extracted using [MangaLineExtraction](https://github.com/ljsabc/MangaLineExtraction).
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
              ```
         
     | 
| 32 | 
         
            +
            ${DATASET} 
         
     | 
| 33 | 
         
            +
            |-- color2manga 
         
     | 
| 34 | 
         
            +
            |   |-- val 
         
     | 
| 35 | 
         
            +
            |   |   |-- ${FOLDER}
         
     | 
| 36 | 
         
            +
            |   |   |   |-- imgs
         
     | 
| 37 | 
         
            +
            |   |   |   |   |-- 0001.png 
         
     | 
| 38 | 
         
            +
            |   |   |   |   |-- ...
         
     | 
| 39 | 
         
            +
            |   |   |   |-- line
         
     | 
| 40 | 
         
            +
            |   |   |   |   |-- 0001.png 
         
     | 
| 41 | 
         
            +
            |   |   |   |   |-- ...
         
     | 
| 42 | 
         
            +
              ```
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
            ### Use a Pre-trained Model
         
     | 
| 45 | 
         
            +
            - Download the pre-trained [ScreenVAE](https://drive.google.com/file/d/1OBxWHjijMwi9gfTOfDiFiHRZA_CXNSWr/view?usp=sharing) model and place under `checkpoints/ScreenVAE/` folder.
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
            - Download the pre-trained [color2manga](https://drive.google.com/file/d/18-N1W0t3igWLJWFyplNZ5Fa2YHWASCZY/view?usp=sharing) model and place under `checkpoints/color2manga/` folder.
         
     | 
| 48 | 
         
            +
            - Generate results with the model
         
     | 
| 49 | 
         
            +
            ```bash
         
     | 
| 50 | 
         
            +
            bash ./scripts/test_western2manga.sh
         
     | 
| 51 | 
         
            +
            ```
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
            ## Copyright and License
         
     | 
| 54 | 
         
            +
            You are granted with the [LICENSE](LICENSE) for both academic and commercial usages.
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
            ## Citation
         
     | 
| 57 | 
         
            +
            If you find the code helpful in your resarch or work, please cite the following papers.
         
     | 
| 58 | 
         
            +
            ```
         
     | 
| 59 | 
         
            +
            @article{xie-2020-manga,
         
     | 
| 60 | 
         
            +
                    author   = {Minshan Xie and Chengze Li and Xueting Liu and Tien-Tsin Wong},
         
     | 
| 61 | 
         
            +
                    title    = {Manga Filling Style Conversion with Screentone Variational Autoencoder},
         
     | 
| 62 | 
         
            +
                    journal  = {ACM Transactions on Graphics (SIGGRAPH Asia 2020 issue)},
         
     | 
| 63 | 
         
            +
                    month    = {December},
         
     | 
| 64 | 
         
            +
                    year     = {2020},
         
     | 
| 65 | 
         
            +
                    volume   = {39},
         
     | 
| 66 | 
         
            +
                    number   = {6},
         
     | 
| 67 | 
         
            +
                    pages    = {226:1--226:15}
         
     | 
| 68 | 
         
            +
                }
         
     | 
| 69 | 
         
            +
            ```
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
            ### Acknowledgements
         
     | 
| 72 | 
         
            +
            This code borrows heavily from the [pytorch-CycleGAN-and-pix2pix](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix) repository.
         
     | 
    	
        BidirectionalTranslation/data/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1,100 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            """This package includes all the modules related to data loading and preprocessing
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
             To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
         
     | 
| 4 | 
         
            +
             You need to implement four functions:
         
     | 
| 5 | 
         
            +
                -- <__init__>:                      initialize the class, first call BaseDataset.__init__(self, opt).
         
     | 
| 6 | 
         
            +
                -- <__len__>:                       return the size of dataset.
         
     | 
| 7 | 
         
            +
                -- <__getitem__>:                   get a data point from data loader.
         
     | 
| 8 | 
         
            +
                -- <modify_commandline_options>:    (optionally) add dataset-specific options and set default options.
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
         
     | 
| 11 | 
         
            +
            See our template dataset class 'template_dataset.py' for more details.
         
     | 
| 12 | 
         
            +
            """
         
     | 
| 13 | 
         
            +
            import importlib
         
     | 
| 14 | 
         
            +
            import torch.utils.data
         
     | 
| 15 | 
         
            +
            from data.base_dataset import BaseDataset
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            def find_dataset_using_name(dataset_name):
         
     | 
| 19 | 
         
            +
                """Import the module "data/[dataset_name]_dataset.py".
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
                In the file, the class called DatasetNameDataset() will
         
     | 
| 22 | 
         
            +
                be instantiated. It has to be a subclass of BaseDataset,
         
     | 
| 23 | 
         
            +
                and it is case-insensitive.
         
     | 
| 24 | 
         
            +
                """
         
     | 
| 25 | 
         
            +
                dataset_filename = "data." + dataset_name + "_dataset"
         
     | 
| 26 | 
         
            +
                datasetlib = importlib.import_module(dataset_filename)
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
                dataset = None
         
     | 
| 29 | 
         
            +
                target_dataset_name = dataset_name.replace('_', '') + 'dataset'
         
     | 
| 30 | 
         
            +
                for name, cls in datasetlib.__dict__.items():
         
     | 
| 31 | 
         
            +
                    if name.lower() == target_dataset_name.lower() \
         
     | 
| 32 | 
         
            +
                       and issubclass(cls, BaseDataset):
         
     | 
| 33 | 
         
            +
                        dataset = cls
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
                if dataset is None:
         
     | 
| 36 | 
         
            +
                    raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
                return dataset
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
            def get_option_setter(dataset_name):
         
     | 
| 42 | 
         
            +
                """Return the static method <modify_commandline_options> of the dataset class."""
         
     | 
| 43 | 
         
            +
                dataset_class = find_dataset_using_name(dataset_name)
         
     | 
| 44 | 
         
            +
                return dataset_class.modify_commandline_options
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
            def create_dataset(opt):
         
     | 
| 48 | 
         
            +
                """Create a dataset given the option.
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
                This function wraps the class CustomDatasetDataLoader.
         
     | 
| 51 | 
         
            +
                    This is the main interface between this package and 'train.py'/'test.py'
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
                Example:
         
     | 
| 54 | 
         
            +
                    >>> from data import create_dataset
         
     | 
| 55 | 
         
            +
                    >>> dataset = create_dataset(opt)
         
     | 
| 56 | 
         
            +
                """
         
     | 
| 57 | 
         
            +
                data_loader = CustomDatasetDataLoader(opt)
         
     | 
| 58 | 
         
            +
                dataset = data_loader.load_data()
         
     | 
| 59 | 
         
            +
                return dataset
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
            class CustomDatasetDataLoader():
         
     | 
| 63 | 
         
            +
                """Wrapper class of Dataset class that performs multi-threaded data loading"""
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
                def __init__(self, opt):
         
     | 
| 66 | 
         
            +
                    """Initialize this class
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
                    Step 1: create a dataset instance given the name [dataset_mode]
         
     | 
| 69 | 
         
            +
                    Step 2: create a multi-threaded data loader.
         
     | 
| 70 | 
         
            +
                    """
         
     | 
| 71 | 
         
            +
                    self.opt = opt
         
     | 
| 72 | 
         
            +
                    dataset_class = find_dataset_using_name(opt.dataset_mode)
         
     | 
| 73 | 
         
            +
                    self.dataset = dataset_class(opt)
         
     | 
| 74 | 
         
            +
                    print("dataset [%s] was created" % type(self.dataset).__name__)
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
                    train_sampler = None
         
     | 
| 77 | 
         
            +
                    if len(opt.gpu_ids) > 1:
         
     | 
| 78 | 
         
            +
                        train_sampler = torch.utils.data.distributed.DistributedSampler(self.dataset)    
         
     | 
| 79 | 
         
            +
             
     | 
| 80 | 
         
            +
                    self.dataloader = torch.utils.data.DataLoader(
         
     | 
| 81 | 
         
            +
                        self.dataset,
         
     | 
| 82 | 
         
            +
                        batch_size=opt.batch_size,
         
     | 
| 83 | 
         
            +
                        #shuffle=not opt.serial_batches,
         
     | 
| 84 | 
         
            +
                        num_workers=int(opt.num_threads),
         
     | 
| 85 | 
         
            +
                        pin_memory=True, sampler=train_sampler
         
     | 
| 86 | 
         
            +
                        )
         
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
                def load_data(self):
         
     | 
| 89 | 
         
            +
                    return self
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
                def __len__(self):
         
     | 
| 92 | 
         
            +
                    """Return the number of data in the dataset"""
         
     | 
| 93 | 
         
            +
                    return min(len(self.dataset), self.opt.max_dataset_size)
         
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
                def __iter__(self):
         
     | 
| 96 | 
         
            +
                    """Return a batch of data"""
         
     | 
| 97 | 
         
            +
                    for i, data in enumerate(self.dataloader):
         
     | 
| 98 | 
         
            +
                        if i * self.opt.batch_size >= self.opt.max_dataset_size:
         
     | 
| 99 | 
         
            +
                            break
         
     | 
| 100 | 
         
            +
                        yield data
         
     | 
    	
        BidirectionalTranslation/data/aligned_dataset.py
    ADDED
    
    | 
         @@ -0,0 +1,60 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import os.path
         
     | 
| 2 | 
         
            +
            from data.base_dataset import BaseDataset, get_params, get_transform
         
     | 
| 3 | 
         
            +
            from data.image_folder import make_dataset
         
     | 
| 4 | 
         
            +
            from PIL import Image
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            class AlignedDataset(BaseDataset):
         
     | 
| 8 | 
         
            +
                """A dataset class for paired image dataset.
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
                It assumes that the directory '/path/to/data/train' contains image pairs in the form of {A,B}.
         
     | 
| 11 | 
         
            +
                During test time, you need to prepare a directory '/path/to/data/test'.
         
     | 
| 12 | 
         
            +
                """
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
                def __init__(self, opt):
         
     | 
| 15 | 
         
            +
                    """Initialize this dataset class.
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
                    Parameters:
         
     | 
| 18 | 
         
            +
                        opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
         
     | 
| 19 | 
         
            +
                    """
         
     | 
| 20 | 
         
            +
                    BaseDataset.__init__(self, opt)
         
     | 
| 21 | 
         
            +
                    self.dir_AB = os.path.join(opt.dataroot, opt.phase)  # get the image directory
         
     | 
| 22 | 
         
            +
                    self.AB_paths = sorted(make_dataset(self.dir_AB, opt.max_dataset_size))  # get image paths
         
     | 
| 23 | 
         
            +
                    assert(self.opt.load_size >= self.opt.crop_size)   # crop_size should be smaller than the size of loaded image
         
     | 
| 24 | 
         
            +
                    self.input_nc = self.opt.output_nc if self.opt.direction == 'BtoA' else self.opt.input_nc
         
     | 
| 25 | 
         
            +
                    self.output_nc = self.opt.input_nc if self.opt.direction == 'BtoA' else self.opt.output_nc
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
                def __getitem__(self, index):
         
     | 
| 28 | 
         
            +
                    """Return a data point and its metadata information.
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
                    Parameters:
         
     | 
| 31 | 
         
            +
                        index - - a random integer for data indexing
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
                    Returns a dictionary that contains A, B, A_paths and B_paths
         
     | 
| 34 | 
         
            +
                        A (tensor) - - an image in the input domain
         
     | 
| 35 | 
         
            +
                        B (tensor) - - its corresponding image in the target domain
         
     | 
| 36 | 
         
            +
                        A_paths (str) - - image paths
         
     | 
| 37 | 
         
            +
                        B_paths (str) - - image paths (same as A_paths)
         
     | 
| 38 | 
         
            +
                    """
         
     | 
| 39 | 
         
            +
                    # read a image given a random integer index
         
     | 
| 40 | 
         
            +
                    AB_path = self.AB_paths[index%len(self.AB_paths)]
         
     | 
| 41 | 
         
            +
                    AB = Image.open(AB_path).convert('RGB')
         
     | 
| 42 | 
         
            +
                    # split AB image into A and B
         
     | 
| 43 | 
         
            +
                    w, h = AB.size
         
     | 
| 44 | 
         
            +
                    w2 = int(w / 2)
         
     | 
| 45 | 
         
            +
                    A = AB.crop((0, 0, w2, h))
         
     | 
| 46 | 
         
            +
                    B = AB.crop((w2, 0, w, h))
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
                    # apply the same transform to both A and B
         
     | 
| 49 | 
         
            +
                    transform_params = get_params(self.opt, A.size)
         
     | 
| 50 | 
         
            +
                    A_transform = get_transform(self.opt, transform_params, grayscale=(self.input_nc == 1))
         
     | 
| 51 | 
         
            +
                    B_transform = get_transform(self.opt, transform_params, grayscale=(self.output_nc == 1))
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
                    A = A_transform(A)
         
     | 
| 54 | 
         
            +
                    B = B_transform(B)
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
                    return {'A': A, 'B': B, 'A_paths': AB_path, 'B_paths': AB_path}
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
                def __len__(self):
         
     | 
| 59 | 
         
            +
                    """Return the total number of images in the dataset."""
         
     | 
| 60 | 
         
            +
                    return len(self.AB_paths)*100
         
     | 
    	
        BidirectionalTranslation/data/base_dataset.py
    ADDED
    
    | 
         @@ -0,0 +1,164 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            """This module implements an abstract base class (ABC) 'BaseDataset' for datasets.
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
         
     | 
| 4 | 
         
            +
            """
         
     | 
| 5 | 
         
            +
            import random
         
     | 
| 6 | 
         
            +
            import numpy as np
         
     | 
| 7 | 
         
            +
            import torch.utils.data as data
         
     | 
| 8 | 
         
            +
            from PIL import Image, ImageOps
         
     | 
| 9 | 
         
            +
            import torchvision.transforms as transforms
         
     | 
| 10 | 
         
            +
            from abc import ABC, abstractmethod
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            class BaseDataset(data.Dataset, ABC):
         
     | 
| 14 | 
         
            +
                """This class is an abstract base class (ABC) for datasets.
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
                To create a subclass, you need to implement the following four functions:
         
     | 
| 17 | 
         
            +
                -- <__init__>:                      initialize the class, first call BaseDataset.__init__(self, opt).
         
     | 
| 18 | 
         
            +
                -- <__len__>:                       return the size of dataset.
         
     | 
| 19 | 
         
            +
                -- <__getitem__>:                   get a data point.
         
     | 
| 20 | 
         
            +
                -- <modify_commandline_options>:    (optionally) add dataset-specific options and set default options.
         
     | 
| 21 | 
         
            +
                """
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
                def __init__(self, opt):
         
     | 
| 24 | 
         
            +
                    """Initialize the class; save the options in the class
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
                    Parameters:
         
     | 
| 27 | 
         
            +
                        opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
         
     | 
| 28 | 
         
            +
                    """
         
     | 
| 29 | 
         
            +
                    self.opt = opt
         
     | 
| 30 | 
         
            +
                    self.root = opt.dataroot
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
                @staticmethod
         
     | 
| 33 | 
         
            +
                def modify_commandline_options(parser, is_train):
         
     | 
| 34 | 
         
            +
                    """Add new dataset-specific options, and rewrite default values for existing options.
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
                    Parameters:
         
     | 
| 37 | 
         
            +
                        parser          -- original option parser
         
     | 
| 38 | 
         
            +
                        is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
                    Returns:
         
     | 
| 41 | 
         
            +
                        the modified parser.
         
     | 
| 42 | 
         
            +
                    """
         
     | 
| 43 | 
         
            +
                    return parser
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
                @abstractmethod
         
     | 
| 46 | 
         
            +
                def __len__(self):
         
     | 
| 47 | 
         
            +
                    """Return the total number of images in the dataset."""
         
     | 
| 48 | 
         
            +
                    return 0
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
                @abstractmethod
         
     | 
| 51 | 
         
            +
                def __getitem__(self, index):
         
     | 
| 52 | 
         
            +
                    """Return a data point and its metadata information.
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
                    Parameters:
         
     | 
| 55 | 
         
            +
                        index - - a random integer for data indexing
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
                    Returns:
         
     | 
| 58 | 
         
            +
                        a dictionary of data with their names. It ususally contains the data itself and its metadata information.
         
     | 
| 59 | 
         
            +
                    """
         
     | 
| 60 | 
         
            +
                    pass
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
            def get_params(opt, size):
         
     | 
| 64 | 
         
            +
                w, h = size
         
     | 
| 65 | 
         
            +
                new_h = h
         
     | 
| 66 | 
         
            +
                new_w = w
         
     | 
| 67 | 
         
            +
                crop = 0
         
     | 
| 68 | 
         
            +
                if opt.preprocess == 'resize_and_crop':
         
     | 
| 69 | 
         
            +
                    new_h = new_w = opt.load_size
         
     | 
| 70 | 
         
            +
                elif opt.preprocess == 'scale_width_and_crop':
         
     | 
| 71 | 
         
            +
                    new_w = opt.load_size
         
     | 
| 72 | 
         
            +
                    new_h = opt.load_size * h // w
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
                # x = random.randint(0, np.maximum(0, new_w - opt.crop_size))
         
     | 
| 75 | 
         
            +
                # y = random.randint(0, np.maximum(0, new_h - opt.crop_size))
         
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
                x = random.randint(crop, np.maximum(0, new_w - opt.crop_size-crop))
         
     | 
| 78 | 
         
            +
                y = random.randint(crop, np.maximum(0, new_h - opt.crop_size-crop))
         
     | 
| 79 | 
         
            +
             
     | 
| 80 | 
         
            +
                flip = random.random() > 0.5
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
                return {'crop_pos': (x, y), 'flip': flip}
         
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
            def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True):
         
     | 
| 86 | 
         
            +
                transform_list = []
         
     | 
| 87 | 
         
            +
                if grayscale:
         
     | 
| 88 | 
         
            +
                    transform_list.append(transforms.Grayscale(1))
         
     | 
| 89 | 
         
            +
                if 'resize' in opt.preprocess:
         
     | 
| 90 | 
         
            +
                    osize = [opt.load_size, opt.load_size]
         
     | 
| 91 | 
         
            +
                    transform_list.append(transforms.Resize(osize, method))
         
     | 
| 92 | 
         
            +
                elif 'scale_width' in opt.preprocess:
         
     | 
| 93 | 
         
            +
                    transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method)))
         
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
                if 'crop' in opt.preprocess:
         
     | 
| 96 | 
         
            +
                    if params is None:
         
     | 
| 97 | 
         
            +
                        # transform_list.append(transforms.RandomCrop(opt.crop_size))
         
     | 
| 98 | 
         
            +
                        transform_list.append(transforms.CenterCrop(opt.crop_size))
         
     | 
| 99 | 
         
            +
                    else:
         
     | 
| 100 | 
         
            +
                        transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size)))
         
     | 
| 101 | 
         
            +
             
     | 
| 102 | 
         
            +
                if opt.preprocess == 'none':
         
     | 
| 103 | 
         
            +
                    transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=2**8, method=method)))
         
     | 
| 104 | 
         
            +
             
     | 
| 105 | 
         
            +
                if not opt.no_flip:
         
     | 
| 106 | 
         
            +
                    if params is None:
         
     | 
| 107 | 
         
            +
                        transform_list.append(transforms.RandomHorizontalFlip())
         
     | 
| 108 | 
         
            +
                    elif params['flip']:
         
     | 
| 109 | 
         
            +
                        transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
         
     | 
| 110 | 
         
            +
                        
         
     | 
| 111 | 
         
            +
                # transform_list += [transforms.ToTensor()]
         
     | 
| 112 | 
         
            +
                if convert:
         
     | 
| 113 | 
         
            +
                    transform_list += [transforms.ToTensor()]
         
     | 
| 114 | 
         
            +
                    if grayscale:
         
     | 
| 115 | 
         
            +
                        transform_list += [transforms.Normalize((0.5,), (0.5,))]
         
     | 
| 116 | 
         
            +
                    else:
         
     | 
| 117 | 
         
            +
                        transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
         
     | 
| 118 | 
         
            +
                return transforms.Compose(transform_list)
         
     | 
| 119 | 
         
            +
             
     | 
| 120 | 
         
            +
             
     | 
| 121 | 
         
            +
            def __make_power_2(img, base, method=Image.BICUBIC):
         
     | 
| 122 | 
         
            +
                ow, oh = img.size
         
     | 
| 123 | 
         
            +
                h = int((oh+base-1) // base * base)
         
     | 
| 124 | 
         
            +
                w = int((ow+base-1) // base * base)
         
     | 
| 125 | 
         
            +
                if (h == oh) and (w == ow):
         
     | 
| 126 | 
         
            +
                    return img
         
     | 
| 127 | 
         
            +
             
     | 
| 128 | 
         
            +
                __print_size_warning(ow, oh, w, h)
         
     | 
| 129 | 
         
            +
                return ImageOps.expand(img, (0, 0, w-ow, h-oh), fill=255)
         
     | 
| 130 | 
         
            +
             
     | 
| 131 | 
         
            +
             
     | 
| 132 | 
         
            +
            def __scale_width(img, target_width, method=Image.BICUBIC):
         
     | 
| 133 | 
         
            +
                ow, oh = img.size
         
     | 
| 134 | 
         
            +
                if (ow == target_width):
         
     | 
| 135 | 
         
            +
                    return img
         
     | 
| 136 | 
         
            +
                w = target_width
         
     | 
| 137 | 
         
            +
                h = int(target_width * oh / ow)
         
     | 
| 138 | 
         
            +
                return img.resize((w, h), method)
         
     | 
| 139 | 
         
            +
             
     | 
| 140 | 
         
            +
             
     | 
| 141 | 
         
            +
            def __crop(img, pos, size):
         
     | 
| 142 | 
         
            +
                ow, oh = img.size
         
     | 
| 143 | 
         
            +
                x1, y1 = pos
         
     | 
| 144 | 
         
            +
                tw = th = size
         
     | 
| 145 | 
         
            +
                if (ow > tw or oh > th):
         
     | 
| 146 | 
         
            +
                    return img.crop((x1, y1, x1 + tw, y1 + th))
         
     | 
| 147 | 
         
            +
                return img
         
     | 
| 148 | 
         
            +
             
     | 
| 149 | 
         
            +
             
     | 
| 150 | 
         
            +
            def __flip(img, flip):
         
     | 
| 151 | 
         
            +
                if flip:
         
     | 
| 152 | 
         
            +
                    return img.transpose(Image.FLIP_LEFT_RIGHT)
         
     | 
| 153 | 
         
            +
                return img
         
     | 
| 154 | 
         
            +
             
     | 
| 155 | 
         
            +
             
     | 
| 156 | 
         
            +
            def __print_size_warning(ow, oh, w, h):
         
     | 
| 157 | 
         
            +
                """Print warning information about image size(only print once)"""
         
     | 
| 158 | 
         
            +
                if not hasattr(__print_size_warning, 'has_printed'):
         
     | 
| 159 | 
         
            +
                    print("The image size needs to be a multiple of 4. "
         
     | 
| 160 | 
         
            +
                          "The loaded image size was (%d, %d), so it was adjusted to "
         
     | 
| 161 | 
         
            +
                          "(%d, %d). This adjustment will be done to all images "
         
     | 
| 162 | 
         
            +
                          "whose sizes are not multiples of 4" % (ow, oh, w, h))
         
     | 
| 163 | 
         
            +
                    __print_size_warning.has_printed = True
         
     | 
| 164 | 
         
            +
             
     | 
    	
        BidirectionalTranslation/data/image_folder.py
    ADDED
    
    | 
         @@ -0,0 +1,66 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            """A modified image folder class
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py)
         
     | 
| 4 | 
         
            +
            so that this class can load images from both current directory and its subdirectories.
         
     | 
| 5 | 
         
            +
            """
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            import torch.utils.data as data
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            from PIL import Image
         
     | 
| 10 | 
         
            +
            import os
         
     | 
| 11 | 
         
            +
            import os.path
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            IMG_EXTENSIONS = [
         
     | 
| 14 | 
         
            +
                '.jpg', '.JPG', '.jpeg', '.JPEG', '.npz', 'npy',
         
     | 
| 15 | 
         
            +
                '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
         
     | 
| 16 | 
         
            +
            ]
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            def is_image_file(filename):
         
     | 
| 20 | 
         
            +
                return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
            def make_dataset(dir, max_dataset_size=float("inf")):
         
     | 
| 24 | 
         
            +
                images = []
         
     | 
| 25 | 
         
            +
                assert os.path.isdir(dir), '%s is not a valid directory' % dir
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
                for root, _, fnames in sorted(os.walk(dir)):
         
     | 
| 28 | 
         
            +
                    for fname in fnames:
         
     | 
| 29 | 
         
            +
                        if is_image_file(fname):
         
     | 
| 30 | 
         
            +
                            path = os.path.join(root, fname)
         
     | 
| 31 | 
         
            +
                            images.append(path)
         
     | 
| 32 | 
         
            +
                return images[:min(max_dataset_size, len(images))]
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
            def default_loader(path):
         
     | 
| 36 | 
         
            +
                return Image.open(path).convert('RGB')
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
            class ImageFolder(data.Dataset):
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                def __init__(self, root, transform=None, return_paths=False,
         
     | 
| 42 | 
         
            +
                             loader=default_loader):
         
     | 
| 43 | 
         
            +
                    imgs = make_dataset(root)
         
     | 
| 44 | 
         
            +
                    if len(imgs) == 0:
         
     | 
| 45 | 
         
            +
                        raise(RuntimeError("Found 0 images in: " + root + "\n"
         
     | 
| 46 | 
         
            +
                                           "Supported image extensions are: " +
         
     | 
| 47 | 
         
            +
                                           ",".join(IMG_EXTENSIONS)))
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
                    self.root = root
         
     | 
| 50 | 
         
            +
                    self.imgs = imgs
         
     | 
| 51 | 
         
            +
                    self.transform = transform
         
     | 
| 52 | 
         
            +
                    self.return_paths = return_paths
         
     | 
| 53 | 
         
            +
                    self.loader = loader
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
                def __getitem__(self, index):
         
     | 
| 56 | 
         
            +
                    path = self.imgs[index]
         
     | 
| 57 | 
         
            +
                    img = self.loader(path)
         
     | 
| 58 | 
         
            +
                    if self.transform is not None:
         
     | 
| 59 | 
         
            +
                        img = self.transform(img)
         
     | 
| 60 | 
         
            +
                    if self.return_paths:
         
     | 
| 61 | 
         
            +
                        return img, path
         
     | 
| 62 | 
         
            +
                    else:
         
     | 
| 63 | 
         
            +
                        return img
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
                def __len__(self):
         
     | 
| 66 | 
         
            +
                    return len(self.imgs)
         
     | 
    	
        BidirectionalTranslation/data/singleCo_dataset.py
    ADDED
    
    | 
         @@ -0,0 +1,85 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import os.path
         
     | 
| 2 | 
         
            +
            from data.base_dataset import BaseDataset, get_params, get_transform
         
     | 
| 3 | 
         
            +
            from data.image_folder import make_dataset
         
     | 
| 4 | 
         
            +
            from PIL import Image, ImageEnhance
         
     | 
| 5 | 
         
            +
            import random
         
     | 
| 6 | 
         
            +
            import numpy as np
         
     | 
| 7 | 
         
            +
            import torch
         
     | 
| 8 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 9 | 
         
            +
            import cv2
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            class SingleCoDataset(BaseDataset):
         
     | 
| 13 | 
         
            +
                @staticmethod
         
     | 
| 14 | 
         
            +
                def modify_commandline_options(parser, is_train):
         
     | 
| 15 | 
         
            +
                    return parser
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
                def __init__(self, opt):
         
     | 
| 18 | 
         
            +
                    self.opt = opt
         
     | 
| 19 | 
         
            +
                    self.root = opt.dataroot
         
     | 
| 20 | 
         
            +
                    self.dir_A = os.path.join(opt.dataroot, opt.phase, opt.folder, 'imgs')
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
                    self.A_paths = make_dataset(self.dir_A)
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
                    self.A_paths = sorted(self.A_paths)
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
                    self.A_size = len(self.A_paths)
         
     | 
| 27 | 
         
            +
                    # self.transform = get_transform(opt)
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
                def __getitem__(self, index):
         
     | 
| 30 | 
         
            +
                    A_path = self.A_paths[index]
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
                    A_img = Image.open(A_path).convert('RGB')
         
     | 
| 33 | 
         
            +
                    # enhancer = ImageEnhance.Brightness(A_img)
         
     | 
| 34 | 
         
            +
                    # A_img = enhancer.enhance(1.5)
         
     | 
| 35 | 
         
            +
                    if os.path.exists(A_path.replace('imgs','line')[:-4]+'.jpg'):
         
     | 
| 36 | 
         
            +
                        # L_img = Image.open(A_path.replace('imgs','line')[:-4]+'.png')
         
     | 
| 37 | 
         
            +
                        L_img = cv2.imread(A_path.replace('imgs','line')[:-4]+'.jpg')
         
     | 
| 38 | 
         
            +
                        kernel = np.ones((3,3), np.uint8)
         
     | 
| 39 | 
         
            +
                        L_img = cv2.erode(L_img, kernel, iterations=1)
         
     | 
| 40 | 
         
            +
                        L_img = Image.fromarray(L_img)
         
     | 
| 41 | 
         
            +
                    else:
         
     | 
| 42 | 
         
            +
                        L_img = A_img
         
     | 
| 43 | 
         
            +
                    if A_img.size!=L_img.size:
         
     | 
| 44 | 
         
            +
                        # L_img = L_img.resize(A_img.size, Image.ANTIALIAS)
         
     | 
| 45 | 
         
            +
                        A_img = A_img.resize(L_img.size, Image.ANTIALIAS)
         
     | 
| 46 | 
         
            +
                    if A_img.size[1]>2500:
         
     | 
| 47 | 
         
            +
                        A_img = A_img.resize((A_img.size[0]//2, A_img.size[1]//2), Image.ANTIALIAS)
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
                    ow, oh = A_img.size
         
     | 
| 50 | 
         
            +
                    transform_params = get_params(self.opt, A_img.size)
         
     | 
| 51 | 
         
            +
                    A_transform = get_transform(self.opt, transform_params, grayscale=False)
         
     | 
| 52 | 
         
            +
                    L_transform = get_transform(self.opt, transform_params, grayscale=True)
         
     | 
| 53 | 
         
            +
                    A = A_transform(A_img)
         
     | 
| 54 | 
         
            +
                    L = L_transform(L_img)
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
                    # base = 2**9
         
     | 
| 57 | 
         
            +
                    # h = int((oh+base-1) // base * base)
         
     | 
| 58 | 
         
            +
                    # w = int((ow+base-1) // base * base)
         
     | 
| 59 | 
         
            +
                    # A = F.pad(A.unsqueeze(0), (0,w-ow, 0,h-oh), 'replicate').squeeze(0)
         
     | 
| 60 | 
         
            +
                    # L = F.pad(L.unsqueeze(0), (0,w-ow, 0,h-oh), 'replicate').squeeze(0)
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
                    tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114
         
     | 
| 63 | 
         
            +
                    Ai = tmp.unsqueeze(0)
         
     | 
| 64 | 
         
            +
                    
         
     | 
| 65 | 
         
            +
                    return {'A': A, 'Ai': Ai, 'L': L, 
         
     | 
| 66 | 
         
            +
                            'B': torch.zeros(1), 'Bs': torch.zeros(1), 'Bi': torch.zeros(1), 'Bl': torch.zeros(1), 
         
     | 
| 67 | 
         
            +
                            'A_paths': A_path, 'h': oh, 'w': ow}
         
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
            +
                def __len__(self):
         
     | 
| 70 | 
         
            +
                    return self.A_size
         
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
                def name(self):
         
     | 
| 73 | 
         
            +
                    return 'SingleCoDataset'
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
            def M_transform(feat, opt, params=None):
         
     | 
| 77 | 
         
            +
                outfeat = feat.copy()
         
     | 
| 78 | 
         
            +
                oh,ow = feat.shape[1:]
         
     | 
| 79 | 
         
            +
                x1, y1 = params['crop_pos']
         
     | 
| 80 | 
         
            +
                tw = th = opt.crop_size
         
     | 
| 81 | 
         
            +
                if (ow > tw or oh > th):
         
     | 
| 82 | 
         
            +
                    outfeat = outfeat[:,y1:y1+th,x1:x1+tw]
         
     | 
| 83 | 
         
            +
                if params['flip']:
         
     | 
| 84 | 
         
            +
                    outfeat = np.flip(outfeat, 2)#outfeat[:,:,::-1]
         
     | 
| 85 | 
         
            +
                return torch.from_numpy(outfeat.copy()).float()*2-1.0
         
     | 
    	
        BidirectionalTranslation/data/singleSr_dataset.py
    ADDED
    
    | 
         @@ -0,0 +1,73 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import os.path
         
     | 
| 2 | 
         
            +
            from data.base_dataset import BaseDataset, get_params, get_transform
         
     | 
| 3 | 
         
            +
            from data.image_folder import make_dataset
         
     | 
| 4 | 
         
            +
            from PIL import Image
         
     | 
| 5 | 
         
            +
            import random
         
     | 
| 6 | 
         
            +
            import numpy as np
         
     | 
| 7 | 
         
            +
            import torch
         
     | 
| 8 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            class SingleSrDataset(BaseDataset):
         
     | 
| 12 | 
         
            +
                @staticmethod
         
     | 
| 13 | 
         
            +
                def modify_commandline_options(parser, is_train):
         
     | 
| 14 | 
         
            +
                    return parser
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
                def __init__(self, opt):
         
     | 
| 17 | 
         
            +
                    self.opt = opt
         
     | 
| 18 | 
         
            +
                    self.root = opt.dataroot
         
     | 
| 19 | 
         
            +
                    self.dir_B = os.path.join(opt.dataroot, opt.phase, opt.folder, 'imgs')
         
     | 
| 20 | 
         
            +
                    # self.dir_B = os.path.join(opt.dataroot, opt.phase, 'test/imgs', opt.folder)
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
                    self.B_paths = make_dataset(self.dir_B)
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
                    self.B_paths = sorted(self.B_paths)
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
                    self.B_size = len(self.B_paths)
         
     | 
| 27 | 
         
            +
                    # self.transform = get_transform(opt)
         
     | 
| 28 | 
         
            +
                    # print(self.B_size)
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
                def __getitem__(self, index):
         
     | 
| 31 | 
         
            +
                    B_path = self.B_paths[index]
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
                    B_img = Image.open(B_path).convert('RGB')
         
     | 
| 34 | 
         
            +
                    if os.path.exists(B_path.replace('imgs','line').replace('.jpg','.png')):
         
     | 
| 35 | 
         
            +
                        L_img = Image.open(B_path.replace('imgs','line').replace('.jpg','.png'))#.convert('RGB')
         
     | 
| 36 | 
         
            +
                    else:
         
     | 
| 37 | 
         
            +
                        L_img = Image.open(B_path.replace('imgs','line').replace('.png','.jpg'))#.convert('RGB')
         
     | 
| 38 | 
         
            +
                    B_img = B_img.resize(L_img.size, Image.ANTIALIAS)
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
                    ow, oh = B_img.size
         
     | 
| 41 | 
         
            +
                    transform_params = get_params(self.opt, B_img.size)
         
     | 
| 42 | 
         
            +
                    B_transform = get_transform(self.opt, transform_params, grayscale=True)
         
     | 
| 43 | 
         
            +
                    B = B_transform(B_img)
         
     | 
| 44 | 
         
            +
                    L = B_transform(L_img)
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
                    # base = 2**8
         
     | 
| 47 | 
         
            +
                    # h = int((oh+base-1) // base * base)
         
     | 
| 48 | 
         
            +
                    # w = int((ow+base-1) // base * base)
         
     | 
| 49 | 
         
            +
                    # B = F.pad(B.unsqueeze(0), (0,w-ow, 0,h-oh), 'replicate').squeeze(0)
         
     | 
| 50 | 
         
            +
                    # L = F.pad(L.unsqueeze(0), (0,w-ow, 0,h-oh), 'replicate').squeeze(0)
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
                    return {'B': B, 'Bs': B, 'Bi': B, 'Bl': L, 
         
     | 
| 53 | 
         
            +
                            'A': torch.zeros(1), 'Ai': torch.zeros(1), 'L': torch.zeros(1), 
         
     | 
| 54 | 
         
            +
                            'A_paths': B_path, 'h': oh, 'w': ow}
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
                def __len__(self):
         
     | 
| 57 | 
         
            +
                    return self.B_size
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
                def name(self):
         
     | 
| 60 | 
         
            +
                    return 'SingleSrDataset'
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
            def M_transform(feat, opt, params=None):
         
     | 
| 64 | 
         
            +
                outfeat = feat.copy()
         
     | 
| 65 | 
         
            +
                if params is not None:
         
     | 
| 66 | 
         
            +
                    oh,ow = feat.shape[1:]
         
     | 
| 67 | 
         
            +
                    x1, y1 = params['crop_pos']
         
     | 
| 68 | 
         
            +
                    tw = th = opt.crop_size
         
     | 
| 69 | 
         
            +
                    if (ow > tw or oh > th):
         
     | 
| 70 | 
         
            +
                        outfeat = outfeat[:,y1:y1+th,x1:x1+tw]
         
     | 
| 71 | 
         
            +
                    if params['flip']:
         
     | 
| 72 | 
         
            +
                        outfeat = np.flip(outfeat, 2).copy()#outfeat[:,:,::-1]
         
     | 
| 73 | 
         
            +
                return torch.from_numpy(outfeat).float()*2-1.0
         
     | 
    	
        BidirectionalTranslation/models/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1,61 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            """This package contains modules related to objective functions, optimizations, and network architectures.
         
     | 
| 2 | 
         
            +
            To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
         
     | 
| 3 | 
         
            +
            You need to implement the following five functions:
         
     | 
| 4 | 
         
            +
                -- <__init__>:                      initialize the class; first call BaseModel.__init__(self, opt).
         
     | 
| 5 | 
         
            +
                -- <set_input>:                     unpack data from dataset and apply preprocessing.
         
     | 
| 6 | 
         
            +
                -- <forward>:                       produce intermediate results.
         
     | 
| 7 | 
         
            +
                -- <optimize_parameters>:           calculate loss, gradients, and update network weights.
         
     | 
| 8 | 
         
            +
                -- <modify_commandline_options>:    (optionally) add model-specific options and set default options.
         
     | 
| 9 | 
         
            +
            In the function <__init__>, you need to define four lists:
         
     | 
| 10 | 
         
            +
                -- self.loss_names (str list):          specify the training losses that you want to plot and save.
         
     | 
| 11 | 
         
            +
                -- self.model_names (str list):         specify the images that you want to display and save.
         
     | 
| 12 | 
         
            +
                -- self.visual_names (str list):        define networks used in our training.
         
     | 
| 13 | 
         
            +
                -- self.optimizers (optimizer list):    define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
         
     | 
| 14 | 
         
            +
            Now you can use the model class by specifying flag '--model dummy'.
         
     | 
| 15 | 
         
            +
            See our template model class 'template_model.py' for an example.
         
     | 
| 16 | 
         
            +
            """
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            import importlib
         
     | 
| 19 | 
         
            +
            from models.base_model import BaseModel
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
            def find_model_using_name(model_name):
         
     | 
| 23 | 
         
            +
                """Import the module "models/[model_name]_model.py".
         
     | 
| 24 | 
         
            +
                In the file, the class called DatasetNameModel() will
         
     | 
| 25 | 
         
            +
                be instantiated. It has to be a subclass of BaseModel,
         
     | 
| 26 | 
         
            +
                and it is case-insensitive.
         
     | 
| 27 | 
         
            +
                """
         
     | 
| 28 | 
         
            +
                model_filename = "models." + model_name + "_model"
         
     | 
| 29 | 
         
            +
                modellib = importlib.import_module(model_filename)
         
     | 
| 30 | 
         
            +
                model = None
         
     | 
| 31 | 
         
            +
                target_model_name = model_name.replace('_', '') + 'model'
         
     | 
| 32 | 
         
            +
                for name, cls in modellib.__dict__.items():
         
     | 
| 33 | 
         
            +
                    if name.lower() == target_model_name.lower() \
         
     | 
| 34 | 
         
            +
                       and issubclass(cls, BaseModel):
         
     | 
| 35 | 
         
            +
                        model = cls
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
                if model is None:
         
     | 
| 38 | 
         
            +
                    print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
         
     | 
| 39 | 
         
            +
                    exit(0)
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                return model
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
            def get_option_setter(model_name):
         
     | 
| 45 | 
         
            +
                """Return the static method <modify_commandline_options> of the model class."""
         
     | 
| 46 | 
         
            +
                model_class = find_model_using_name(model_name)
         
     | 
| 47 | 
         
            +
                return model_class.modify_commandline_options
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
            def create_model(opt, ckpt_root):
         
     | 
| 51 | 
         
            +
                """Create a model given the option.
         
     | 
| 52 | 
         
            +
                This function warps the class CustomDatasetDataLoader.
         
     | 
| 53 | 
         
            +
                This is the main interface between this package and 'train.py'/'test.py'
         
     | 
| 54 | 
         
            +
                Example:
         
     | 
| 55 | 
         
            +
                    >>> from models import create_model
         
     | 
| 56 | 
         
            +
                    >>> model = create_model(opt)
         
     | 
| 57 | 
         
            +
                """
         
     | 
| 58 | 
         
            +
                model = find_model_using_name(opt.model)
         
     | 
| 59 | 
         
            +
                instance = model(opt, ckpt_root = ckpt_root)
         
     | 
| 60 | 
         
            +
                print("model [%s] was created" % type(instance).__name__)
         
     | 
| 61 | 
         
            +
                return instance
         
     | 
    	
        BidirectionalTranslation/models/base_model.py
    ADDED
    
    | 
         @@ -0,0 +1,277 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import os
         
     | 
| 2 | 
         
            +
            import torch
         
     | 
| 3 | 
         
            +
            from collections import OrderedDict
         
     | 
| 4 | 
         
            +
            from abc import ABC, abstractmethod
         
     | 
| 5 | 
         
            +
            from . import networks
         
     | 
| 6 | 
         
            +
            import numpy as np
         
     | 
| 7 | 
         
            +
            from torch.nn.parallel import DistributedDataParallel as DDP
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            class BaseModel(ABC):
         
     | 
| 10 | 
         
            +
                """This class is an abstract base class (ABC) for models.
         
     | 
| 11 | 
         
            +
                To create a subclass, you need to implement the following five functions:
         
     | 
| 12 | 
         
            +
                    -- <__init__>:                      initialize the class; first call BaseModel.__init__(self, opt).
         
     | 
| 13 | 
         
            +
                    -- <set_input>:                     unpack data from dataset and apply preprocessing.
         
     | 
| 14 | 
         
            +
                    -- <forward>:                       produce intermediate results.
         
     | 
| 15 | 
         
            +
                    -- <optimize_parameters>:           calculate losses, gradients, and update network weights.
         
     | 
| 16 | 
         
            +
                    -- <modify_commandline_options>:    (optionally) add model-specific options and set default options.
         
     | 
| 17 | 
         
            +
                """
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
                def __init__(self, opt):
         
     | 
| 20 | 
         
            +
                    """Initialize the BaseModel class.
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
                    Parameters:
         
     | 
| 23 | 
         
            +
                        opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
                    When creating your custom class, you need to implement your own initialization.
         
     | 
| 26 | 
         
            +
                    In this fucntion, you should first call  `BaseModel.__init__(self, opt)`
         
     | 
| 27 | 
         
            +
                    Then, you need to define four lists:
         
     | 
| 28 | 
         
            +
                        -- self.loss_names (str list):          specify the training losses that you want to plot and save.
         
     | 
| 29 | 
         
            +
                        -- self.model_names (str list):         specify the images that you want to display and save.
         
     | 
| 30 | 
         
            +
                        -- self.visual_names (str list):        define networks used in our training.
         
     | 
| 31 | 
         
            +
                        -- self.optimizers (optimizer list):    define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
         
     | 
| 32 | 
         
            +
                    """
         
     | 
| 33 | 
         
            +
                    self.opt = opt
         
     | 
| 34 | 
         
            +
                    self.gpu_ids = opt.gpu_ids
         
     | 
| 35 | 
         
            +
                    self.isTrain = opt.isTrain
         
     | 
| 36 | 
         
            +
                    self.iter = 0
         
     | 
| 37 | 
         
            +
                    self.last_iter = 0
         
     | 
| 38 | 
         
            +
                    self.device = torch.device('cuda:{}'.format(
         
     | 
| 39 | 
         
            +
                        self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')  # get device name: CPU or GPU
         
     | 
| 40 | 
         
            +
                    # save all the checkpoints to save_dir
         
     | 
| 41 | 
         
            +
                    self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
         
     | 
| 42 | 
         
            +
                    try:
         
     | 
| 43 | 
         
            +
                        os.mkdir(self.save_dir)
         
     | 
| 44 | 
         
            +
                    except:
         
     | 
| 45 | 
         
            +
                        pass
         
     | 
| 46 | 
         
            +
                    # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark.
         
     | 
| 47 | 
         
            +
                    if opt.preprocess != 'scale_width':
         
     | 
| 48 | 
         
            +
                        torch.backends.cudnn.benchmark = True
         
     | 
| 49 | 
         
            +
                    self.loss_names = []
         
     | 
| 50 | 
         
            +
                    self.model_names = []
         
     | 
| 51 | 
         
            +
                    self.visual_names = []
         
     | 
| 52 | 
         
            +
                    self.optimizers = []
         
     | 
| 53 | 
         
            +
                    self.image_paths = []
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
                    self.label_colours = np.random.randint(255, size=(100,3))
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
                def save_suppixel(self,l_inds):
         
     | 
| 58 | 
         
            +
                    im_target_rgb = np.array([self.label_colours[ c % 100 ] for c in l_inds])
         
     | 
| 59 | 
         
            +
                    b,h,w = l_inds.shape
         
     | 
| 60 | 
         
            +
                    im_target_rgb = im_target_rgb.reshape(b,h,w,3).transpose(0,3,1,2)/127.5-1.0
         
     | 
| 61 | 
         
            +
                    return torch.from_numpy(im_target_rgb)
         
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
                @staticmethod
         
     | 
| 64 | 
         
            +
                def modify_commandline_options(parser, is_train):
         
     | 
| 65 | 
         
            +
                    """Add new model-specific options, and rewrite default values for existing options.
         
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
                    Parameters:
         
     | 
| 68 | 
         
            +
                        parser          -- original option parser
         
     | 
| 69 | 
         
            +
                        is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
                    Returns:
         
     | 
| 72 | 
         
            +
                        the modified parser.
         
     | 
| 73 | 
         
            +
                    """
         
     | 
| 74 | 
         
            +
                    return parser
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
                @abstractmethod
         
     | 
| 77 | 
         
            +
                def set_input(self, input):
         
     | 
| 78 | 
         
            +
                    """Unpack input data from the dataloader and perform necessary pre-processing steps.
         
     | 
| 79 | 
         
            +
             
     | 
| 80 | 
         
            +
                    Parameters:
         
     | 
| 81 | 
         
            +
                        input (dict): includes the data itself and its metadata information.
         
     | 
| 82 | 
         
            +
                    """
         
     | 
| 83 | 
         
            +
                    pass
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
                @abstractmethod
         
     | 
| 86 | 
         
            +
                def forward(self):
         
     | 
| 87 | 
         
            +
                    """Run forward pass; called by both functions <optimize_parameters> and <test>."""
         
     | 
| 88 | 
         
            +
                    pass
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
                def is_train(self):
         
     | 
| 91 | 
         
            +
                    """check if the current batch is good for training."""
         
     | 
| 92 | 
         
            +
                    return True
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
                @abstractmethod
         
     | 
| 95 | 
         
            +
                def optimize_parameters(self):
         
     | 
| 96 | 
         
            +
                    """Calculate losses, gradients, and update network weights; called in every training iteration"""
         
     | 
| 97 | 
         
            +
                    pass
         
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
                def setup(self, opt):
         
     | 
| 100 | 
         
            +
                    """Load and print networks; create schedulers
         
     | 
| 101 | 
         
            +
             
     | 
| 102 | 
         
            +
                    Parameters:
         
     | 
| 103 | 
         
            +
                        opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
         
     | 
| 104 | 
         
            +
                    """
         
     | 
| 105 | 
         
            +
                    if self.isTrain:
         
     | 
| 106 | 
         
            +
                        self.schedulers = [networks.get_scheduler(
         
     | 
| 107 | 
         
            +
                            optimizer, opt) for optimizer in self.optimizers]
         
     | 
| 108 | 
         
            +
                    if not self.isTrain or opt.continue_train:
         
     | 
| 109 | 
         
            +
                        self.load_networks(opt.epoch)
         
     | 
| 110 | 
         
            +
                    self.print_networks(opt.verbose)
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
                def eval(self):
         
     | 
| 113 | 
         
            +
                    """Make models eval mode during test time"""
         
     | 
| 114 | 
         
            +
                    for name in self.model_names:
         
     | 
| 115 | 
         
            +
                        if isinstance(name, str):
         
     | 
| 116 | 
         
            +
                            net = getattr(self, 'net' + name)
         
     | 
| 117 | 
         
            +
                            net.eval()
         
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
                def test(self):
         
     | 
| 120 | 
         
            +
                    """Forward function used in test time.
         
     | 
| 121 | 
         
            +
             
     | 
| 122 | 
         
            +
                    This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop
         
     | 
| 123 | 
         
            +
                    It also calls <compute_visuals> to produce additional visualization results
         
     | 
| 124 | 
         
            +
                    """
         
     | 
| 125 | 
         
            +
                    with torch.no_grad():
         
     | 
| 126 | 
         
            +
                        self.forward()
         
     | 
| 127 | 
         
            +
                        self.compute_visuals()
         
     | 
| 128 | 
         
            +
             
     | 
| 129 | 
         
            +
                def compute_visuals(self):
         
     | 
| 130 | 
         
            +
                    """Calculate additional output images for visdom and HTML visualization"""
         
     | 
| 131 | 
         
            +
                    pass
         
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
                def get_image_paths(self):
         
     | 
| 134 | 
         
            +
                    """ Return image paths that are used to load current data"""
         
     | 
| 135 | 
         
            +
                    return self.image_paths
         
     | 
| 136 | 
         
            +
             
     | 
| 137 | 
         
            +
                def update_learning_rate(self):
         
     | 
| 138 | 
         
            +
                    """Update learning rates for all the networks; called at the end of every epoch"""
         
     | 
| 139 | 
         
            +
                    for scheduler in self.schedulers:
         
     | 
| 140 | 
         
            +
                        scheduler.step()
         
     | 
| 141 | 
         
            +
                    lr = self.optimizers[0].param_groups[0]['lr']
         
     | 
| 142 | 
         
            +
                    print('learning rate = %.7f' % lr)
         
     | 
| 143 | 
         
            +
             
     | 
| 144 | 
         
            +
                def get_current_visuals(self):
         
     | 
| 145 | 
         
            +
                    """Return visualization images. train.py will display these images with visdom, and save the images to a HTML"""
         
     | 
| 146 | 
         
            +
                    visual_ret = OrderedDict()
         
     | 
| 147 | 
         
            +
                    for name in self.visual_names:
         
     | 
| 148 | 
         
            +
                        if isinstance(name, str):
         
     | 
| 149 | 
         
            +
                            if 'Lab' in name:
         
     | 
| 150 | 
         
            +
                                labimg = getattr(self, name).cpu()
         
     | 
| 151 | 
         
            +
                                labimg[:,0,:,:]+=1
         
     | 
| 152 | 
         
            +
                                labimg[:,0,:,:]*=50
         
     | 
| 153 | 
         
            +
                                labimg[:,1:,:,:] *= 110
         
     | 
| 154 | 
         
            +
                                labimg = labimg.permute((0,2,3,1))
         
     | 
| 155 | 
         
            +
                                for i in range(labimg.shape[0]):
         
     | 
| 156 | 
         
            +
                                    labimg[i,:,:,:]=lab2rgb(labimg[i,:,:,:])
         
     | 
| 157 | 
         
            +
                                visual_ret[name] = (labimg.permute((0,3,1,2))*2-1.0).to(self.device)
         
     | 
| 158 | 
         
            +
                            elif 'Fm' in name:
         
     | 
| 159 | 
         
            +
                                visual_ret[name] = self.save_suppixel(getattr(self, name).cpu()).to(self.device)
         
     | 
| 160 | 
         
            +
                            else:
         
     | 
| 161 | 
         
            +
                                visual_ret[name] = getattr(self, name)
         
     | 
| 162 | 
         
            +
                    return visual_ret
         
     | 
| 163 | 
         
            +
             
     | 
| 164 | 
         
            +
                def get_current_losses(self):
         
     | 
| 165 | 
         
            +
                    """Return traning losses / errors. train.py will print out these errors on console, and save them to a file"""
         
     | 
| 166 | 
         
            +
                    errors_ret = OrderedDict()
         
     | 
| 167 | 
         
            +
                    for name in self.loss_names:
         
     | 
| 168 | 
         
            +
                        if isinstance(name, str):
         
     | 
| 169 | 
         
            +
                            # float(...) works for both scalar tensor and float number
         
     | 
| 170 | 
         
            +
                            errors_ret[name] = float(getattr(self, 'loss_' + name))
         
     | 
| 171 | 
         
            +
                    return errors_ret
         
     | 
| 172 | 
         
            +
             
     | 
| 173 | 
         
            +
                def save_networks(self, epoch):
         
     | 
| 174 | 
         
            +
                    """Save all the networks to the disk.
         
     | 
| 175 | 
         
            +
             
     | 
| 176 | 
         
            +
                    Parameters:
         
     | 
| 177 | 
         
            +
                        epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
         
     | 
| 178 | 
         
            +
                    """
         
     | 
| 179 | 
         
            +
                    for name in self.model_names:
         
     | 
| 180 | 
         
            +
                        if isinstance(name, str):
         
     | 
| 181 | 
         
            +
                            save_filename = '%s_net_%s.pth' % (epoch, name)
         
     | 
| 182 | 
         
            +
                            save_path = os.path.join(self.save_dir, save_filename)
         
     | 
| 183 | 
         
            +
                            # print(save_path)
         
     | 
| 184 | 
         
            +
                            net = getattr(self, 'net' + name)
         
     | 
| 185 | 
         
            +
             
     | 
| 186 | 
         
            +
                            if len(self.gpu_ids) > 0 and torch.cuda.is_available():
         
     | 
| 187 | 
         
            +
                                torch.save(net.state_dict(), save_path)
         
     | 
| 188 | 
         
            +
                                # net.cuda(self.gpu_ids[0])
         
     | 
| 189 | 
         
            +
                            else:
         
     | 
| 190 | 
         
            +
                                torch.save(net.cpu().state_dict(), save_path)
         
     | 
| 191 | 
         
            +
             
     | 
| 192 | 
         
            +
                    save_filename = '%s_net_opt.pth' % (epoch)
         
     | 
| 193 | 
         
            +
                    save_path = os.path.join(self.save_dir, save_filename)
         
     | 
| 194 | 
         
            +
                    save_dict = {'iter': str(self.iter // self.opt.print_freq * self.opt.print_freq)}
         
     | 
| 195 | 
         
            +
                    for i, name in enumerate(self.optimizer_names):
         
     | 
| 196 | 
         
            +
                        save_dict.update({name.lower(): self.optimizers[i].state_dict()})
         
     | 
| 197 | 
         
            +
                    torch.save(save_dict, save_path)
         
     | 
| 198 | 
         
            +
             
     | 
| 199 | 
         
            +
             
     | 
| 200 | 
         
            +
                def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
         
     | 
| 201 | 
         
            +
                    """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
         
     | 
| 202 | 
         
            +
                    key = keys[i]
         
     | 
| 203 | 
         
            +
                    if i + 1 == len(keys):  # at the end, pointing to a parameter/buffer
         
     | 
| 204 | 
         
            +
                        if module.__class__.__name__.startswith('InstanceNorm') and \
         
     | 
| 205 | 
         
            +
                                (key == 'running_mean' or key == 'running_var'):
         
     | 
| 206 | 
         
            +
                            if getattr(module, key) is None:
         
     | 
| 207 | 
         
            +
                                state_dict.pop('.'.join(keys))
         
     | 
| 208 | 
         
            +
                        if module.__class__.__name__.startswith('InstanceNorm') and \
         
     | 
| 209 | 
         
            +
                           (key == 'num_batches_tracked'):
         
     | 
| 210 | 
         
            +
                            state_dict.pop('.'.join(keys))
         
     | 
| 211 | 
         
            +
                    else:
         
     | 
| 212 | 
         
            +
                        self.__patch_instance_norm_state_dict(
         
     | 
| 213 | 
         
            +
                            state_dict, getattr(module, key), keys, i + 1)
         
     | 
| 214 | 
         
            +
             
     | 
| 215 | 
         
            +
                def load_networks(self, epoch):
         
     | 
| 216 | 
         
            +
                    """Load all the networks from the disk.
         
     | 
| 217 | 
         
            +
             
     | 
| 218 | 
         
            +
                    Parameters:
         
     | 
| 219 | 
         
            +
                        epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
         
     | 
| 220 | 
         
            +
                    """
         
     | 
| 221 | 
         
            +
                    for name in self.model_names:
         
     | 
| 222 | 
         
            +
                        if isinstance(name, str):
         
     | 
| 223 | 
         
            +
                            load_filename = '%s_net_%s.pth' % (epoch, name)
         
     | 
| 224 | 
         
            +
                            load_path = os.path.join(self.save_dir, load_filename)
         
     | 
| 225 | 
         
            +
                            net = getattr(self, 'net' + name)
         
     | 
| 226 | 
         
            +
                            # if isinstance(net, torch.nn.DataParallel):
         
     | 
| 227 | 
         
            +
                            if isinstance(net, DDP):
         
     | 
| 228 | 
         
            +
                                net = net.module
         
     | 
| 229 | 
         
            +
                            # print(net)
         
     | 
| 230 | 
         
            +
                            print('loading the model from %s' % load_path)
         
     | 
| 231 | 
         
            +
                            # if you are using PyTorch newer than 0.4 (e.g., built from
         
     | 
| 232 | 
         
            +
                            # GitHub source), you can remove str() on self.device
         
     | 
| 233 | 
         
            +
                            state_dict = torch.load(
         
     | 
| 234 | 
         
            +
                                load_path, map_location=lambda storage, loc: storage.cuda())
         
     | 
| 235 | 
         
            +
                            if hasattr(state_dict, '_metadata'):
         
     | 
| 236 | 
         
            +
                                del state_dict._metadata
         
     | 
| 237 | 
         
            +
             
     | 
| 238 | 
         
            +
                            # patch InstanceNorm checkpoints prior to 0.4
         
     | 
| 239 | 
         
            +
                            # need to copy keys here because we mutate in loop
         
     | 
| 240 | 
         
            +
                            #for key in list(state_dict.keys()):
         
     | 
| 241 | 
         
            +
                            #    self.__patch_instance_norm_state_dict(
         
     | 
| 242 | 
         
            +
                            #        state_dict, net, key.split('.'))
         
     | 
| 243 | 
         
            +
             
     | 
| 244 | 
         
            +
                            net.load_state_dict(state_dict)
         
     | 
| 245 | 
         
            +
                            del state_dict
         
     | 
| 246 | 
         
            +
             
     | 
| 247 | 
         
            +
                def print_networks(self, verbose):
         
     | 
| 248 | 
         
            +
                    """Print the total number of parameters in the network and (if verbose) network architecture
         
     | 
| 249 | 
         
            +
             
     | 
| 250 | 
         
            +
                    Parameters:
         
     | 
| 251 | 
         
            +
                        verbose (bool) -- if verbose: print the network architecture
         
     | 
| 252 | 
         
            +
                    """
         
     | 
| 253 | 
         
            +
                    print('---------- Networks initialized -------------')
         
     | 
| 254 | 
         
            +
                    for name in self.model_names:
         
     | 
| 255 | 
         
            +
                        if isinstance(name, str):
         
     | 
| 256 | 
         
            +
                            net = getattr(self, 'net' + name)
         
     | 
| 257 | 
         
            +
                            num_params = 0
         
     | 
| 258 | 
         
            +
                            for param in net.parameters():
         
     | 
| 259 | 
         
            +
                                num_params += param.numel()
         
     | 
| 260 | 
         
            +
                            if verbose:
         
     | 
| 261 | 
         
            +
                                print(net)
         
     | 
| 262 | 
         
            +
                            print('[Network %s] Total number of parameters : %.3f M' %
         
     | 
| 263 | 
         
            +
                                  (name, num_params / 1e6))
         
     | 
| 264 | 
         
            +
                    print('-----------------------------------------------')
         
     | 
| 265 | 
         
            +
             
     | 
| 266 | 
         
            +
                def set_requires_grad(self, nets, requires_grad=False):
         
     | 
| 267 | 
         
            +
                    """Set requires_grad=False for all the networks to avoid unnecessary computations
         
     | 
| 268 | 
         
            +
                    Parameters:
         
     | 
| 269 | 
         
            +
                        nets (network list)   -- a list of networks
         
     | 
| 270 | 
         
            +
                        requires_grad (bool)  -- whether the networks require gradients or not
         
     | 
| 271 | 
         
            +
                    """
         
     | 
| 272 | 
         
            +
                    if not isinstance(nets, list):
         
     | 
| 273 | 
         
            +
                        nets = [nets]
         
     | 
| 274 | 
         
            +
                    for net in nets:
         
     | 
| 275 | 
         
            +
                        if net is not None:
         
     | 
| 276 | 
         
            +
                            for param in net.parameters():
         
     | 
| 277 | 
         
            +
                                param.requires_grad = requires_grad
         
     | 
    	
        BidirectionalTranslation/models/cycle_ganstft_model.py
    ADDED
    
    | 
         @@ -0,0 +1,103 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import random
         
     | 
| 2 | 
         
            +
            import torch
         
     | 
| 3 | 
         
            +
            from .base_model import BaseModel
         
     | 
| 4 | 
         
            +
            from . import networks
         
     | 
| 5 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 6 | 
         
            +
            from torch.nn.parallel import DistributedDataParallel as DDP
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            class CycleGANSTFTModel(BaseModel):
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
                def __init__(self, opt, ckpt_root):
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
                    BaseModel.__init__(self, opt)
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
                    use_vae = True
         
     | 
| 16 | 
         
            +
                    self.interchnnls = 4
         
     | 
| 17 | 
         
            +
                    use_noise = False
         
     | 
| 18 | 
         
            +
                    self.half_size = opt.batch_size //2
         
     | 
| 19 | 
         
            +
                    self.device=opt.local_rank
         
     | 
| 20 | 
         
            +
                    self.gpu_ids=[self.device]
         
     | 
| 21 | 
         
            +
                    self.local_rank = opt.local_rank
         
     | 
| 22 | 
         
            +
                    self.cropsize = opt.crop_size
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
                    self.model_names = ['G_INTSCR2RGB','G_RGB2INTSCR','E']
         
     | 
| 25 | 
         
            +
                    self.netG_INTSCR2RGB = networks.define_G(self.interchnnls + 1, 3, opt.nz, opt.ngf, netG='unet_256', 
         
     | 
| 26 | 
         
            +
                                                  norm='layer', nl='lrelu', use_dropout=opt.use_dropout, init_type='kaiming', init_gain=opt.init_gain,
         
     | 
| 27 | 
         
            +
                                                  gpu_ids=self.gpu_ids, where_add='all', upsample='bilinear', use_noise=use_noise)
         
     | 
| 28 | 
         
            +
                    self.netG_RGB2INTSCR = networks.define_G(4, self.interchnnls, 0, opt.ngf, netG='unet_256', 
         
     | 
| 29 | 
         
            +
                                                  norm='layer', nl='lrelu', use_dropout=opt.use_dropout, init_type='kaiming', init_gain=opt.init_gain,
         
     | 
| 30 | 
         
            +
                                                  gpu_ids=self.gpu_ids, where_add='input', upsample='bilinear', use_noise=use_noise)        
         
     | 
| 31 | 
         
            +
                    self.netE = networks.define_E(opt.output_nc, opt.nz, opt.nef, netE=opt.netE, norm='none', nl='lrelu',
         
     | 
| 32 | 
         
            +
                                                  init_type='xavier', init_gain=opt.init_gain, gpu_ids=self.gpu_ids, vaeLike=use_vae)
         
     | 
| 33 | 
         
            +
                    self.nets = [self.netG_INTSCR2RGB, self.netG_RGB2INTSCR, self.netE]
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
                    self.netSVAE = networks.define_SVAE(inc=1, outc=self.interchnnls, outplanes=64, blocks=3, netVAE='SVAE', 
         
     | 
| 36 | 
         
            +
                        save_dir= ckpt_root+'/ScreenStyle/ScreenVAE',init_type=opt.init_type, init_gain=opt.init_gain, gpu_ids=self.gpu_ids)
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
                def set_input(self, input):
         
     | 
| 40 | 
         
            +
                    AtoB = self.opt.direction == 'AtoB'
         
     | 
| 41 | 
         
            +
                    self.real_RGB = input['A'].to(self.device)
         
     | 
| 42 | 
         
            +
                    self.real_Ai = self.grayscale(self.real_RGB)
         
     | 
| 43 | 
         
            +
                    self.real_L = input['L'].to(self.device)
         
     | 
| 44 | 
         
            +
                    self.real_ML = input['Bl'].to(self.device)
         
     | 
| 45 | 
         
            +
                    self.real_M = input['B'].to(self.device)
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
                    self.h = input['h']
         
     | 
| 48 | 
         
            +
                    self.w = input['w']
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
                def grayscale(self, input_image):
         
     | 
| 51 | 
         
            +
                    rate = torch.Tensor([0.299, 0.587, 0.114]).reshape(1, 3, 1, 1).to(input_image.device)
         
     | 
| 52 | 
         
            +
                    # tmp = input_image[:,0, ...] * 0.299 + input_image[:,1, ...] * 0.587 + input_image[:,2, ...] * 0.114
         
     | 
| 53 | 
         
            +
                    return (input_image*rate).sum(1,keepdims=True)
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
                def forward(self, AtoB=True, sty=None):
         
     | 
| 56 | 
         
            +
                    if AtoB:
         
     | 
| 57 | 
         
            +
                        real_LRGB = torch.cat([self.real_L, self.real_RGB],1)
         
     | 
| 58 | 
         
            +
                        fake_SCR = self.netG_RGB2INTSCR(real_LRGB)
         
     | 
| 59 | 
         
            +
                        fake_M = self.netSVAE(fake_SCR, line=self.real_L, img_input=False)
         
     | 
| 60 | 
         
            +
                        fake_M = torch.clamp(fake_M, -1,1)
         
     | 
| 61 | 
         
            +
                        fake_M2 = self.norm(torch.mul(self.denorm(fake_M), self.denorm(self.real_L)))#*self.mask2
         
     | 
| 62 | 
         
            +
                        return fake_M[:,:,:self.h, :self.w], fake_M2[:,:,:self.h, :self.w], fake_SCR[:,:,:self.h, :self.w]
         
     | 
| 63 | 
         
            +
                    else:
         
     | 
| 64 | 
         
            +
                        if sty is None:  # use encoded z
         
     | 
| 65 | 
         
            +
                            z0, _ = self.netE(self.real_RGB)
         
     | 
| 66 | 
         
            +
                        else:
         
     | 
| 67 | 
         
            +
                            z0 = sty
         
     | 
| 68 | 
         
            +
                            # z0 = self.get_z_random(self.real_A.size(0), self.opt.nz)
         
     | 
| 69 | 
         
            +
                        real_SCR = self.netSVAE(self.real_M, self.real_ML, output_screen_only=True) #8
         
     | 
| 70 | 
         
            +
                        real_LSCR = torch.cat([self.real_ML, real_SCR], 1)
         
     | 
| 71 | 
         
            +
                        fake_nRGB = self.netG_INTSCR2RGB(real_LSCR, z0)
         
     | 
| 72 | 
         
            +
                        fake_nRGB = torch.clamp(fake_nRGB, -1,1)
         
     | 
| 73 | 
         
            +
                        fake_RGB = self.norm(torch.mul(self.denorm(fake_nRGB), self.denorm(self.real_ML)))
         
     | 
| 74 | 
         
            +
                        return fake_RGB[:,:,:self.h, :self.w], real_SCR[:,:,:self.h, :self.w], self.real_ML[:,:,:self.h, :self.w]
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
                def norm(self, im):
         
     | 
| 77 | 
         
            +
                    return im * 2.0 - 1
         
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
                def denorm(self, im):
         
     | 
| 80 | 
         
            +
                    return (im + 1) / 2.0
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
                def optimize_parameters(self):
         
     | 
| 83 | 
         
            +
                    pass
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
                def get_z_random(self, batch_size, nz, random_type='gauss', truncation=False, tvalue=1):
         
     | 
| 86 | 
         
            +
                    z = None
         
     | 
| 87 | 
         
            +
                    if random_type == 'uni':
         
     | 
| 88 | 
         
            +
                        z = torch.rand(batch_size, nz) * 2.0 - 1.0
         
     | 
| 89 | 
         
            +
                    elif random_type == 'gauss':
         
     | 
| 90 | 
         
            +
                        z = torch.randn(batch_size, nz) * tvalue
         
     | 
| 91 | 
         
            +
                        # do the truncation trick
         
     | 
| 92 | 
         
            +
                        if truncation:
         
     | 
| 93 | 
         
            +
                            k = 0
         
     | 
| 94 | 
         
            +
                            while (k < 15 * nz):
         
     | 
| 95 | 
         
            +
                                if torch.max(z) <= tvalue:
         
     | 
| 96 | 
         
            +
                                    break
         
     | 
| 97 | 
         
            +
                                zabs = torch.abs(z)
         
     | 
| 98 | 
         
            +
                                zz = torch.randn(batch_size, nz)
         
     | 
| 99 | 
         
            +
                                z[zabs > tvalue] = zz[zabs > tvalue]
         
     | 
| 100 | 
         
            +
                                k += 1
         
     | 
| 101 | 
         
            +
                            z = torch.clamp(z, -tvalue, tvalue)
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
                    return z.detach().to(self.device)
         
     | 
    	
        BidirectionalTranslation/models/networks.py
    ADDED
    
    | 
         @@ -0,0 +1,1375 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import torch
         
     | 
| 2 | 
         
            +
            import torch.nn as nn
         
     | 
| 3 | 
         
            +
            from torch.nn import init
         
     | 
| 4 | 
         
            +
            import functools
         
     | 
| 5 | 
         
            +
            from torch.optim import lr_scheduler
         
     | 
| 6 | 
         
            +
            import numpy as np
         
     | 
| 7 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 8 | 
         
            +
            from torch.nn.modules.normalization import LayerNorm
         
     | 
| 9 | 
         
            +
            import os
         
     | 
| 10 | 
         
            +
            from torch.nn.utils import spectral_norm
         
     | 
| 11 | 
         
            +
            from torchvision import models
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            ###############################################################################
         
     | 
| 14 | 
         
            +
            # Helper functions
         
     | 
| 15 | 
         
            +
            ###############################################################################
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            def init_weights(net, init_type='normal', init_gain=0.02):
         
     | 
| 19 | 
         
            +
                """Initialize network weights.
         
     | 
| 20 | 
         
            +
                Parameters:
         
     | 
| 21 | 
         
            +
                    net (network)   -- network to be initialized
         
     | 
| 22 | 
         
            +
                    init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
         
     | 
| 23 | 
         
            +
                    init_gain (float)    -- scaling factor for normal, xavier and orthogonal.
         
     | 
| 24 | 
         
            +
                We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
         
     | 
| 25 | 
         
            +
                work better for some applications. Feel free to try yourself.
         
     | 
| 26 | 
         
            +
                """
         
     | 
| 27 | 
         
            +
                def init_func(m):  # define the initialization function
         
     | 
| 28 | 
         
            +
                    classname = m.__class__.__name__
         
     | 
| 29 | 
         
            +
                    if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
         
     | 
| 30 | 
         
            +
                        if init_type == 'normal':
         
     | 
| 31 | 
         
            +
                            init.normal_(m.weight.data, 0.0, init_gain)
         
     | 
| 32 | 
         
            +
                        elif init_type == 'xavier':
         
     | 
| 33 | 
         
            +
                            init.xavier_normal_(m.weight.data, gain=init_gain)
         
     | 
| 34 | 
         
            +
                        elif init_type == 'kaiming':
         
     | 
| 35 | 
         
            +
                            #init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
         
     | 
| 36 | 
         
            +
                            init.kaiming_normal_(m.weight.data, a=0.2, mode='fan_in', nonlinearity='leaky_relu')
         
     | 
| 37 | 
         
            +
                        elif init_type == 'orthogonal':
         
     | 
| 38 | 
         
            +
                            init.orthogonal_(m.weight.data, gain=init_gain)
         
     | 
| 39 | 
         
            +
                        else:
         
     | 
| 40 | 
         
            +
                            raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
         
     | 
| 41 | 
         
            +
                        if hasattr(m, 'bias') and m.bias is not None:
         
     | 
| 42 | 
         
            +
                            init.constant_(m.bias.data, 0.0)
         
     | 
| 43 | 
         
            +
                    elif classname.find('BatchNorm2d') != -1:  # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
         
     | 
| 44 | 
         
            +
                        init.normal_(m.weight.data, 1.0, init_gain)
         
     | 
| 45 | 
         
            +
                        init.constant_(m.bias.data, 0.0)
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
                print('initialize network with %s' % init_type)
         
     | 
| 48 | 
         
            +
                net.apply(init_func)  # apply the initialization function <init_func>
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
            def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[], init=True):
         
     | 
| 52 | 
         
            +
                """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
         
     | 
| 53 | 
         
            +
                Parameters:
         
     | 
| 54 | 
         
            +
                    net (network)      -- the network to be initialized
         
     | 
| 55 | 
         
            +
                    init_type (str)    -- the name of an initialization method: normal | xavier | kaiming | orthogonal
         
     | 
| 56 | 
         
            +
                    gain (float)       -- scaling factor for normal, xavier and orthogonal.
         
     | 
| 57 | 
         
            +
                    gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
         
     | 
| 58 | 
         
            +
                Return an initialized network.
         
     | 
| 59 | 
         
            +
                """
         
     | 
| 60 | 
         
            +
                if len(gpu_ids) > 0:
         
     | 
| 61 | 
         
            +
                    assert(torch.cuda.is_available())
         
     | 
| 62 | 
         
            +
                    net.to(gpu_ids[0])
         
     | 
| 63 | 
         
            +
                if init:
         
     | 
| 64 | 
         
            +
                    init_weights(net, init_type, init_gain=init_gain)
         
     | 
| 65 | 
         
            +
                return net
         
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
            def get_scheduler(optimizer, opt):
         
     | 
| 69 | 
         
            +
                """Return a learning rate scheduler
         
     | 
| 70 | 
         
            +
                Parameters:
         
     | 
| 71 | 
         
            +
                    optimizer          -- the optimizer of the network
         
     | 
| 72 | 
         
            +
                    opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions. 
         
     | 
| 73 | 
         
            +
                                          opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine
         
     | 
| 74 | 
         
            +
                For 'linear', we keep the same learning rate for the first <opt.niter> epochs
         
     | 
| 75 | 
         
            +
                and linearly decay the rate to zero over the next <opt.niter_decay> epochs.
         
     | 
| 76 | 
         
            +
                For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
         
     | 
| 77 | 
         
            +
                See https://pytorch.org/docs/stable/optim.html for more details.
         
     | 
| 78 | 
         
            +
                """
         
     | 
| 79 | 
         
            +
                if opt.lr_policy == 'linear':
         
     | 
| 80 | 
         
            +
                    def lambda_rule(epoch):
         
     | 
| 81 | 
         
            +
                        lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1)
         
     | 
| 82 | 
         
            +
                        return lr_l
         
     | 
| 83 | 
         
            +
                    scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
         
     | 
| 84 | 
         
            +
                elif opt.lr_policy == 'step':
         
     | 
| 85 | 
         
            +
                    scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
         
     | 
| 86 | 
         
            +
                elif opt.lr_policy == 'plateau':
         
     | 
| 87 | 
         
            +
                    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
         
     | 
| 88 | 
         
            +
                elif opt.lr_policy == 'cosine':
         
     | 
| 89 | 
         
            +
                    scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.niter, eta_min=0)
         
     | 
| 90 | 
         
            +
                else:
         
     | 
| 91 | 
         
            +
                    return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
         
     | 
| 92 | 
         
            +
                return scheduler
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
            class LayerNormWarpper(nn.Module):
         
     | 
| 95 | 
         
            +
                def __init__(self, num_features):
         
     | 
| 96 | 
         
            +
                    super(LayerNormWarpper, self).__init__()
         
     | 
| 97 | 
         
            +
                    self.num_features = int(num_features)
         
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
                def forward(self, x):
         
     | 
| 100 | 
         
            +
                    x = nn.LayerNorm([self.num_features, x.size()[2], x.size()[3]], elementwise_affine=False).cuda()(x)
         
     | 
| 101 | 
         
            +
                    return x
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
            def get_norm_layer(norm_type='instance'):
         
     | 
| 104 | 
         
            +
                """Return a normalization layer
         
     | 
| 105 | 
         
            +
                Parameters:
         
     | 
| 106 | 
         
            +
                    norm_type (str) -- the name of the normalization layer: batch | instance | none
         
     | 
| 107 | 
         
            +
                For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
         
     | 
| 108 | 
         
            +
                For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
         
     | 
| 109 | 
         
            +
                """
         
     | 
| 110 | 
         
            +
                if norm_type == 'batch':
         
     | 
| 111 | 
         
            +
                    norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
         
     | 
| 112 | 
         
            +
                elif norm_type == 'instance':
         
     | 
| 113 | 
         
            +
                    norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
         
     | 
| 114 | 
         
            +
                elif norm_type == 'layer':
         
     | 
| 115 | 
         
            +
                    norm_layer = functools.partial(LayerNormWarpper)
         
     | 
| 116 | 
         
            +
                elif norm_type == 'none':
         
     | 
| 117 | 
         
            +
                    norm_layer = None
         
     | 
| 118 | 
         
            +
                else:
         
     | 
| 119 | 
         
            +
                    raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
         
     | 
| 120 | 
         
            +
                return norm_layer
         
     | 
| 121 | 
         
            +
             
     | 
| 122 | 
         
            +
             
     | 
| 123 | 
         
            +
            def get_non_linearity(layer_type='relu'):
         
     | 
| 124 | 
         
            +
                if layer_type == 'relu':
         
     | 
| 125 | 
         
            +
                    nl_layer = functools.partial(nn.ReLU, inplace=True)
         
     | 
| 126 | 
         
            +
                elif layer_type == 'lrelu':
         
     | 
| 127 | 
         
            +
                    nl_layer = functools.partial(
         
     | 
| 128 | 
         
            +
                        nn.LeakyReLU, negative_slope=0.2, inplace=True)
         
     | 
| 129 | 
         
            +
                elif layer_type == 'elu':
         
     | 
| 130 | 
         
            +
                    nl_layer = functools.partial(nn.ELU, inplace=True)
         
     | 
| 131 | 
         
            +
                elif layer_type == 'selu':
         
     | 
| 132 | 
         
            +
                    nl_layer = functools.partial(nn.SELU, inplace=True)
         
     | 
| 133 | 
         
            +
                elif layer_type == 'prelu':
         
     | 
| 134 | 
         
            +
                    nl_layer = functools.partial(nn.PReLU)
         
     | 
| 135 | 
         
            +
                else:
         
     | 
| 136 | 
         
            +
                    raise NotImplementedError(
         
     | 
| 137 | 
         
            +
                        'nonlinearity activitation [%s] is not found' % layer_type)
         
     | 
| 138 | 
         
            +
                return nl_layer
         
     | 
| 139 | 
         
            +
             
     | 
| 140 | 
         
            +
             
     | 
| 141 | 
         
            +
            def define_G(input_nc, output_nc, nz, ngf, netG='unet_128', norm='batch', nl='relu', use_noise=False, 
         
     | 
| 142 | 
         
            +
                         use_dropout=False, init_type='xavier', init_gain=0.02, gpu_ids=[], where_add='input', upsample='bilinear'):
         
     | 
| 143 | 
         
            +
                net = None
         
     | 
| 144 | 
         
            +
                norm_layer = get_norm_layer(norm_type=norm)
         
     | 
| 145 | 
         
            +
                nl_layer = get_non_linearity(layer_type=nl)
         
     | 
| 146 | 
         
            +
                # print(norm, norm_layer)
         
     | 
| 147 | 
         
            +
             
     | 
| 148 | 
         
            +
                if nz == 0:
         
     | 
| 149 | 
         
            +
                    where_add = 'input'
         
     | 
| 150 | 
         
            +
             
     | 
| 151 | 
         
            +
                if netG == 'unet_128' and where_add == 'input':
         
     | 
| 152 | 
         
            +
                    net = G_Unet_add_input(input_nc, output_nc, nz, 7, ngf, norm_layer=norm_layer, nl_layer=nl_layer, use_noise=use_noise,
         
     | 
| 153 | 
         
            +
                                           use_dropout=use_dropout, upsample=upsample, device=gpu_ids)
         
     | 
| 154 | 
         
            +
                elif netG == 'unet_128_G' and where_add == 'input':
         
     | 
| 155 | 
         
            +
                    net = G_Unet_add_input_G(input_nc, output_nc, nz, 7, ngf, norm_layer=norm_layer, nl_layer=nl_layer, use_noise=use_noise,
         
     | 
| 156 | 
         
            +
                                           use_dropout=use_dropout, upsample=upsample, device=gpu_ids)
         
     | 
| 157 | 
         
            +
                elif netG == 'unet_256' and where_add == 'input':
         
     | 
| 158 | 
         
            +
                    net = G_Unet_add_input(input_nc, output_nc, nz, 8, ngf, norm_layer=norm_layer, nl_layer=nl_layer, use_noise=use_noise,
         
     | 
| 159 | 
         
            +
                                           use_dropout=use_dropout, upsample=upsample, device=gpu_ids)
         
     | 
| 160 | 
         
            +
                elif netG == 'unet_256_G' and where_add == 'input':
         
     | 
| 161 | 
         
            +
                    net = G_Unet_add_input_G(input_nc, output_nc, nz, 8, ngf, norm_layer=norm_layer, nl_layer=nl_layer, use_noise=use_noise,
         
     | 
| 162 | 
         
            +
                                           use_dropout=use_dropout, upsample=upsample, device=gpu_ids)
         
     | 
| 163 | 
         
            +
                elif netG == 'unet_128' and where_add == 'all':
         
     | 
| 164 | 
         
            +
                    net = G_Unet_add_all(input_nc, output_nc, nz, 7, ngf, norm_layer=norm_layer, nl_layer=nl_layer, use_noise=use_noise,
         
     | 
| 165 | 
         
            +
                                         use_dropout=use_dropout, upsample=upsample)
         
     | 
| 166 | 
         
            +
                elif netG == 'unet_256' and where_add == 'all':
         
     | 
| 167 | 
         
            +
                    net = G_Unet_add_all(input_nc, output_nc, nz, 8, ngf, norm_layer=norm_layer, nl_layer=nl_layer, use_noise=use_noise,
         
     | 
| 168 | 
         
            +
                                         use_dropout=use_dropout, upsample=upsample)
         
     | 
| 169 | 
         
            +
                else:
         
     | 
| 170 | 
         
            +
                    raise NotImplementedError('Generator model name [%s] is not recognized' % net)
         
     | 
| 171 | 
         
            +
                # print(net)
         
     | 
| 172 | 
         
            +
                return init_net(net, init_type, init_gain, gpu_ids)
         
     | 
| 173 | 
         
            +
             
     | 
| 174 | 
         
            +
             
     | 
| 175 | 
         
            +
            def define_C(input_nc, output_nc, nz, ngf, netC='unet_128', norm='instance', nl='relu',
         
     | 
| 176 | 
         
            +
                         use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[], upsample='basic'):
         
     | 
| 177 | 
         
            +
                net = None
         
     | 
| 178 | 
         
            +
                norm_layer = get_norm_layer(norm_type=norm)
         
     | 
| 179 | 
         
            +
                nl_layer = get_non_linearity(layer_type=nl)
         
     | 
| 180 | 
         
            +
             
     | 
| 181 | 
         
            +
                if netC == 'resnet_9blocks':
         
     | 
| 182 | 
         
            +
                    net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9)
         
     | 
| 183 | 
         
            +
                elif netC == 'resnet_6blocks':
         
     | 
| 184 | 
         
            +
                    net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6)
         
     | 
| 185 | 
         
            +
                elif netC == 'unet_128':
         
     | 
| 186 | 
         
            +
                    net = G_Unet_add_input_C(input_nc, output_nc, 0, 7, ngf, norm_layer=norm_layer, nl_layer=nl_layer,
         
     | 
| 187 | 
         
            +
                                           use_dropout=use_dropout, upsample=upsample)
         
     | 
| 188 | 
         
            +
                elif netC == 'unet_256':
         
     | 
| 189 | 
         
            +
                    net = G_Unet_add_input(input_nc, output_nc, 0, 8, ngf, norm_layer=norm_layer, nl_layer=nl_layer,
         
     | 
| 190 | 
         
            +
                                           use_dropout=use_dropout, upsample=upsample)
         
     | 
| 191 | 
         
            +
                elif netC == 'unet_32':
         
     | 
| 192 | 
         
            +
                    net = G_Unet_add_input(input_nc, output_nc, 0, 5, ngf, norm_layer=norm_layer, nl_layer=nl_layer,
         
     | 
| 193 | 
         
            +
                                           use_dropout=use_dropout, upsample=upsample)
         
     | 
| 194 | 
         
            +
                else:
         
     | 
| 195 | 
         
            +
                    raise NotImplementedError('Generator model name [%s] is not recognized' % net)
         
     | 
| 196 | 
         
            +
             
     | 
| 197 | 
         
            +
                return init_net(net, init_type, init_gain, gpu_ids)
         
     | 
| 198 | 
         
            +
             
     | 
| 199 | 
         
            +
             
     | 
| 200 | 
         
            +
            def define_D(input_nc, ndf, netD, norm='batch', nl='lrelu', init_type='xavier', init_gain=0.02, num_Ds=1, gpu_ids=[]):
         
     | 
| 201 | 
         
            +
                net = None
         
     | 
| 202 | 
         
            +
                norm_layer = get_norm_layer(norm_type=norm)
         
     | 
| 203 | 
         
            +
                nl = 'lrelu'  # use leaky relu for D
         
     | 
| 204 | 
         
            +
                nl_layer = get_non_linearity(layer_type=nl)
         
     | 
| 205 | 
         
            +
             
     | 
| 206 | 
         
            +
                if netD == 'basic_128':
         
     | 
| 207 | 
         
            +
                    net = D_NLayers(input_nc, ndf, n_layers=2, norm_layer=norm_layer, nl_layer=nl_layer)
         
     | 
| 208 | 
         
            +
                elif netD == 'basic_256':
         
     | 
| 209 | 
         
            +
                    net = D_NLayers(input_nc, ndf, n_layers=3, norm_layer=norm_layer, nl_layer=nl_layer)
         
     | 
| 210 | 
         
            +
                elif netD == 'basic_128_multi':
         
     | 
| 211 | 
         
            +
                    net = D_NLayersMulti(input_nc=input_nc, ndf=ndf, n_layers=2, norm_layer=norm_layer, num_D=num_Ds, nl_layer=nl_layer)
         
     | 
| 212 | 
         
            +
                elif netD == 'basic_256_multi':
         
     | 
| 213 | 
         
            +
                    net = D_NLayersMulti(input_nc=input_nc, ndf=ndf, n_layers=3, norm_layer=norm_layer, num_D=num_Ds, nl_layer=nl_layer)
         
     | 
| 214 | 
         
            +
                else:
         
     | 
| 215 | 
         
            +
                    raise NotImplementedError('Discriminator model name [%s] is not recognized' % net)
         
     | 
| 216 | 
         
            +
                return init_net(net, init_type, init_gain, gpu_ids)
         
     | 
| 217 | 
         
            +
             
     | 
| 218 | 
         
            +
             
     | 
| 219 | 
         
            +
            def define_E(input_nc, output_nc, ndf, netE, norm='batch', nl='lrelu',
         
     | 
| 220 | 
         
            +
                         init_type='xavier', init_gain=0.02, gpu_ids=[], vaeLike=False):
         
     | 
| 221 | 
         
            +
                net = None
         
     | 
| 222 | 
         
            +
                norm_layer = get_norm_layer(norm_type=norm)
         
     | 
| 223 | 
         
            +
                nl = 'lrelu'  # use leaky relu for E
         
     | 
| 224 | 
         
            +
                nl_layer = get_non_linearity(layer_type=nl)
         
     | 
| 225 | 
         
            +
                if netE == 'resnet_128':
         
     | 
| 226 | 
         
            +
                    net = E_ResNet(input_nc, output_nc, ndf, n_blocks=4, norm_layer=norm_layer,
         
     | 
| 227 | 
         
            +
                                   nl_layer=nl_layer, vaeLike=vaeLike)
         
     | 
| 228 | 
         
            +
                elif netE == 'resnet_256':
         
     | 
| 229 | 
         
            +
                    net = E_ResNet(input_nc, output_nc, ndf, n_blocks=5, norm_layer=norm_layer,
         
     | 
| 230 | 
         
            +
                                   nl_layer=nl_layer, vaeLike=vaeLike)
         
     | 
| 231 | 
         
            +
                elif netE == 'conv_128':
         
     | 
| 232 | 
         
            +
                    net = E_NLayers(input_nc, output_nc, ndf, n_layers=4, norm_layer=norm_layer,
         
     | 
| 233 | 
         
            +
                                    nl_layer=nl_layer, vaeLike=vaeLike)
         
     | 
| 234 | 
         
            +
                elif netE == 'conv_256':
         
     | 
| 235 | 
         
            +
                    net = E_NLayers(input_nc, output_nc, ndf, n_layers=5, norm_layer=norm_layer,
         
     | 
| 236 | 
         
            +
                                    nl_layer=nl_layer, vaeLike=vaeLike)
         
     | 
| 237 | 
         
            +
                else:
         
     | 
| 238 | 
         
            +
                    raise NotImplementedError('Encoder model name [%s] is not recognized' % net)
         
     | 
| 239 | 
         
            +
             
     | 
| 240 | 
         
            +
                return init_net(net, init_type, init_gain, gpu_ids, False)
         
     | 
| 241 | 
         
            +
             
     | 
| 242 | 
         
            +
             
     | 
| 243 | 
         
            +
            class ResnetGenerator(nn.Module):
         
     | 
| 244 | 
         
            +
                def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, norm_layer=None, use_dropout=False, n_blocks=6, padding_type='replicate'):
         
     | 
| 245 | 
         
            +
                    assert(n_blocks >= 0)
         
     | 
| 246 | 
         
            +
                    super(ResnetGenerator, self).__init__()
         
     | 
| 247 | 
         
            +
                    self.input_nc = input_nc
         
     | 
| 248 | 
         
            +
                    self.output_nc = output_nc
         
     | 
| 249 | 
         
            +
                    self.ngf = ngf
         
     | 
| 250 | 
         
            +
                    if type(norm_layer) == functools.partial:  # no need to use bias as BatchNorm2d has affine parameters
         
     | 
| 251 | 
         
            +
                        use_bias = norm_layer.func != nn.BatchNorm2d
         
     | 
| 252 | 
         
            +
                    else:
         
     | 
| 253 | 
         
            +
                        use_bias = norm_layer != nn.BatchNorm2d
         
     | 
| 254 | 
         
            +
             
     | 
| 255 | 
         
            +
                    model = [nn.ReplicationPad2d(3),
         
     | 
| 256 | 
         
            +
                             nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0,
         
     | 
| 257 | 
         
            +
                                       bias=use_bias)]
         
     | 
| 258 | 
         
            +
                    if norm_layer is not None:
         
     | 
| 259 | 
         
            +
                        model += [norm_layer(ngf)]
         
     | 
| 260 | 
         
            +
                    model += [nn.ReLU(True)]
         
     | 
| 261 | 
         
            +
             
     | 
| 262 | 
         
            +
                    # n_downsampling = 2
         
     | 
| 263 | 
         
            +
                    for i in range(n_downsampling):
         
     | 
| 264 | 
         
            +
                        mult = 2**i
         
     | 
| 265 | 
         
            +
                        model += [nn.ReplicationPad2d(1),nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
         
     | 
| 266 | 
         
            +
                                            stride=2, padding=0, bias=use_bias)]
         
     | 
| 267 | 
         
            +
                        # model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
         
     | 
| 268 | 
         
            +
                        #                     stride=2, padding=1, bias=use_bias)]
         
     | 
| 269 | 
         
            +
                        if norm_layer is not None:
         
     | 
| 270 | 
         
            +
                            model += [norm_layer(ngf * mult * 2)]
         
     | 
| 271 | 
         
            +
                        model += [nn.ReLU(True)]
         
     | 
| 272 | 
         
            +
             
     | 
| 273 | 
         
            +
                    mult = 2**n_downsampling
         
     | 
| 274 | 
         
            +
                    for i in range(n_blocks):
         
     | 
| 275 | 
         
            +
                        model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
         
     | 
| 276 | 
         
            +
             
     | 
| 277 | 
         
            +
                    for i in range(n_downsampling):
         
     | 
| 278 | 
         
            +
                        mult = 2**(n_downsampling - i)
         
     | 
| 279 | 
         
            +
                        # model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
         
     | 
| 280 | 
         
            +
                        #                              kernel_size=3, stride=2,
         
     | 
| 281 | 
         
            +
                        #                              padding=1, output_padding=1,
         
     | 
| 282 | 
         
            +
                        #                              bias=use_bias)]
         
     | 
| 283 | 
         
            +
                        # if norm_layer is not None:
         
     | 
| 284 | 
         
            +
                        #     model += [norm_layer(ngf * mult / 2)]
         
     | 
| 285 | 
         
            +
                        # model += [nn.ReLU(True)]
         
     | 
| 286 | 
         
            +
                        model += upsampleLayer(ngf * mult, int(ngf * mult / 2), upsample='bilinear', padding_type=padding_type)
         
     | 
| 287 | 
         
            +
                        if norm_layer is not None:
         
     | 
| 288 | 
         
            +
                            model += [norm_layer(int(ngf * mult / 2))]
         
     | 
| 289 | 
         
            +
                        model += [nn.ReLU(True)]
         
     | 
| 290 | 
         
            +
                        model +=[nn.ReplicationPad2d(1),
         
     | 
| 291 | 
         
            +
                                 nn.Conv2d(int(ngf * mult / 2), int(ngf * mult / 2), kernel_size=3, padding=0)]
         
     | 
| 292 | 
         
            +
                        if norm_layer is not None:
         
     | 
| 293 | 
         
            +
                            model += [norm_layer(ngf * mult / 2)]
         
     | 
| 294 | 
         
            +
                        model += [nn.ReLU(True)]
         
     | 
| 295 | 
         
            +
                    model += [nn.ReplicationPad2d(3)]
         
     | 
| 296 | 
         
            +
                    model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
         
     | 
| 297 | 
         
            +
                    #model += [nn.Tanh()]
         
     | 
| 298 | 
         
            +
             
     | 
| 299 | 
         
            +
                    self.model = nn.Sequential(*model)
         
     | 
| 300 | 
         
            +
             
     | 
| 301 | 
         
            +
                def forward(self, input):
         
     | 
| 302 | 
         
            +
                    return self.model(input)
         
     | 
| 303 | 
         
            +
             
     | 
| 304 | 
         
            +
             
     | 
| 305 | 
         
            +
            # Define a resnet block
         
     | 
| 306 | 
         
            +
            class ResnetBlock(nn.Module):
         
     | 
| 307 | 
         
            +
                def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
         
     | 
| 308 | 
         
            +
                    super(ResnetBlock, self).__init__()
         
     | 
| 309 | 
         
            +
                    self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
         
     | 
| 310 | 
         
            +
             
     | 
| 311 | 
         
            +
                def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
         
     | 
| 312 | 
         
            +
                    conv_block = []
         
     | 
| 313 | 
         
            +
                    p = 0
         
     | 
| 314 | 
         
            +
                    if padding_type == 'reflect':
         
     | 
| 315 | 
         
            +
                        conv_block += [nn.ReflectionPad2d(1)]
         
     | 
| 316 | 
         
            +
                    elif padding_type == 'replicate':
         
     | 
| 317 | 
         
            +
                        conv_block += [nn.ReplicationPad2d(1)]
         
     | 
| 318 | 
         
            +
                    elif padding_type == 'zero':
         
     | 
| 319 | 
         
            +
                        p = 1
         
     | 
| 320 | 
         
            +
                    else:
         
     | 
| 321 | 
         
            +
                        raise NotImplementedError('padding [%s] is not implemented' % padding_type)
         
     | 
| 322 | 
         
            +
             
     | 
| 323 | 
         
            +
                    conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias)]
         
     | 
| 324 | 
         
            +
                    if norm_layer is not None:
         
     | 
| 325 | 
         
            +
                        conv_block += [norm_layer(dim)]
         
     | 
| 326 | 
         
            +
                    conv_block += [nn.ReLU(True)]
         
     | 
| 327 | 
         
            +
                    # if use_dropout:
         
     | 
| 328 | 
         
            +
                    #     conv_block += [nn.Dropout(0.5)]
         
     | 
| 329 | 
         
            +
             
     | 
| 330 | 
         
            +
                    p = 0
         
     | 
| 331 | 
         
            +
                    if padding_type == 'reflect':
         
     | 
| 332 | 
         
            +
                        conv_block += [nn.ReflectionPad2d(1)]
         
     | 
| 333 | 
         
            +
                    elif padding_type == 'replicate':
         
     | 
| 334 | 
         
            +
                        conv_block += [nn.ReplicationPad2d(1)]
         
     | 
| 335 | 
         
            +
                    elif padding_type == 'zero':
         
     | 
| 336 | 
         
            +
                        p = 1
         
     | 
| 337 | 
         
            +
                    else:
         
     | 
| 338 | 
         
            +
                        raise NotImplementedError('padding [%s] is not implemented' % padding_type)
         
     | 
| 339 | 
         
            +
                    conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias)]
         
     | 
| 340 | 
         
            +
                    if norm_layer is not None:
         
     | 
| 341 | 
         
            +
                        conv_block += [norm_layer(dim)]
         
     | 
| 342 | 
         
            +
             
     | 
| 343 | 
         
            +
                    return nn.Sequential(*conv_block)
         
     | 
| 344 | 
         
            +
             
     | 
| 345 | 
         
            +
                def forward(self, x):
         
     | 
| 346 | 
         
            +
                    out = x + self.conv_block(x)
         
     | 
| 347 | 
         
            +
                    return out
         
     | 
| 348 | 
         
            +
             
     | 
| 349 | 
         
            +
             
     | 
| 350 | 
         
            +
            class D_NLayersMulti(nn.Module):
         
     | 
| 351 | 
         
            +
                def __init__(self, input_nc, ndf=64, n_layers=3,
         
     | 
| 352 | 
         
            +
                             norm_layer=nn.BatchNorm2d,  num_D=1, nl_layer=None):
         
     | 
| 353 | 
         
            +
                    super(D_NLayersMulti, self).__init__()
         
     | 
| 354 | 
         
            +
                    # st()
         
     | 
| 355 | 
         
            +
                    self.num_D = num_D
         
     | 
| 356 | 
         
            +
                    self.nl_layer=nl_layer
         
     | 
| 357 | 
         
            +
                    if num_D == 1:
         
     | 
| 358 | 
         
            +
                        layers = self.get_layers(input_nc, ndf, n_layers, norm_layer)
         
     | 
| 359 | 
         
            +
                        self.model = nn.Sequential(*layers)
         
     | 
| 360 | 
         
            +
                    else:
         
     | 
| 361 | 
         
            +
                        layers = self.get_layers(input_nc, ndf, n_layers, norm_layer)
         
     | 
| 362 | 
         
            +
                        self.add_module("model_0", nn.Sequential(*layers))
         
     | 
| 363 | 
         
            +
                        self.down = nn.functional.interpolate
         
     | 
| 364 | 
         
            +
                        for i in range(1, num_D):
         
     | 
| 365 | 
         
            +
                            ndf_i = int(round(ndf / (2**i)))
         
     | 
| 366 | 
         
            +
                            layers = self.get_layers(input_nc, ndf_i, n_layers, norm_layer)
         
     | 
| 367 | 
         
            +
                            self.add_module("model_%d" % i, nn.Sequential(*layers))
         
     | 
| 368 | 
         
            +
             
     | 
| 369 | 
         
            +
                def get_layers(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
         
     | 
| 370 | 
         
            +
                    kw = 3
         
     | 
| 371 | 
         
            +
                    padw = 1
         
     | 
| 372 | 
         
            +
                    sequence = [spectral_norm(nn.Conv2d(input_nc, ndf, kernel_size=kw,
         
     | 
| 373 | 
         
            +
                                          stride=2, padding=padw)), nn.LeakyReLU(0.2, True)]
         
     | 
| 374 | 
         
            +
             
     | 
| 375 | 
         
            +
                    nf_mult = 1
         
     | 
| 376 | 
         
            +
                    nf_mult_prev = 1
         
     | 
| 377 | 
         
            +
                    for n in range(1, n_layers):
         
     | 
| 378 | 
         
            +
                        nf_mult_prev = nf_mult
         
     | 
| 379 | 
         
            +
                        nf_mult = min(2**n, 8)
         
     | 
| 380 | 
         
            +
                        sequence += [spectral_norm(nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
         
     | 
| 381 | 
         
            +
                                      kernel_size=kw, stride=2, padding=padw))]
         
     | 
| 382 | 
         
            +
                        if norm_layer:
         
     | 
| 383 | 
         
            +
                            sequence += [norm_layer(ndf * nf_mult)]
         
     | 
| 384 | 
         
            +
             
     | 
| 385 | 
         
            +
                        sequence += [self.nl_layer()]
         
     | 
| 386 | 
         
            +
             
     | 
| 387 | 
         
            +
                    nf_mult_prev = nf_mult
         
     | 
| 388 | 
         
            +
                    nf_mult = min(2**n_layers, 8)
         
     | 
| 389 | 
         
            +
                    sequence += [spectral_norm(nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
         
     | 
| 390 | 
         
            +
                                  kernel_size=kw, stride=1, padding=padw))]
         
     | 
| 391 | 
         
            +
                    if norm_layer:
         
     | 
| 392 | 
         
            +
                        sequence += [norm_layer(ndf * nf_mult)]
         
     | 
| 393 | 
         
            +
                    sequence += [self.nl_layer()]
         
     | 
| 394 | 
         
            +
             
     | 
| 395 | 
         
            +
                    sequence += [spectral_norm(nn.Conv2d(ndf * nf_mult, 1,
         
     | 
| 396 | 
         
            +
                                           kernel_size=kw, stride=1, padding=padw))]
         
     | 
| 397 | 
         
            +
             
     | 
| 398 | 
         
            +
                    return sequence
         
     | 
| 399 | 
         
            +
             
     | 
| 400 | 
         
            +
                def forward(self, input):
         
     | 
| 401 | 
         
            +
                    if self.num_D == 1:
         
     | 
| 402 | 
         
            +
                        return self.model(input)
         
     | 
| 403 | 
         
            +
                    result = []
         
     | 
| 404 | 
         
            +
                    down = input
         
     | 
| 405 | 
         
            +
                    for i in range(self.num_D):
         
     | 
| 406 | 
         
            +
                        model = getattr(self, "model_%d" % i)
         
     | 
| 407 | 
         
            +
                        result.append(model(down))
         
     | 
| 408 | 
         
            +
                        if i != self.num_D - 1:
         
     | 
| 409 | 
         
            +
                            down = self.down(down, scale_factor=0.5, mode='bilinear')
         
     | 
| 410 | 
         
            +
                    return result
         
     | 
| 411 | 
         
            +
             
     | 
| 412 | 
         
            +
            class D_NLayers(nn.Module):
         
     | 
| 413 | 
         
            +
                """Defines a PatchGAN discriminator"""
         
     | 
| 414 | 
         
            +
             
     | 
| 415 | 
         
            +
                def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
         
     | 
| 416 | 
         
            +
                    """Construct a PatchGAN discriminator
         
     | 
| 417 | 
         
            +
                    Parameters:
         
     | 
| 418 | 
         
            +
                        input_nc (int)  -- the number of channels in input images
         
     | 
| 419 | 
         
            +
                        ndf (int)       -- the number of filters in the last conv layer
         
     | 
| 420 | 
         
            +
                        n_layers (int)  -- the number of conv layers in the discriminator
         
     | 
| 421 | 
         
            +
                        norm_layer      -- normalization layer
         
     | 
| 422 | 
         
            +
                    """
         
     | 
| 423 | 
         
            +
                    super(D_NLayers, self).__init__()
         
     | 
| 424 | 
         
            +
                    if type(norm_layer) == functools.partial:  # no need to use bias as BatchNorm2d has affine parameters
         
     | 
| 425 | 
         
            +
                        use_bias = norm_layer.func != nn.BatchNorm2d
         
     | 
| 426 | 
         
            +
                    else:
         
     | 
| 427 | 
         
            +
                        use_bias = norm_layer != nn.BatchNorm2d
         
     | 
| 428 | 
         
            +
             
     | 
| 429 | 
         
            +
                    kw = 3 
         
     | 
| 430 | 
         
            +
                    padw = 1
         
     | 
| 431 | 
         
            +
                    sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
         
     | 
| 432 | 
         
            +
                    nf_mult = 1
         
     | 
| 433 | 
         
            +
                    nf_mult_prev = 1
         
     | 
| 434 | 
         
            +
                    for n in range(1, n_layers):  # gradually increase the number of filters
         
     | 
| 435 | 
         
            +
                        nf_mult_prev = nf_mult
         
     | 
| 436 | 
         
            +
                        nf_mult = min(2 ** n, 8)
         
     | 
| 437 | 
         
            +
                        sequence += [
         
     | 
| 438 | 
         
            +
                            nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
         
     | 
| 439 | 
         
            +
                            norm_layer(ndf * nf_mult),
         
     | 
| 440 | 
         
            +
                            nn.LeakyReLU(0.2, True)
         
     | 
| 441 | 
         
            +
                        ]
         
     | 
| 442 | 
         
            +
             
     | 
| 443 | 
         
            +
                    nf_mult_prev = nf_mult
         
     | 
| 444 | 
         
            +
                    nf_mult = min(2 ** n_layers, 8)
         
     | 
| 445 | 
         
            +
                    sequence += [
         
     | 
| 446 | 
         
            +
                        nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
         
     | 
| 447 | 
         
            +
                        norm_layer(ndf * nf_mult),
         
     | 
| 448 | 
         
            +
                        nn.LeakyReLU(0.2, True)
         
     | 
| 449 | 
         
            +
                    ]
         
     | 
| 450 | 
         
            +
             
     | 
| 451 | 
         
            +
                    sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]  # output 1 channel prediction map
         
     | 
| 452 | 
         
            +
                    self.model = nn.Sequential(*sequence)
         
     | 
| 453 | 
         
            +
             
     | 
| 454 | 
         
            +
                def forward(self, input):
         
     | 
| 455 | 
         
            +
                    """Standard forward."""
         
     | 
| 456 | 
         
            +
                    return self.model(input)
         
     | 
| 457 | 
         
            +
             
     | 
| 458 | 
         
            +
             
     | 
| 459 | 
         
            +
            class G_Unet_add_input(nn.Module):
         
     | 
| 460 | 
         
            +
                def __init__(self, input_nc, output_nc, nz, num_downs, ngf=64, 
         
     | 
| 461 | 
         
            +
                             norm_layer=None, nl_layer=None, use_dropout=False, use_noise=False,
         
     | 
| 462 | 
         
            +
                             upsample='basic', device=0):
         
     | 
| 463 | 
         
            +
                    super(G_Unet_add_input, self).__init__()
         
     | 
| 464 | 
         
            +
                    self.nz = nz
         
     | 
| 465 | 
         
            +
                    max_nchn = 8
         
     | 
| 466 | 
         
            +
                    noise = []
         
     | 
| 467 | 
         
            +
                    for i in range(num_downs+1):
         
     | 
| 468 | 
         
            +
                        if use_noise:
         
     | 
| 469 | 
         
            +
                            noise.append(True)
         
     | 
| 470 | 
         
            +
                        else:
         
     | 
| 471 | 
         
            +
                            noise.append(False)
         
     | 
| 472 | 
         
            +
             
     | 
| 473 | 
         
            +
                    # construct unet structure
         
     | 
| 474 | 
         
            +
                    #print(num_downs)
         
     | 
| 475 | 
         
            +
                    unet_block = UnetBlock_A(ngf * max_nchn, ngf * max_nchn, ngf * max_nchn, noise=noise[num_downs-1], 
         
     | 
| 476 | 
         
            +
                                           innermost=True, norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
         
     | 
| 477 | 
         
            +
                    for i in range(num_downs - 5):
         
     | 
| 478 | 
         
            +
                        unet_block = UnetBlock_A(ngf * max_nchn, ngf * max_nchn, ngf * max_nchn, unet_block, noise[num_downs-i-3],
         
     | 
| 479 | 
         
            +
                                               norm_layer=norm_layer, nl_layer=nl_layer, use_dropout=use_dropout, upsample=upsample)
         
     | 
| 480 | 
         
            +
                    unet_block = UnetBlock_A(ngf * 4, ngf * 4, ngf * max_nchn, unet_block, noise[2],
         
     | 
| 481 | 
         
            +
                                           norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
         
     | 
| 482 | 
         
            +
                    unet_block = UnetBlock_A(ngf * 2, ngf * 2, ngf * 4, unet_block, noise[1],
         
     | 
| 483 | 
         
            +
                                           norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
         
     | 
| 484 | 
         
            +
                    unet_block = UnetBlock_A(ngf, ngf, ngf * 2, unet_block, noise[0],
         
     | 
| 485 | 
         
            +
                                           norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
         
     | 
| 486 | 
         
            +
                    unet_block = UnetBlock_A(input_nc + nz, output_nc, ngf, unet_block, None, 
         
     | 
| 487 | 
         
            +
                                           outermost=True, norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
         
     | 
| 488 | 
         
            +
             
     | 
| 489 | 
         
            +
                    self.model = unet_block
         
     | 
| 490 | 
         
            +
             
     | 
| 491 | 
         
            +
                def forward(self, x, z=None):
         
     | 
| 492 | 
         
            +
                    if self.nz > 0:
         
     | 
| 493 | 
         
            +
                        z_img = z.view(z.size(0), z.size(1), 1, 1).expand(
         
     | 
| 494 | 
         
            +
                            z.size(0), z.size(1), x.size(2), x.size(3))
         
     | 
| 495 | 
         
            +
                        x_with_z = torch.cat([x, z_img], 1)
         
     | 
| 496 | 
         
            +
                    else:
         
     | 
| 497 | 
         
            +
                        x_with_z = x  # no z
         
     | 
| 498 | 
         
            +
             
     | 
| 499 | 
         
            +
             
     | 
| 500 | 
         
            +
                    return torch.tanh(self.model(x_with_z))
         
     | 
| 501 | 
         
            +
                    # return self.model(x_with_z)
         
     | 
| 502 | 
         
            +
             
     | 
| 503 | 
         
            +
            class G_Unet_add_input_G(nn.Module):
         
     | 
| 504 | 
         
            +
                def __init__(self, input_nc, output_nc, nz, num_downs, ngf=64, 
         
     | 
| 505 | 
         
            +
                             norm_layer=None, nl_layer=None, use_dropout=False, use_noise=False,
         
     | 
| 506 | 
         
            +
                             upsample='basic', device=0):
         
     | 
| 507 | 
         
            +
                    super(G_Unet_add_input_G, self).__init__()
         
     | 
| 508 | 
         
            +
                    self.nz = nz
         
     | 
| 509 | 
         
            +
                    max_nchn = 8
         
     | 
| 510 | 
         
            +
                    noise = []
         
     | 
| 511 | 
         
            +
                    for i in range(num_downs+1):
         
     | 
| 512 | 
         
            +
                        if use_noise:
         
     | 
| 513 | 
         
            +
                            noise.append(True)
         
     | 
| 514 | 
         
            +
                        else:
         
     | 
| 515 | 
         
            +
                            noise.append(False)
         
     | 
| 516 | 
         
            +
                    # construct unet structure
         
     | 
| 517 | 
         
            +
                    #print(num_downs)
         
     | 
| 518 | 
         
            +
                    unet_block = UnetBlock_G(ngf * max_nchn, ngf * max_nchn, ngf * max_nchn, noise=False,
         
     | 
| 519 | 
         
            +
                                           innermost=True, norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
         
     | 
| 520 | 
         
            +
                    for i in range(num_downs - 5):
         
     | 
| 521 | 
         
            +
                        unet_block = UnetBlock_G(ngf * max_nchn, ngf * max_nchn, ngf * max_nchn, unet_block, noise=False,
         
     | 
| 522 | 
         
            +
                                               norm_layer=norm_layer, nl_layer=nl_layer, use_dropout=use_dropout, upsample=upsample)
         
     | 
| 523 | 
         
            +
                    unet_block = UnetBlock_G(ngf * 4, ngf * 4, ngf * max_nchn, unet_block, noise[2],
         
     | 
| 524 | 
         
            +
                                           norm_layer=norm_layer, nl_layer=nl_layer, upsample='basic')
         
     | 
| 525 | 
         
            +
                    unet_block = UnetBlock_G(ngf * 2, ngf * 2, ngf * 4, unet_block, noise[1],
         
     | 
| 526 | 
         
            +
                                           norm_layer=norm_layer, nl_layer=nl_layer, upsample='basic')
         
     | 
| 527 | 
         
            +
                    unet_block = UnetBlock_G(ngf, ngf, ngf * 2, unet_block, noise[0],
         
     | 
| 528 | 
         
            +
                                           norm_layer=norm_layer, nl_layer=nl_layer, upsample='basic')
         
     | 
| 529 | 
         
            +
                    unet_block = UnetBlock_G(input_nc + nz, output_nc, ngf, unet_block, None,
         
     | 
| 530 | 
         
            +
                                           outermost=True, norm_layer=norm_layer, nl_layer=nl_layer, upsample='basic')
         
     | 
| 531 | 
         
            +
             
     | 
| 532 | 
         
            +
                    self.model = unet_block
         
     | 
| 533 | 
         
            +
             
     | 
| 534 | 
         
            +
                def forward(self, x, z=None):
         
     | 
| 535 | 
         
            +
                    if self.nz > 0:
         
     | 
| 536 | 
         
            +
                        z_img = z.view(z.size(0), z.size(1), 1, 1).expand(
         
     | 
| 537 | 
         
            +
                            z.size(0), z.size(1), x.size(2), x.size(3))
         
     | 
| 538 | 
         
            +
                        x_with_z = torch.cat([x, z_img], 1)
         
     | 
| 539 | 
         
            +
                    else:
         
     | 
| 540 | 
         
            +
                        x_with_z = x  # no z
         
     | 
| 541 | 
         
            +
             
     | 
| 542 | 
         
            +
                    # return F.tanh(self.model(x_with_z))
         
     | 
| 543 | 
         
            +
                    return self.model(x_with_z)
         
     | 
| 544 | 
         
            +
             
     | 
| 545 | 
         
            +
            class G_Unet_add_input_C(nn.Module):
         
     | 
| 546 | 
         
            +
                def __init__(self, input_nc, output_nc, nz, num_downs, ngf=64, 
         
     | 
| 547 | 
         
            +
                             norm_layer=None, nl_layer=None, use_dropout=False, use_noise=False,
         
     | 
| 548 | 
         
            +
                             upsample='basic', device=0):
         
     | 
| 549 | 
         
            +
                    super(G_Unet_add_input_C, self).__init__()
         
     | 
| 550 | 
         
            +
                    self.nz = nz
         
     | 
| 551 | 
         
            +
                    max_nchn = 8
         
     | 
| 552 | 
         
            +
                    # construct unet structure
         
     | 
| 553 | 
         
            +
                    #print(num_downs)
         
     | 
| 554 | 
         
            +
                    unet_block = UnetBlock(ngf * max_nchn, ngf * max_nchn, ngf * max_nchn, noise=False,
         
     | 
| 555 | 
         
            +
                                           innermost=True, norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
         
     | 
| 556 | 
         
            +
                    for i in range(num_downs - 5):
         
     | 
| 557 | 
         
            +
                        unet_block = UnetBlock(ngf * max_nchn, ngf * max_nchn, ngf * max_nchn, unet_block, noise=False,
         
     | 
| 558 | 
         
            +
                                               norm_layer=norm_layer, nl_layer=nl_layer, use_dropout=use_dropout, upsample=upsample)
         
     | 
| 559 | 
         
            +
                    unet_block = UnetBlock(ngf * 4, ngf * 4, ngf * max_nchn, unet_block, noise=False,
         
     | 
| 560 | 
         
            +
                                           norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
         
     | 
| 561 | 
         
            +
                    unet_block = UnetBlock(ngf * 2, ngf * 2, ngf * 4, unet_block, noise=False,
         
     | 
| 562 | 
         
            +
                                           norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
         
     | 
| 563 | 
         
            +
                    unet_block = UnetBlock(ngf, ngf, ngf * 2, unet_block, noise=False,
         
     | 
| 564 | 
         
            +
                                           norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
         
     | 
| 565 | 
         
            +
                    unet_block = UnetBlock(input_nc + nz, output_nc, ngf, unet_block, noise=False,
         
     | 
| 566 | 
         
            +
                                           outermost=True, norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
         
     | 
| 567 | 
         
            +
             
     | 
| 568 | 
         
            +
                    self.model = unet_block
         
     | 
| 569 | 
         
            +
             
     | 
| 570 | 
         
            +
                def forward(self, x, z=None):
         
     | 
| 571 | 
         
            +
                    if self.nz > 0:
         
     | 
| 572 | 
         
            +
                        z_img = z.view(z.size(0), z.size(1), 1, 1).expand(
         
     | 
| 573 | 
         
            +
                            z.size(0), z.size(1), x.size(2), x.size(3))
         
     | 
| 574 | 
         
            +
                        x_with_z = torch.cat([x, z_img], 1)
         
     | 
| 575 | 
         
            +
                    else:
         
     | 
| 576 | 
         
            +
                        x_with_z = x  # no z
         
     | 
| 577 | 
         
            +
             
     | 
| 578 | 
         
            +
                    # return torch.tanh(self.model(x_with_z))
         
     | 
| 579 | 
         
            +
                    return self.model(x_with_z)
         
     | 
| 580 | 
         
            +
             
     | 
| 581 | 
         
            +
            def upsampleLayer(inplanes, outplanes, kw=1, upsample='basic', padding_type='replicate'):
         
     | 
| 582 | 
         
            +
                # padding_type = 'zero'
         
     | 
| 583 | 
         
            +
                if upsample == 'basic':
         
     | 
| 584 | 
         
            +
                    upconv = [nn.ConvTranspose2d(inplanes, outplanes, kernel_size=4, stride=2, padding=1)]#, padding_mode='replicate'
         
     | 
| 585 | 
         
            +
                elif upsample == 'bilinear' or upsample == 'nearest' or upsample == 'linear':
         
     | 
| 586 | 
         
            +
                    upconv = [nn.Upsample(scale_factor=2, mode=upsample, align_corners=True),
         
     | 
| 587 | 
         
            +
                              #nn.ReplicationPad2d(1),
         
     | 
| 588 | 
         
            +
                              nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=1, padding=0)]
         
     | 
| 589 | 
         
            +
                    # p = kw//2
         
     | 
| 590 | 
         
            +
                    # upconv = [nn.Upsample(scale_factor=2, mode=upsample, align_corners=True),
         
     | 
| 591 | 
         
            +
                    #           nn.Conv2d(inplanes, outplanes, kernel_size=kw, stride=1, padding=p, padding_mode='replicate')]
         
     | 
| 592 | 
         
            +
                else:
         
     | 
| 593 | 
         
            +
                    raise NotImplementedError(
         
     | 
| 594 | 
         
            +
                        'upsample layer [%s] not implemented' % upsample)
         
     | 
| 595 | 
         
            +
                return upconv
         
     | 
| 596 | 
         
            +
             
     | 
| 597 | 
         
            +
            class UnetBlock_G(nn.Module):
         
     | 
| 598 | 
         
            +
                def __init__(self, input_nc, outer_nc, inner_nc,
         
     | 
| 599 | 
         
            +
                             submodule=None, noise=None, outermost=False, innermost=False, 
         
     | 
| 600 | 
         
            +
                             norm_layer=None, nl_layer=None, use_dropout=False, upsample='basic', padding_type='replicate'):
         
     | 
| 601 | 
         
            +
                    super(UnetBlock_G, self).__init__()
         
     | 
| 602 | 
         
            +
                    self.outermost = outermost
         
     | 
| 603 | 
         
            +
                    p = 0
         
     | 
| 604 | 
         
            +
                    downconv = []
         
     | 
| 605 | 
         
            +
                    if padding_type == 'reflect':
         
     | 
| 606 | 
         
            +
                        downconv += [nn.ReflectionPad2d(1)]
         
     | 
| 607 | 
         
            +
                    elif padding_type == 'replicate':
         
     | 
| 608 | 
         
            +
                        downconv += [nn.ReplicationPad2d(1)]
         
     | 
| 609 | 
         
            +
                    elif padding_type == 'zero':
         
     | 
| 610 | 
         
            +
                        p = 1
         
     | 
| 611 | 
         
            +
                    else:
         
     | 
| 612 | 
         
            +
                        raise NotImplementedError(
         
     | 
| 613 | 
         
            +
                            'padding [%s] is not implemented' % padding_type)
         
     | 
| 614 | 
         
            +
             
     | 
| 615 | 
         
            +
                    downconv += [nn.Conv2d(input_nc, inner_nc,
         
     | 
| 616 | 
         
            +
                                           kernel_size=3, stride=2, padding=p)]
         
     | 
| 617 | 
         
            +
                    # downsample is different from upsample
         
     | 
| 618 | 
         
            +
                    downrelu = nn.LeakyReLU(0.2, True)
         
     | 
| 619 | 
         
            +
                    downnorm = norm_layer(inner_nc) if norm_layer is not None else None
         
     | 
| 620 | 
         
            +
                    uprelu = nl_layer()
         
     | 
| 621 | 
         
            +
                    uprelu2 = nl_layer()
         
     | 
| 622 | 
         
            +
                    uppad = nn.ReplicationPad2d(1)
         
     | 
| 623 | 
         
            +
                    upnorm = norm_layer(outer_nc) if norm_layer is not None else None
         
     | 
| 624 | 
         
            +
                    upnorm2 = norm_layer(outer_nc) if norm_layer is not None else None
         
     | 
| 625 | 
         
            +
                    self.noiseblock = ApplyNoise(outer_nc)
         
     | 
| 626 | 
         
            +
                    self.noise = noise
         
     | 
| 627 | 
         
            +
             
     | 
| 628 | 
         
            +
                    if outermost:
         
     | 
| 629 | 
         
            +
                        upconv = upsampleLayer(inner_nc * 2, inner_nc, upsample=upsample, padding_type=padding_type)
         
     | 
| 630 | 
         
            +
                        uppad = nn.ReplicationPad2d(3)
         
     | 
| 631 | 
         
            +
                        upconv2 = nn.Conv2d(inner_nc, outer_nc, kernel_size=7, padding=0)
         
     | 
| 632 | 
         
            +
                        down = downconv
         
     | 
| 633 | 
         
            +
                        up = [uprelu] + upconv
         
     | 
| 634 | 
         
            +
                        if upnorm is not None:
         
     | 
| 635 | 
         
            +
                            up += [norm_layer(inner_nc)]
         
     | 
| 636 | 
         
            +
                        # upconv = upsampleLayer(inner_nc * 2, outer_nc, upsample=upsample, padding_type=padding_type)
         
     | 
| 637 | 
         
            +
                        # upconv2 = nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=0)
         
     | 
| 638 | 
         
            +
                        # down = downconv
         
     | 
| 639 | 
         
            +
                        # up = [uprelu] + upconv
         
     | 
| 640 | 
         
            +
                        # if upnorm is not None:
         
     | 
| 641 | 
         
            +
                        #     up += [norm_layer(outer_nc)]
         
     | 
| 642 | 
         
            +
                        up +=[uprelu2, uppad, upconv2] #+ [nn.Tanh()]
         
     | 
| 643 | 
         
            +
                        model = down + [submodule] + up
         
     | 
| 644 | 
         
            +
                    elif innermost:
         
     | 
| 645 | 
         
            +
                        upconv = upsampleLayer(inner_nc, outer_nc, upsample=upsample, padding_type=padding_type)
         
     | 
| 646 | 
         
            +
                        upconv2 = nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p)
         
     | 
| 647 | 
         
            +
                        down = [downrelu] + downconv
         
     | 
| 648 | 
         
            +
                        up = [uprelu] + upconv
         
     | 
| 649 | 
         
            +
                        if upnorm is not None:
         
     | 
| 650 | 
         
            +
                            up += [upnorm]
         
     | 
| 651 | 
         
            +
                        up += [uprelu2, uppad, upconv2]
         
     | 
| 652 | 
         
            +
                        if upnorm2 is not None:
         
     | 
| 653 | 
         
            +
                            up += [upnorm2]
         
     | 
| 654 | 
         
            +
                        model = down + up
         
     | 
| 655 | 
         
            +
                    else:
         
     | 
| 656 | 
         
            +
                        upconv = upsampleLayer(inner_nc * 2, outer_nc, upsample=upsample, padding_type=padding_type)
         
     | 
| 657 | 
         
            +
                        upconv2 = nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p)
         
     | 
| 658 | 
         
            +
                        down = [downrelu] + downconv
         
     | 
| 659 | 
         
            +
                        if downnorm is not None:
         
     | 
| 660 | 
         
            +
                            down += [downnorm]
         
     | 
| 661 | 
         
            +
                        up = [uprelu] + upconv
         
     | 
| 662 | 
         
            +
                        if upnorm is not None:
         
     | 
| 663 | 
         
            +
                            up += [upnorm]
         
     | 
| 664 | 
         
            +
                        up += [uprelu2, uppad, upconv2]
         
     | 
| 665 | 
         
            +
                        if upnorm2 is not None:
         
     | 
| 666 | 
         
            +
                            up += [upnorm2]
         
     | 
| 667 | 
         
            +
             
     | 
| 668 | 
         
            +
                        if use_dropout:
         
     | 
| 669 | 
         
            +
                            model = down + [submodule] + up + [nn.Dropout(0.5)]
         
     | 
| 670 | 
         
            +
                        else:
         
     | 
| 671 | 
         
            +
                            model = down + [submodule] + up
         
     | 
| 672 | 
         
            +
             
     | 
| 673 | 
         
            +
                    self.model = nn.Sequential(*model)
         
     | 
| 674 | 
         
            +
             
     | 
| 675 | 
         
            +
                def forward(self, x):
         
     | 
| 676 | 
         
            +
                    if self.outermost:
         
     | 
| 677 | 
         
            +
                        return self.model(x)
         
     | 
| 678 | 
         
            +
                    else:
         
     | 
| 679 | 
         
            +
                        x2 = self.model(x)
         
     | 
| 680 | 
         
            +
                        if self.noise:
         
     | 
| 681 | 
         
            +
                            x2 = self.noiseblock(x2, self.noise)
         
     | 
| 682 | 
         
            +
                        return torch.cat([x2, x], 1)
         
     | 
| 683 | 
         
            +
             
     | 
| 684 | 
         
            +
             
     | 
| 685 | 
         
            +
            class UnetBlock(nn.Module):
         
     | 
| 686 | 
         
            +
                def __init__(self, input_nc, outer_nc, inner_nc,
         
     | 
| 687 | 
         
            +
                             submodule=None, noise=None, outermost=False, innermost=False, 
         
     | 
| 688 | 
         
            +
                             norm_layer=None, nl_layer=None, use_dropout=False, upsample='basic', padding_type='replicate'):
         
     | 
| 689 | 
         
            +
                    super(UnetBlock, self).__init__()
         
     | 
| 690 | 
         
            +
                    self.outermost = outermost
         
     | 
| 691 | 
         
            +
                    p = 0
         
     | 
| 692 | 
         
            +
                    downconv = []
         
     | 
| 693 | 
         
            +
                    if padding_type == 'reflect':
         
     | 
| 694 | 
         
            +
                        downconv += [nn.ReflectionPad2d(1)]
         
     | 
| 695 | 
         
            +
                    elif padding_type == 'replicate':
         
     | 
| 696 | 
         
            +
                        downconv += [nn.ReplicationPad2d(1)]
         
     | 
| 697 | 
         
            +
                    elif padding_type == 'zero':
         
     | 
| 698 | 
         
            +
                        p = 1
         
     | 
| 699 | 
         
            +
                    else:
         
     | 
| 700 | 
         
            +
                        raise NotImplementedError(
         
     | 
| 701 | 
         
            +
                            'padding [%s] is not implemented' % padding_type)
         
     | 
| 702 | 
         
            +
             
     | 
| 703 | 
         
            +
                    downconv += [nn.Conv2d(input_nc, inner_nc,
         
     | 
| 704 | 
         
            +
                                           kernel_size=3, stride=2, padding=p)]
         
     | 
| 705 | 
         
            +
                    # downsample is different from upsample
         
     | 
| 706 | 
         
            +
                    downrelu = nn.LeakyReLU(0.2, True)
         
     | 
| 707 | 
         
            +
                    downnorm = norm_layer(inner_nc) if norm_layer is not None else None
         
     | 
| 708 | 
         
            +
                    uprelu = nl_layer()
         
     | 
| 709 | 
         
            +
                    uprelu2 = nl_layer()
         
     | 
| 710 | 
         
            +
                    uppad = nn.ReplicationPad2d(1)
         
     | 
| 711 | 
         
            +
                    upnorm = norm_layer(outer_nc) if norm_layer is not None else None
         
     | 
| 712 | 
         
            +
                    upnorm2 = norm_layer(outer_nc) if norm_layer is not None else None
         
     | 
| 713 | 
         
            +
                    self.noiseblock = ApplyNoise(outer_nc)
         
     | 
| 714 | 
         
            +
                    self.noise = noise
         
     | 
| 715 | 
         
            +
             
     | 
| 716 | 
         
            +
                    if outermost:
         
     | 
| 717 | 
         
            +
                        upconv = upsampleLayer(inner_nc * 2, outer_nc, upsample=upsample, padding_type=padding_type)
         
     | 
| 718 | 
         
            +
                        upconv2 = nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p)
         
     | 
| 719 | 
         
            +
                        down = downconv
         
     | 
| 720 | 
         
            +
                        up = [uprelu] + upconv
         
     | 
| 721 | 
         
            +
                        if upnorm is not None:
         
     | 
| 722 | 
         
            +
                            up += [upnorm]
         
     | 
| 723 | 
         
            +
                        up +=[uprelu2, uppad, upconv2] #+ [nn.Tanh()]
         
     | 
| 724 | 
         
            +
                        model = down + [submodule] + up
         
     | 
| 725 | 
         
            +
                    elif innermost:
         
     | 
| 726 | 
         
            +
                        upconv = upsampleLayer(inner_nc, outer_nc, upsample=upsample, padding_type=padding_type)
         
     | 
| 727 | 
         
            +
                        upconv2 = nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p)
         
     | 
| 728 | 
         
            +
                        down = [downrelu] + downconv
         
     | 
| 729 | 
         
            +
                        up = [uprelu] + upconv
         
     | 
| 730 | 
         
            +
                        if upnorm is not None:
         
     | 
| 731 | 
         
            +
                            up += [upnorm]
         
     | 
| 732 | 
         
            +
                        up += [uprelu2, uppad, upconv2]
         
     | 
| 733 | 
         
            +
                        if upnorm2 is not None:
         
     | 
| 734 | 
         
            +
                            up += [upnorm2]
         
     | 
| 735 | 
         
            +
                        model = down + up
         
     | 
| 736 | 
         
            +
                    else:
         
     | 
| 737 | 
         
            +
                        upconv = upsampleLayer(inner_nc * 2, outer_nc, upsample=upsample, padding_type=padding_type)
         
     | 
| 738 | 
         
            +
                        upconv2 = nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p)
         
     | 
| 739 | 
         
            +
                        down = [downrelu] + downconv
         
     | 
| 740 | 
         
            +
                        if downnorm is not None:
         
     | 
| 741 | 
         
            +
                            down += [downnorm]
         
     | 
| 742 | 
         
            +
                        up = [uprelu] + upconv
         
     | 
| 743 | 
         
            +
                        if upnorm is not None:
         
     | 
| 744 | 
         
            +
                            up += [upnorm]
         
     | 
| 745 | 
         
            +
                        up += [uprelu2, uppad, upconv2]
         
     | 
| 746 | 
         
            +
                        if upnorm2 is not None:
         
     | 
| 747 | 
         
            +
                            up += [upnorm2]
         
     | 
| 748 | 
         
            +
             
     | 
| 749 | 
         
            +
                        if use_dropout:
         
     | 
| 750 | 
         
            +
                            model = down + [submodule] + up + [nn.Dropout(0.5)]
         
     | 
| 751 | 
         
            +
                        else:
         
     | 
| 752 | 
         
            +
                            model = down + [submodule] + up
         
     | 
| 753 | 
         
            +
             
     | 
| 754 | 
         
            +
                    self.model = nn.Sequential(*model)
         
     | 
| 755 | 
         
            +
             
     | 
| 756 | 
         
            +
                def forward(self, x):
         
     | 
| 757 | 
         
            +
                    if self.outermost:
         
     | 
| 758 | 
         
            +
                        return self.model(x)
         
     | 
| 759 | 
         
            +
                    else:
         
     | 
| 760 | 
         
            +
                        x2 = self.model(x)
         
     | 
| 761 | 
         
            +
                        if self.noise:
         
     | 
| 762 | 
         
            +
                            x2 = self.noiseblock(x2, self.noise)
         
     | 
| 763 | 
         
            +
                        return torch.cat([x2, x], 1)
         
     | 
| 764 | 
         
            +
             
     | 
| 765 | 
         
            +
            # Defines the submodule with skip connection.
         
     | 
| 766 | 
         
            +
            # X -------------------identity---------------------- X
         
     | 
| 767 | 
         
            +
            #   |-- downsampling -- |submodule| -- upsampling --|
         
     | 
| 768 | 
         
            +
            class UnetBlock_A(nn.Module):
         
     | 
| 769 | 
         
            +
                def __init__(self, input_nc, outer_nc, inner_nc,
         
     | 
| 770 | 
         
            +
                             submodule=None, noise=None, outermost=False, innermost=False, 
         
     | 
| 771 | 
         
            +
                             norm_layer=None, nl_layer=None, use_dropout=False, upsample='basic', padding_type='replicate'):
         
     | 
| 772 | 
         
            +
                    super(UnetBlock_A, self).__init__()
         
     | 
| 773 | 
         
            +
                    self.outermost = outermost
         
     | 
| 774 | 
         
            +
                    p = 0
         
     | 
| 775 | 
         
            +
                    downconv = []
         
     | 
| 776 | 
         
            +
                    if padding_type == 'reflect':
         
     | 
| 777 | 
         
            +
                        downconv += [nn.ReflectionPad2d(1)]
         
     | 
| 778 | 
         
            +
                    elif padding_type == 'replicate':
         
     | 
| 779 | 
         
            +
                        downconv += [nn.ReplicationPad2d(1)]
         
     | 
| 780 | 
         
            +
                    elif padding_type == 'zero':
         
     | 
| 781 | 
         
            +
                        p = 1
         
     | 
| 782 | 
         
            +
                    else:
         
     | 
| 783 | 
         
            +
                        raise NotImplementedError(
         
     | 
| 784 | 
         
            +
                            'padding [%s] is not implemented' % padding_type)
         
     | 
| 785 | 
         
            +
             
     | 
| 786 | 
         
            +
                    downconv += [spectral_norm(nn.Conv2d(input_nc, inner_nc,
         
     | 
| 787 | 
         
            +
                                           kernel_size=3, stride=2, padding=p))]
         
     | 
| 788 | 
         
            +
                    # downsample is different from upsample
         
     | 
| 789 | 
         
            +
                    downrelu = nn.LeakyReLU(0.2, True)
         
     | 
| 790 | 
         
            +
                    downnorm = norm_layer(inner_nc) if norm_layer is not None else None
         
     | 
| 791 | 
         
            +
                    uprelu = nl_layer()
         
     | 
| 792 | 
         
            +
                    uprelu2 = nl_layer()
         
     | 
| 793 | 
         
            +
                    uppad = nn.ReplicationPad2d(1)
         
     | 
| 794 | 
         
            +
                    upnorm = norm_layer(outer_nc) if norm_layer is not None else None
         
     | 
| 795 | 
         
            +
                    upnorm2 = norm_layer(outer_nc) if norm_layer is not None else None
         
     | 
| 796 | 
         
            +
                    self.noiseblock = ApplyNoise(outer_nc)
         
     | 
| 797 | 
         
            +
                    self.noise = noise
         
     | 
| 798 | 
         
            +
             
     | 
| 799 | 
         
            +
                    if outermost:
         
     | 
| 800 | 
         
            +
                        upconv = upsampleLayer(inner_nc * 1, outer_nc, upsample=upsample, padding_type=padding_type)
         
     | 
| 801 | 
         
            +
                        upconv2 = spectral_norm(nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p))
         
     | 
| 802 | 
         
            +
                        down = downconv
         
     | 
| 803 | 
         
            +
                        up = [uprelu] + upconv
         
     | 
| 804 | 
         
            +
                        if upnorm is not None:
         
     | 
| 805 | 
         
            +
                            up += [upnorm]
         
     | 
| 806 | 
         
            +
                        up +=[uprelu2, uppad, upconv2] #+ [nn.Tanh()]
         
     | 
| 807 | 
         
            +
                        model = down + [submodule] + up
         
     | 
| 808 | 
         
            +
                    elif innermost:
         
     | 
| 809 | 
         
            +
                        upconv = upsampleLayer(inner_nc, outer_nc, upsample=upsample, padding_type=padding_type)
         
     | 
| 810 | 
         
            +
                        upconv2 = spectral_norm(nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p))
         
     | 
| 811 | 
         
            +
                        down = [downrelu] + downconv
         
     | 
| 812 | 
         
            +
                        up = [uprelu] + upconv
         
     | 
| 813 | 
         
            +
                        if upnorm is not None:
         
     | 
| 814 | 
         
            +
                            up += [upnorm]
         
     | 
| 815 | 
         
            +
                        up += [uprelu2, uppad, upconv2]
         
     | 
| 816 | 
         
            +
                        if upnorm2 is not None:
         
     | 
| 817 | 
         
            +
                            up += [upnorm2]
         
     | 
| 818 | 
         
            +
                        model = down + up
         
     | 
| 819 | 
         
            +
                    else:
         
     | 
| 820 | 
         
            +
                        upconv = upsampleLayer(inner_nc * 1, outer_nc, upsample=upsample, padding_type=padding_type)
         
     | 
| 821 | 
         
            +
                        upconv2 = spectral_norm(nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p))
         
     | 
| 822 | 
         
            +
                        down = [downrelu] + downconv
         
     | 
| 823 | 
         
            +
                        if downnorm is not None:
         
     | 
| 824 | 
         
            +
                            down += [downnorm]
         
     | 
| 825 | 
         
            +
                        up = [uprelu] + upconv
         
     | 
| 826 | 
         
            +
                        if upnorm is not None:
         
     | 
| 827 | 
         
            +
                            up += [upnorm]
         
     | 
| 828 | 
         
            +
                        up += [uprelu2, uppad, upconv2]
         
     | 
| 829 | 
         
            +
                        if upnorm2 is not None:
         
     | 
| 830 | 
         
            +
                            up += [upnorm2]
         
     | 
| 831 | 
         
            +
             
     | 
| 832 | 
         
            +
                        if use_dropout:
         
     | 
| 833 | 
         
            +
                            model = down + [submodule] + up + [nn.Dropout(0.5)]
         
     | 
| 834 | 
         
            +
                        else:
         
     | 
| 835 | 
         
            +
                            model = down + [submodule] + up
         
     | 
| 836 | 
         
            +
             
     | 
| 837 | 
         
            +
                    self.model = nn.Sequential(*model)
         
     | 
| 838 | 
         
            +
             
     | 
| 839 | 
         
            +
                def forward(self, x):
         
     | 
| 840 | 
         
            +
                    if self.outermost:
         
     | 
| 841 | 
         
            +
                        return self.model(x)
         
     | 
| 842 | 
         
            +
                    else:
         
     | 
| 843 | 
         
            +
                        x2 = self.model(x)
         
     | 
| 844 | 
         
            +
                        if self.noise:
         
     | 
| 845 | 
         
            +
                            x2 = self.noiseblock(x2, self.noise)
         
     | 
| 846 | 
         
            +
                        if x2.shape[-1]==x.shape[-1]:
         
     | 
| 847 | 
         
            +
                            return x2 + x
         
     | 
| 848 | 
         
            +
                        else:
         
     | 
| 849 | 
         
            +
                            x2 = F.interpolate(x2, x.shape[2:])
         
     | 
| 850 | 
         
            +
                            return x2 + x
         
     | 
| 851 | 
         
            +
             
     | 
| 852 | 
         
            +
             
     | 
| 853 | 
         
            +
            class E_ResNet(nn.Module):
         
     | 
| 854 | 
         
            +
                def __init__(self, input_nc=3, output_nc=1, ndf=64, n_blocks=4,
         
     | 
| 855 | 
         
            +
                             norm_layer=None, nl_layer=None, vaeLike=False):
         
     | 
| 856 | 
         
            +
                    super(E_ResNet, self).__init__()
         
     | 
| 857 | 
         
            +
                    self.vaeLike = vaeLike
         
     | 
| 858 | 
         
            +
                    max_ndf = 4
         
     | 
| 859 | 
         
            +
                    conv_layers = [
         
     | 
| 860 | 
         
            +
                        nn.Conv2d(input_nc, ndf, kernel_size=3, stride=2, padding=1, bias=True)]
         
     | 
| 861 | 
         
            +
                    for n in range(1, n_blocks):
         
     | 
| 862 | 
         
            +
                        input_ndf = ndf * min(max_ndf, n)
         
     | 
| 863 | 
         
            +
                        output_ndf = ndf * min(max_ndf, n + 1)
         
     | 
| 864 | 
         
            +
                        conv_layers += [BasicBlock(input_ndf,
         
     | 
| 865 | 
         
            +
                                                   output_ndf, norm_layer, nl_layer)]
         
     | 
| 866 | 
         
            +
                    conv_layers += [nl_layer(), nn.AdaptiveAvgPool2d(4)]
         
     | 
| 867 | 
         
            +
                    if vaeLike:
         
     | 
| 868 | 
         
            +
                        self.fc = nn.Sequential(*[nn.Linear(output_ndf * 16, output_nc)])
         
     | 
| 869 | 
         
            +
                        self.fcVar = nn.Sequential(*[nn.Linear(output_ndf * 16, output_nc)])
         
     | 
| 870 | 
         
            +
                    else:
         
     | 
| 871 | 
         
            +
                        self.fc = nn.Sequential(*[nn.Linear(output_ndf * 16, output_nc)])
         
     | 
| 872 | 
         
            +
                    self.conv = nn.Sequential(*conv_layers)
         
     | 
| 873 | 
         
            +
             
     | 
| 874 | 
         
            +
                def forward(self, x):
         
     | 
| 875 | 
         
            +
                    x_conv = self.conv(x)
         
     | 
| 876 | 
         
            +
                    conv_flat = x_conv.view(x.size(0), -1)
         
     | 
| 877 | 
         
            +
                    output = self.fc(conv_flat)
         
     | 
| 878 | 
         
            +
                    if self.vaeLike:
         
     | 
| 879 | 
         
            +
                        outputVar = self.fcVar(conv_flat)
         
     | 
| 880 | 
         
            +
                        return output, outputVar
         
     | 
| 881 | 
         
            +
                    else:
         
     | 
| 882 | 
         
            +
                        return output
         
     | 
| 883 | 
         
            +
                    return output
         
     | 
| 884 | 
         
            +
             
     | 
| 885 | 
         
            +
             
     | 
| 886 | 
         
            +
            # Defines the Unet generator.
         
     | 
| 887 | 
         
            +
            # |num_downs|: number of downsamplings in UNet. For example,
         
     | 
| 888 | 
         
            +
            # if |num_downs| == 7, image of size 128x128 will become of size 1x1
         
     | 
| 889 | 
         
            +
            # at the bottleneck
         
     | 
| 890 | 
         
            +
            class G_Unet_add_all(nn.Module):
         
     | 
| 891 | 
         
            +
                def __init__(self, input_nc, output_nc, nz, num_downs, ngf=64, 
         
     | 
| 892 | 
         
            +
                             norm_layer=None, nl_layer=None, use_dropout=False, use_noise=False, upsample='basic'):
         
     | 
| 893 | 
         
            +
                    super(G_Unet_add_all, self).__init__()
         
     | 
| 894 | 
         
            +
                    self.nz = nz
         
     | 
| 895 | 
         
            +
                    self.mapping = G_mapping(self.nz, self.nz, 512, normalize_latents=False, lrmul=1)
         
     | 
| 896 | 
         
            +
                    self.truncation_psi = 0
         
     | 
| 897 | 
         
            +
                    self.truncation_cutoff = 0
         
     | 
| 898 | 
         
            +
             
     | 
| 899 | 
         
            +
                    # - 2 means we start from feature map with height and width equals 4.
         
     | 
| 900 | 
         
            +
                    # as this example, we get num_layers = 18.
         
     | 
| 901 | 
         
            +
                    num_layers = int(np.log2(512)) * 2 - 2
         
     | 
| 902 | 
         
            +
                    # Noise inputs.
         
     | 
| 903 | 
         
            +
                    self.noise_inputs = []
         
     | 
| 904 | 
         
            +
                    for layer_idx in range(num_layers):
         
     | 
| 905 | 
         
            +
                        res = layer_idx // 2 + 2
         
     | 
| 906 | 
         
            +
                        shape = [1, 1, 2 ** res, 2 ** res]
         
     | 
| 907 | 
         
            +
                        self.noise_inputs.append(torch.randn(*shape).to("cuda"))
         
     | 
| 908 | 
         
            +
             
     | 
| 909 | 
         
            +
                    # construct unet structure
         
     | 
| 910 | 
         
            +
                    unet_block = UnetBlock_with_z(ngf * 8, ngf * 8, ngf * 8, nz, submodule=None, innermost=True, 
         
     | 
| 911 | 
         
            +
                                                  norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
         
     | 
| 912 | 
         
            +
                    unet_block = UnetBlock_with_z(ngf * 8, ngf * 8, ngf * 8, nz, submodule=unet_block,
         
     | 
| 913 | 
         
            +
                                                  norm_layer=norm_layer, nl_layer=nl_layer, use_dropout=use_dropout, upsample=upsample)
         
     | 
| 914 | 
         
            +
                    for i in range(num_downs - 6):
         
     | 
| 915 | 
         
            +
                        unet_block = UnetBlock_with_z(ngf * 8, ngf * 8, ngf * 8, nz, submodule=unet_block, 
         
     | 
| 916 | 
         
            +
                                                      norm_layer=norm_layer, nl_layer=nl_layer, use_dropout=use_dropout, upsample=upsample)
         
     | 
| 917 | 
         
            +
                    unet_block = UnetBlock_with_z(ngf * 4, ngf * 4, ngf * 8, nz, submodule=unet_block,
         
     | 
| 918 | 
         
            +
                                                  norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
         
     | 
| 919 | 
         
            +
                    unet_block = UnetBlock_with_z(ngf * 2, ngf * 2, ngf * 4, nz, submodule=unet_block,
         
     | 
| 920 | 
         
            +
                                                  norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
         
     | 
| 921 | 
         
            +
                    unet_block = UnetBlock_with_z(ngf, ngf, ngf * 2, nz, submodule=unet_block, 
         
     | 
| 922 | 
         
            +
                                                  norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
         
     | 
| 923 | 
         
            +
                    unet_block = UnetBlock_with_z(input_nc, output_nc, ngf, nz, submodule=unet_block,
         
     | 
| 924 | 
         
            +
                                                  outermost=True, norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
         
     | 
| 925 | 
         
            +
                    self.model = unet_block
         
     | 
| 926 | 
         
            +
             
     | 
| 927 | 
         
            +
                def forward(self, x, z):
         
     | 
| 928 | 
         
            +
             
     | 
| 929 | 
         
            +
                    dlatents1, num_layers = self.mapping(z)
         
     | 
| 930 | 
         
            +
                    dlatents1 = dlatents1.unsqueeze(1)
         
     | 
| 931 | 
         
            +
                    dlatents1 = dlatents1.expand(-1, int(num_layers), -1)
         
     | 
| 932 | 
         
            +
             
     | 
| 933 | 
         
            +
                    # Apply truncation trick.
         
     | 
| 934 | 
         
            +
                    if self.truncation_psi and self.truncation_cutoff:
         
     | 
| 935 | 
         
            +
                        coefs = np.ones([1, num_layers, 1], dtype=np.float32)
         
     | 
| 936 | 
         
            +
                        for i in range(num_layers):
         
     | 
| 937 | 
         
            +
                            if i < self.truncation_cutoff:
         
     | 
| 938 | 
         
            +
                                coefs[:, i, :] *= self.truncation_psi
         
     | 
| 939 | 
         
            +
                        """Linear interpolation.
         
     | 
| 940 | 
         
            +
                           a + (b - a) * t (a = 0)
         
     | 
| 941 | 
         
            +
                           reduce to
         
     | 
| 942 | 
         
            +
                           b * t
         
     | 
| 943 | 
         
            +
                        """
         
     | 
| 944 | 
         
            +
                        dlatents1 = dlatents1 * torch.Tensor(coefs).to(dlatents1.device)
         
     | 
| 945 | 
         
            +
             
     | 
| 946 | 
         
            +
                    return torch.tanh(self.model(x, dlatents1, self.noise_inputs))
         
     | 
| 947 | 
         
            +
             
     | 
| 948 | 
         
            +
             
     | 
| 949 | 
         
            +
            class ApplyNoise(nn.Module):
         
     | 
| 950 | 
         
            +
                def __init__(self, channels):
         
     | 
| 951 | 
         
            +
                    super().__init__()
         
     | 
| 952 | 
         
            +
                    self.channels = channels
         
     | 
| 953 | 
         
            +
                    self.weight = nn.Parameter(torch.randn(channels), requires_grad=True)
         
     | 
| 954 | 
         
            +
                    self.bias = nn.Parameter(torch.zeros(channels), requires_grad=True)
         
     | 
| 955 | 
         
            +
             
     | 
| 956 | 
         
            +
                def forward(self, x, noise):
         
     | 
| 957 | 
         
            +
                    W,_ = torch.split(self.weight.view(1, -1, 1, 1), self.channels // 2, dim=1)
         
     | 
| 958 | 
         
            +
                    B,_ = torch.split(self.bias.view(1, -1, 1, 1), self.channels // 2, dim=1)
         
     | 
| 959 | 
         
            +
                    Z = torch.zeros_like(W)
         
     | 
| 960 | 
         
            +
                    w = torch.cat([W,Z], dim=1).to(x.device)
         
     | 
| 961 | 
         
            +
                    b = torch.cat([B,Z], dim=1).to(x.device)
         
     | 
| 962 | 
         
            +
                    adds = w * torch.randn_like(x) + b
         
     | 
| 963 | 
         
            +
                    return x + adds.type_as(x)
         
     | 
| 964 | 
         
            +
             
     | 
| 965 | 
         
            +
             
     | 
| 966 | 
         
            +
            class FC(nn.Module):
         
     | 
| 967 | 
         
            +
                def __init__(self,
         
     | 
| 968 | 
         
            +
                             in_channels,
         
     | 
| 969 | 
         
            +
                             out_channels,
         
     | 
| 970 | 
         
            +
                             gain=2**(0.5),
         
     | 
| 971 | 
         
            +
                             use_wscale=False,
         
     | 
| 972 | 
         
            +
                             lrmul=1.0,
         
     | 
| 973 | 
         
            +
                             bias=True):
         
     | 
| 974 | 
         
            +
                    """
         
     | 
| 975 | 
         
            +
                        The complete conversion of Dense/FC/Linear Layer of original Tensorflow version.
         
     | 
| 976 | 
         
            +
                    """
         
     | 
| 977 | 
         
            +
                    super(FC, self).__init__()
         
     | 
| 978 | 
         
            +
                    he_std = gain * in_channels ** (-0.5)  # He init
         
     | 
| 979 | 
         
            +
                    if use_wscale:
         
     | 
| 980 | 
         
            +
                        init_std = 1.0 / lrmul
         
     | 
| 981 | 
         
            +
                        self.w_lrmul = he_std * lrmul
         
     | 
| 982 | 
         
            +
                    else:
         
     | 
| 983 | 
         
            +
                        init_std = he_std / lrmul
         
     | 
| 984 | 
         
            +
                        self.w_lrmul = lrmul
         
     | 
| 985 | 
         
            +
             
     | 
| 986 | 
         
            +
                    self.weight = torch.nn.Parameter(torch.randn(out_channels, in_channels) * init_std)
         
     | 
| 987 | 
         
            +
                    if bias:
         
     | 
| 988 | 
         
            +
                        self.bias = torch.nn.Parameter(torch.zeros(out_channels))
         
     | 
| 989 | 
         
            +
                        self.b_lrmul = lrmul
         
     | 
| 990 | 
         
            +
                    else:
         
     | 
| 991 | 
         
            +
                        self.bias = None
         
     | 
| 992 | 
         
            +
             
     | 
| 993 | 
         
            +
                def forward(self, x):
         
     | 
| 994 | 
         
            +
                    if self.bias is not None:
         
     | 
| 995 | 
         
            +
                        out = F.linear(x, self.weight * self.w_lrmul, self.bias * self.b_lrmul)
         
     | 
| 996 | 
         
            +
                    else:
         
     | 
| 997 | 
         
            +
                        out = F.linear(x, self.weight * self.w_lrmul)
         
     | 
| 998 | 
         
            +
                    out = F.leaky_relu(out, 0.2, inplace=True)
         
     | 
| 999 | 
         
            +
                    return out
         
     | 
| 1000 | 
         
            +
             
     | 
| 1001 | 
         
            +
             
     | 
| 1002 | 
         
            +
            class ApplyStyle(nn.Module):
         
     | 
| 1003 | 
         
            +
                """
         
     | 
| 1004 | 
         
            +
                    @ref: https://github.com/lernapparat/lernapparat/blob/master/style_gan/pytorch_style_gan.ipynb
         
     | 
| 1005 | 
         
            +
                """
         
     | 
| 1006 | 
         
            +
                def __init__(self, latent_size, channels, use_wscale, nl_layer):
         
     | 
| 1007 | 
         
            +
                    super(ApplyStyle, self).__init__()
         
     | 
| 1008 | 
         
            +
                    modules = [nn.Linear(latent_size, channels*2)]
         
     | 
| 1009 | 
         
            +
                    if nl_layer:
         
     | 
| 1010 | 
         
            +
                        modules += [nl_layer()]
         
     | 
| 1011 | 
         
            +
                    self.linear = nn.Sequential(*modules)
         
     | 
| 1012 | 
         
            +
             
     | 
| 1013 | 
         
            +
                def forward(self, x, latent):
         
     | 
| 1014 | 
         
            +
                    style = self.linear(latent)  # style => [batch_size, n_channels*2]
         
     | 
| 1015 | 
         
            +
                    shape = [-1, 2, x.size(1), 1, 1]
         
     | 
| 1016 | 
         
            +
                    style = style.view(shape)    # [batch_size, 2, n_channels, ...]
         
     | 
| 1017 | 
         
            +
                    x = x * (style[:, 0] + 1.) + style[:, 1]
         
     | 
| 1018 | 
         
            +
                    return x
         
     | 
| 1019 | 
         
            +
             
     | 
| 1020 | 
         
            +
            class PixelNorm(nn.Module):
         
     | 
| 1021 | 
         
            +
                def __init__(self, epsilon=1e-8):
         
     | 
| 1022 | 
         
            +
                    """
         
     | 
| 1023 | 
         
            +
                        @notice: avoid in-place ops.
         
     | 
| 1024 | 
         
            +
                        https://discuss.pytorch.org/t/encounter-the-runtimeerror-one-of-the-variables-needed-for-gradient-computation-has-been-modified-by-an-inplace-operation/836/3
         
     | 
| 1025 | 
         
            +
                    """
         
     | 
| 1026 | 
         
            +
                    super(PixelNorm, self).__init__()
         
     | 
| 1027 | 
         
            +
                    self.epsilon = epsilon
         
     | 
| 1028 | 
         
            +
             
     | 
| 1029 | 
         
            +
                def forward(self, x):
         
     | 
| 1030 | 
         
            +
                    tmp  = torch.mul(x, x) # or x ** 2
         
     | 
| 1031 | 
         
            +
                    tmp1 = torch.rsqrt(torch.mean(tmp, dim=1, keepdim=True) + self.epsilon)
         
     | 
| 1032 | 
         
            +
             
     | 
| 1033 | 
         
            +
                    return x * tmp1
         
     | 
| 1034 | 
         
            +
             
     | 
| 1035 | 
         
            +
             
     | 
| 1036 | 
         
            +
            class InstanceNorm(nn.Module):
         
     | 
| 1037 | 
         
            +
                def __init__(self, epsilon=1e-8):
         
     | 
| 1038 | 
         
            +
                    """
         
     | 
| 1039 | 
         
            +
                        @notice: avoid in-place ops.
         
     | 
| 1040 | 
         
            +
                        https://discuss.pytorch.org/t/encounter-the-runtimeerror-one-of-the-variables-needed-for-gradient-computation-has-been-modified-by-an-inplace-operation/836/3
         
     | 
| 1041 | 
         
            +
                    """
         
     | 
| 1042 | 
         
            +
                    super(InstanceNorm, self).__init__()
         
     | 
| 1043 | 
         
            +
                    self.epsilon = epsilon
         
     | 
| 1044 | 
         
            +
             
     | 
| 1045 | 
         
            +
                def forward(self, x):
         
     | 
| 1046 | 
         
            +
                    x   = x - torch.mean(x, (2, 3), True)
         
     | 
| 1047 | 
         
            +
                    tmp = torch.mul(x, x) # or x ** 2
         
     | 
| 1048 | 
         
            +
                    tmp = torch.rsqrt(torch.mean(tmp, (2, 3), True) + self.epsilon)
         
     | 
| 1049 | 
         
            +
                    return x * tmp
         
     | 
| 1050 | 
         
            +
             
     | 
| 1051 | 
         
            +
             
     | 
| 1052 | 
         
            +
            class LayerEpilogue(nn.Module):
         
     | 
| 1053 | 
         
            +
                def __init__(self, channels, dlatent_size, use_wscale, use_noise,
         
     | 
| 1054 | 
         
            +
                             use_pixel_norm, use_instance_norm, use_styles, nl_layer=None):
         
     | 
| 1055 | 
         
            +
                    super(LayerEpilogue, self).__init__()
         
     | 
| 1056 | 
         
            +
                    self.use_noise = use_noise
         
     | 
| 1057 | 
         
            +
                    if use_noise:
         
     | 
| 1058 | 
         
            +
                        self.noise = ApplyNoise(channels)
         
     | 
| 1059 | 
         
            +
                    self.act = nn.LeakyReLU(negative_slope=0.2)
         
     | 
| 1060 | 
         
            +
             
     | 
| 1061 | 
         
            +
                    if use_pixel_norm:
         
     | 
| 1062 | 
         
            +
                        self.pixel_norm = PixelNorm()
         
     | 
| 1063 | 
         
            +
                    else:
         
     | 
| 1064 | 
         
            +
                        self.pixel_norm = None
         
     | 
| 1065 | 
         
            +
             
     | 
| 1066 | 
         
            +
                    if use_instance_norm:
         
     | 
| 1067 | 
         
            +
                        self.instance_norm = InstanceNorm()
         
     | 
| 1068 | 
         
            +
                    else:
         
     | 
| 1069 | 
         
            +
                        self.instance_norm = None
         
     | 
| 1070 | 
         
            +
             
     | 
| 1071 | 
         
            +
                    if use_styles:
         
     | 
| 1072 | 
         
            +
                        self.style_mod = ApplyStyle(dlatent_size, channels, use_wscale=use_wscale, nl_layer=nl_layer)
         
     | 
| 1073 | 
         
            +
                    else:
         
     | 
| 1074 | 
         
            +
                        self.style_mod = None
         
     | 
| 1075 | 
         
            +
             
     | 
| 1076 | 
         
            +
                def forward(self, x, noise, dlatents_in_slice=None):
         
     | 
| 1077 | 
         
            +
                    # if noise is not None:
         
     | 
| 1078 | 
         
            +
                    if self.use_noise:
         
     | 
| 1079 | 
         
            +
                        x = self.noise(x, noise)
         
     | 
| 1080 | 
         
            +
                    x = self.act(x)
         
     | 
| 1081 | 
         
            +
                    if self.pixel_norm is not None:
         
     | 
| 1082 | 
         
            +
                        x = self.pixel_norm(x)
         
     | 
| 1083 | 
         
            +
                    if self.instance_norm is not None:
         
     | 
| 1084 | 
         
            +
                        x = self.instance_norm(x)
         
     | 
| 1085 | 
         
            +
                    if self.style_mod is not None:
         
     | 
| 1086 | 
         
            +
                        x = self.style_mod(x, dlatents_in_slice)
         
     | 
| 1087 | 
         
            +
             
     | 
| 1088 | 
         
            +
                    return x
         
     | 
| 1089 | 
         
            +
             
     | 
| 1090 | 
         
            +
            class G_mapping(nn.Module):
         
     | 
| 1091 | 
         
            +
                def __init__(self,
         
     | 
| 1092 | 
         
            +
                             mapping_fmaps=512,
         
     | 
| 1093 | 
         
            +
                             dlatent_size=512,
         
     | 
| 1094 | 
         
            +
                             resolution=512,
         
     | 
| 1095 | 
         
            +
                             normalize_latents=True,  # Normalize latent vectors (Z) before feeding them to the mapping layers?
         
     | 
| 1096 | 
         
            +
                             use_wscale=True,         # Enable equalized learning rate?
         
     | 
| 1097 | 
         
            +
                             lrmul=0.01,              # Learning rate multiplier for the mapping layers.
         
     | 
| 1098 | 
         
            +
                             gain=2**(0.5),            # original gain in tensorflow.
         
     | 
| 1099 | 
         
            +
                             nl_layer=None
         
     | 
| 1100 | 
         
            +
                             ):
         
     | 
| 1101 | 
         
            +
                    super(G_mapping, self).__init__()
         
     | 
| 1102 | 
         
            +
                    self.mapping_fmaps = mapping_fmaps
         
     | 
| 1103 | 
         
            +
                    func = [
         
     | 
| 1104 | 
         
            +
                        nn.Linear(self.mapping_fmaps, dlatent_size)
         
     | 
| 1105 | 
         
            +
                    ]
         
     | 
| 1106 | 
         
            +
                    if nl_layer:
         
     | 
| 1107 | 
         
            +
                        func += [nl_layer()]
         
     | 
| 1108 | 
         
            +
             
     | 
| 1109 | 
         
            +
                    for j in range(0,4):
         
     | 
| 1110 | 
         
            +
                        func += [
         
     | 
| 1111 | 
         
            +
                            nn.Linear(dlatent_size, dlatent_size)
         
     | 
| 1112 | 
         
            +
                        ]
         
     | 
| 1113 | 
         
            +
                        if nl_layer:
         
     | 
| 1114 | 
         
            +
                            func += [nl_layer()]
         
     | 
| 1115 | 
         
            +
             
     | 
| 1116 | 
         
            +
                    self.func = nn.Sequential(*func)
         
     | 
| 1117 | 
         
            +
                        #FC(self.mapping_fmaps, dlatent_size, gain, lrmul=lrmul, use_wscale=use_wscale),
         
     | 
| 1118 | 
         
            +
                        #FC(dlatent_size, dlatent_size, gain, lrmul=lrmul, use_wscale=use_wscale),
         
     | 
| 1119 | 
         
            +
             
     | 
| 1120 | 
         
            +
                    self.normalize_latents = normalize_latents
         
     | 
| 1121 | 
         
            +
                    self.resolution_log2 = int(np.log2(resolution))
         
     | 
| 1122 | 
         
            +
                    self.num_layers = self.resolution_log2 * 2 - 2
         
     | 
| 1123 | 
         
            +
                    self.pixel_norm = PixelNorm()
         
     | 
| 1124 | 
         
            +
                    # - 2 means we start from feature map with height and width equals 4.
         
     | 
| 1125 | 
         
            +
                    # as this example, we get num_layers = 18.
         
     | 
| 1126 | 
         
            +
             
     | 
| 1127 | 
         
            +
                def forward(self, x):
         
     | 
| 1128 | 
         
            +
                    if self.normalize_latents:
         
     | 
| 1129 | 
         
            +
                        x = self.pixel_norm(x)
         
     | 
| 1130 | 
         
            +
                    out = self.func(x)
         
     | 
| 1131 | 
         
            +
                    return out, self.num_layers
         
     | 
| 1132 | 
         
            +
             
     | 
| 1133 | 
         
            +
            class UnetBlock_with_z(nn.Module):
         
     | 
| 1134 | 
         
            +
                def __init__(self, input_nc, outer_nc, inner_nc, nz=0, 
         
     | 
| 1135 | 
         
            +
                             submodule=None, outermost=False, innermost=False, 
         
     | 
| 1136 | 
         
            +
                             norm_layer=None, nl_layer=None, use_dropout=False, 
         
     | 
| 1137 | 
         
            +
                             upsample='basic', padding_type='replicate'):
         
     | 
| 1138 | 
         
            +
                    super(UnetBlock_with_z, self).__init__()
         
     | 
| 1139 | 
         
            +
                    p = 0
         
     | 
| 1140 | 
         
            +
                    downconv = []
         
     | 
| 1141 | 
         
            +
                    if padding_type == 'reflect':
         
     | 
| 1142 | 
         
            +
                        downconv += [nn.ReflectionPad2d(1)]
         
     | 
| 1143 | 
         
            +
                    elif padding_type == 'replicate':
         
     | 
| 1144 | 
         
            +
                        downconv += [nn.ReplicationPad2d(1)]
         
     | 
| 1145 | 
         
            +
                    elif padding_type == 'zero':
         
     | 
| 1146 | 
         
            +
                        p = 1
         
     | 
| 1147 | 
         
            +
                    else:
         
     | 
| 1148 | 
         
            +
                        raise NotImplementedError(
         
     | 
| 1149 | 
         
            +
                            'padding [%s] is not implemented' % padding_type)
         
     | 
| 1150 | 
         
            +
             
     | 
| 1151 | 
         
            +
                    self.outermost = outermost
         
     | 
| 1152 | 
         
            +
                    self.innermost = innermost
         
     | 
| 1153 | 
         
            +
                    self.nz = nz
         
     | 
| 1154 | 
         
            +
             
     | 
| 1155 | 
         
            +
                    # input_nc = input_nc + nz
         
     | 
| 1156 | 
         
            +
                    downconv += [spectral_norm(nn.Conv2d(input_nc, inner_nc,
         
     | 
| 1157 | 
         
            +
                                           kernel_size=3, stride=2, padding=p))]
         
     | 
| 1158 | 
         
            +
                    # downsample is different from upsample
         
     | 
| 1159 | 
         
            +
                    downrelu = nn.LeakyReLU(0.2, True)
         
     | 
| 1160 | 
         
            +
                    downnorm = norm_layer(inner_nc) if norm_layer is not None else None
         
     | 
| 1161 | 
         
            +
                    uprelu = nl_layer()
         
     | 
| 1162 | 
         
            +
                    uprelu2 = nl_layer()
         
     | 
| 1163 | 
         
            +
                    uppad = nn.ReplicationPad2d(1)
         
     | 
| 1164 | 
         
            +
                    upnorm = norm_layer(outer_nc) if norm_layer is not None else None
         
     | 
| 1165 | 
         
            +
                    upnorm2 = norm_layer(outer_nc) if norm_layer is not None else None
         
     | 
| 1166 | 
         
            +
             
     | 
| 1167 | 
         
            +
                    use_styles=False
         
     | 
| 1168 | 
         
            +
                    uprelu = nl_layer()
         
     | 
| 1169 | 
         
            +
                    if self.nz >0:
         
     | 
| 1170 | 
         
            +
                        use_styles=True
         
     | 
| 1171 | 
         
            +
             
     | 
| 1172 | 
         
            +
                    if outermost:
         
     | 
| 1173 | 
         
            +
                        self.adaIn = LayerEpilogue(inner_nc, self.nz, use_wscale=True, use_noise=False,
         
     | 
| 1174 | 
         
            +
                                                    use_pixel_norm=True, use_instance_norm=True, use_styles=use_styles, nl_layer=nl_layer)
         
     | 
| 1175 | 
         
            +
                        upconv = upsampleLayer(
         
     | 
| 1176 | 
         
            +
                            inner_nc , outer_nc, upsample=upsample, padding_type=padding_type)
         
     | 
| 1177 | 
         
            +
                        upconv2 = spectral_norm(nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p))
         
     | 
| 1178 | 
         
            +
                        down = downconv
         
     | 
| 1179 | 
         
            +
                        up = [uprelu] + upconv 
         
     | 
| 1180 | 
         
            +
                        if upnorm is not None:
         
     | 
| 1181 | 
         
            +
                            up += [upnorm]
         
     | 
| 1182 | 
         
            +
                        up +=[uprelu2, uppad, upconv2] #+ [nn.Tanh()]
         
     | 
| 1183 | 
         
            +
                    elif innermost:
         
     | 
| 1184 | 
         
            +
                        self.adaIn = LayerEpilogue(inner_nc, self.nz, use_wscale=True, use_noise=True,
         
     | 
| 1185 | 
         
            +
                                                    use_pixel_norm=True, use_instance_norm=True, use_styles=use_styles, nl_layer=nl_layer)
         
     | 
| 1186 | 
         
            +
                        upconv = upsampleLayer(
         
     | 
| 1187 | 
         
            +
                            inner_nc, outer_nc, upsample=upsample, padding_type=padding_type)
         
     | 
| 1188 | 
         
            +
                        upconv2 = spectral_norm(nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p))
         
     | 
| 1189 | 
         
            +
                        down = [downrelu] + downconv
         
     | 
| 1190 | 
         
            +
                        up = [uprelu] + upconv
         
     | 
| 1191 | 
         
            +
                        if norm_layer is not None:
         
     | 
| 1192 | 
         
            +
                            up += [norm_layer(outer_nc)]
         
     | 
| 1193 | 
         
            +
                        up += [uprelu2, uppad, upconv2]
         
     | 
| 1194 | 
         
            +
                        if upnorm2 is not None:
         
     | 
| 1195 | 
         
            +
                            up += [upnorm2]
         
     | 
| 1196 | 
         
            +
                    else:
         
     | 
| 1197 | 
         
            +
                        self.adaIn = LayerEpilogue(inner_nc, self.nz, use_wscale=True, use_noise=False,
         
     | 
| 1198 | 
         
            +
                                                    use_pixel_norm=True, use_instance_norm=True, use_styles=use_styles, nl_layer=nl_layer)
         
     | 
| 1199 | 
         
            +
                        upconv = upsampleLayer(
         
     | 
| 1200 | 
         
            +
                            inner_nc , outer_nc, upsample=upsample, padding_type=padding_type)
         
     | 
| 1201 | 
         
            +
                        upconv2 = spectral_norm(nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p))
         
     | 
| 1202 | 
         
            +
                        down = [downrelu] + downconv
         
     | 
| 1203 | 
         
            +
                        if norm_layer is not None:
         
     | 
| 1204 | 
         
            +
                            down += [norm_layer(inner_nc)]
         
     | 
| 1205 | 
         
            +
                        up = [uprelu] + upconv
         
     | 
| 1206 | 
         
            +
             
     | 
| 1207 | 
         
            +
                        if norm_layer is not None:
         
     | 
| 1208 | 
         
            +
                            up += [norm_layer(outer_nc)]
         
     | 
| 1209 | 
         
            +
                        up += [uprelu2, uppad, upconv2]
         
     | 
| 1210 | 
         
            +
                        if upnorm2 is not None:
         
     | 
| 1211 | 
         
            +
                            up += [upnorm2]
         
     | 
| 1212 | 
         
            +
             
     | 
| 1213 | 
         
            +
                        if use_dropout:
         
     | 
| 1214 | 
         
            +
                            up += [nn.Dropout(0.5)]
         
     | 
| 1215 | 
         
            +
                    self.down = nn.Sequential(*down)
         
     | 
| 1216 | 
         
            +
                    self.submodule = submodule
         
     | 
| 1217 | 
         
            +
                    self.up = nn.Sequential(*up)
         
     | 
| 1218 | 
         
            +
             
     | 
| 1219 | 
         
            +
             
     | 
| 1220 | 
         
            +
                def forward(self, x, z, noise):
         
     | 
| 1221 | 
         
            +
                    if self.outermost:
         
     | 
| 1222 | 
         
            +
                        x1 = self.down(x)
         
     | 
| 1223 | 
         
            +
                        x2 = self.submodule(x1, z[:,2:], noise[2:])
         
     | 
| 1224 | 
         
            +
                        return self.up(x2)
         
     | 
| 1225 | 
         
            +
             
     | 
| 1226 | 
         
            +
                    elif self.innermost:
         
     | 
| 1227 | 
         
            +
                        x1 = self.down(x)
         
     | 
| 1228 | 
         
            +
                        x_and_z = self.adaIn(x1, noise[0], z[:,0])
         
     | 
| 1229 | 
         
            +
                        x2 = self.up(x_and_z)
         
     | 
| 1230 | 
         
            +
                        x2 = F.interpolate(x2, x.shape[2:])
         
     | 
| 1231 | 
         
            +
                        return x2 + x
         
     | 
| 1232 | 
         
            +
             
     | 
| 1233 | 
         
            +
                    else:
         
     | 
| 1234 | 
         
            +
                        x1 = self.down(x)
         
     | 
| 1235 | 
         
            +
                        x2 = self.submodule(x1, z[:,2:], noise[2:])
         
     | 
| 1236 | 
         
            +
                        x_and_z = self.adaIn(x2, noise[0], z[:,0])
         
     | 
| 1237 | 
         
            +
                        return self.up(x_and_z) + x
         
     | 
| 1238 | 
         
            +
             
     | 
| 1239 | 
         
            +
             
     | 
| 1240 | 
         
            +
            class E_NLayers(nn.Module):
         
     | 
| 1241 | 
         
            +
                def __init__(self, input_nc, output_nc=1, ndf=64, n_layers=4,
         
     | 
| 1242 | 
         
            +
                             norm_layer=None, nl_layer=None, vaeLike=False):
         
     | 
| 1243 | 
         
            +
                    super(E_NLayers, self).__init__()
         
     | 
| 1244 | 
         
            +
                    self.vaeLike = vaeLike
         
     | 
| 1245 | 
         
            +
             
     | 
| 1246 | 
         
            +
                    kw, padw = 3, 1
         
     | 
| 1247 | 
         
            +
                    sequence = [spectral_norm(nn.Conv2d(input_nc, ndf, kernel_size=kw,
         
     | 
| 1248 | 
         
            +
                                          stride=2, padding=padw, padding_mode='replicate')), nl_layer()]
         
     | 
| 1249 | 
         
            +
             
     | 
| 1250 | 
         
            +
                    nf_mult = 1
         
     | 
| 1251 | 
         
            +
                    nf_mult_prev = 1
         
     | 
| 1252 | 
         
            +
                    for n in range(1, n_layers):
         
     | 
| 1253 | 
         
            +
                        nf_mult_prev = nf_mult
         
     | 
| 1254 | 
         
            +
                        nf_mult = min(2**n, 8)
         
     | 
| 1255 | 
         
            +
                        sequence += [spectral_norm(nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
         
     | 
| 1256 | 
         
            +
                                     kernel_size=kw, stride=2, padding=padw, padding_mode='replicate'))]
         
     | 
| 1257 | 
         
            +
                        if norm_layer is not None:
         
     | 
| 1258 | 
         
            +
                            sequence += [norm_layer(ndf * nf_mult)]
         
     | 
| 1259 | 
         
            +
                        sequence += [nl_layer()]
         
     | 
| 1260 | 
         
            +
                    sequence += [nn.AdaptiveAvgPool2d(4)]
         
     | 
| 1261 | 
         
            +
                    self.conv = nn.Sequential(*sequence)
         
     | 
| 1262 | 
         
            +
                    self.fc = nn.Sequential(*[spectral_norm(nn.Linear(ndf * nf_mult * 16, output_nc))])
         
     | 
| 1263 | 
         
            +
                    if vaeLike:
         
     | 
| 1264 | 
         
            +
                        self.fcVar = nn.Sequential(*[spectral_norm(nn.Linear(ndf * nf_mult * 16, output_nc))])
         
     | 
| 1265 | 
         
            +
             
     | 
| 1266 | 
         
            +
                def forward(self, x):
         
     | 
| 1267 | 
         
            +
                    x_conv = self.conv(x)
         
     | 
| 1268 | 
         
            +
                    conv_flat = x_conv.view(x.size(0), -1)
         
     | 
| 1269 | 
         
            +
                    output = self.fc(conv_flat)
         
     | 
| 1270 | 
         
            +
                    if self.vaeLike:
         
     | 
| 1271 | 
         
            +
                        outputVar = self.fcVar(conv_flat)
         
     | 
| 1272 | 
         
            +
                        return output, outputVar
         
     | 
| 1273 | 
         
            +
                    return output
         
     | 
| 1274 | 
         
            +
             
     | 
| 1275 | 
         
            +
            class BasicBlock(nn.Module):
         
     | 
| 1276 | 
         
            +
                def __init__(self, inplanes, outplanes):
         
     | 
| 1277 | 
         
            +
                    super(BasicBlock, self).__init__()
         
     | 
| 1278 | 
         
            +
                    layers = []
         
     | 
| 1279 | 
         
            +
                    norm_layer=get_norm_layer(norm_type='layer') #functools.partial(LayerNorm)
         
     | 
| 1280 | 
         
            +
                    # norm_layer = None
         
     | 
| 1281 | 
         
            +
                    nl_layer=nn.ReLU()
         
     | 
| 1282 | 
         
            +
                    if norm_layer is not None:
         
     | 
| 1283 | 
         
            +
                        layers += [norm_layer(inplanes)]
         
     | 
| 1284 | 
         
            +
                    layers += [nl_layer]
         
     | 
| 1285 | 
         
            +
                    layers += [nn.ReplicationPad2d(1),
         
     | 
| 1286 | 
         
            +
                               nn.Conv2d(inplanes, outplanes, kernel_size=3, stride=1,
         
     | 
| 1287 | 
         
            +
                                 padding=0, bias=True)]
         
     | 
| 1288 | 
         
            +
                    self.conv = nn.Sequential(*layers)
         
     | 
| 1289 | 
         
            +
             
     | 
| 1290 | 
         
            +
                def forward(self, x):
         
     | 
| 1291 | 
         
            +
                    return self.conv(x)
         
     | 
| 1292 | 
         
            +
             
     | 
| 1293 | 
         
            +
             
     | 
| 1294 | 
         
            +
            def define_SVAE(inc=96, outc=3, outplanes=64, blocks=1, netVAE='SVAE', model_name='', load_ext=True, save_dir='',
         
     | 
| 1295 | 
         
            +
                init_type="normal", init_gain=0.02, gpu_ids=[]):
         
     | 
| 1296 | 
         
            +
                if netVAE == 'SVAE':
         
     | 
| 1297 | 
         
            +
                    net = ScreenVAE(inc=inc, outc=outc, outplanes=outplanes, blocks=blocks, save_dir=save_dir, 
         
     | 
| 1298 | 
         
            +
                        init_type=init_type, init_gain=init_gain, gpu_ids=gpu_ids)
         
     | 
| 1299 | 
         
            +
                else:
         
     | 
| 1300 | 
         
            +
                    raise NotImplementedError('Encoder model name [%s] is not recognized' % net)
         
     | 
| 1301 | 
         
            +
                init_net(net, init_type=init_type, init_gain=init_gain, gpu_ids=gpu_ids)
         
     | 
| 1302 | 
         
            +
                net.load_networks('latest')
         
     | 
| 1303 | 
         
            +
                return net
         
     | 
| 1304 | 
         
            +
             
     | 
| 1305 | 
         
            +
             
     | 
| 1306 | 
         
            +
            class ScreenVAE(nn.Module):
         
     | 
| 1307 | 
         
            +
                def __init__(self,inc=1,outc=4, outplanes=64, downs=5, blocks=2,load_ext=True, save_dir='',init_type="normal", init_gain=0.02, gpu_ids=[]):
         
     | 
| 1308 | 
         
            +
                    super(ScreenVAE, self).__init__()
         
     | 
| 1309 | 
         
            +
                    self.inc = inc
         
     | 
| 1310 | 
         
            +
                    self.outc = outc
         
     | 
| 1311 | 
         
            +
                    self.save_dir = save_dir
         
     | 
| 1312 | 
         
            +
                    norm_layer=functools.partial(LayerNormWarpper)
         
     | 
| 1313 | 
         
            +
                    nl_layer=nn.LeakyReLU
         
     | 
| 1314 | 
         
            +
             
     | 
| 1315 | 
         
            +
                    self.model_names=['enc','dec']
         
     | 
| 1316 | 
         
            +
                    self.enc=define_C(inc+1, outc*2, 0, 24, netC='resnet_6blocks', 
         
     | 
| 1317 | 
         
            +
                                                  norm='layer', nl='lrelu', use_dropout=True, init_type='kaiming', 
         
     | 
| 1318 | 
         
            +
                                                  gpu_ids=gpu_ids, upsample='bilinear')
         
     | 
| 1319 | 
         
            +
                    self.dec=define_G(outc, inc, 0, 48, netG='unet_128_G', 
         
     | 
| 1320 | 
         
            +
                                                  norm='layer', nl='lrelu', use_dropout=True, init_type='kaiming', 
         
     | 
| 1321 | 
         
            +
                                                  gpu_ids=gpu_ids, where_add='input', upsample='bilinear', use_noise=True)
         
     | 
| 1322 | 
         
            +
             
     | 
| 1323 | 
         
            +
                    for param in self.parameters():
         
     | 
| 1324 | 
         
            +
                        param.requires_grad = False
         
     | 
| 1325 | 
         
            +
             
     | 
| 1326 | 
         
            +
                def load_networks(self, epoch):
         
     | 
| 1327 | 
         
            +
                    """Load all the networks from the disk.
         
     | 
| 1328 | 
         
            +
             
     | 
| 1329 | 
         
            +
                    Parameters:
         
     | 
| 1330 | 
         
            +
                        epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
         
     | 
| 1331 | 
         
            +
                    """
         
     | 
| 1332 | 
         
            +
                    for name in self.model_names:
         
     | 
| 1333 | 
         
            +
                        if isinstance(name, str):
         
     | 
| 1334 | 
         
            +
                            load_filename = '%s_net_%s.pth' % (epoch, name)
         
     | 
| 1335 | 
         
            +
                            load_path = os.path.join(self.save_dir, load_filename)
         
     | 
| 1336 | 
         
            +
                            net = getattr(self, name)
         
     | 
| 1337 | 
         
            +
                            if isinstance(net, torch.nn.DataParallel):
         
     | 
| 1338 | 
         
            +
                                net = net.module
         
     | 
| 1339 | 
         
            +
                            print('loading the model from %s' % load_path)
         
     | 
| 1340 | 
         
            +
                            state_dict = torch.load(
         
     | 
| 1341 | 
         
            +
                                load_path, map_location=lambda storage, loc: storage.cuda())
         
     | 
| 1342 | 
         
            +
                            if hasattr(state_dict, '_metadata'):
         
     | 
| 1343 | 
         
            +
                                del state_dict._metadata
         
     | 
| 1344 | 
         
            +
             
     | 
| 1345 | 
         
            +
                            net.load_state_dict(state_dict)
         
     | 
| 1346 | 
         
            +
                            del state_dict
         
     | 
| 1347 | 
         
            +
             
     | 
| 1348 | 
         
            +
                def npad(self, im, pad=128):
         
     | 
| 1349 | 
         
            +
                    h,w = im.shape[-2:]
         
     | 
| 1350 | 
         
            +
                    hp = h //pad*pad+pad
         
     | 
| 1351 | 
         
            +
                    wp = w //pad*pad+pad
         
     | 
| 1352 | 
         
            +
                    return F.pad(im, (0, wp-w, 0, hp-h), mode='replicate')
         
     | 
| 1353 | 
         
            +
             
     | 
| 1354 | 
         
            +
                def forward(self, x, line=None, img_input=True, output_screen_only=True):
         
     | 
| 1355 | 
         
            +
                    if img_input:
         
     | 
| 1356 | 
         
            +
                        if line is None:
         
     | 
| 1357 | 
         
            +
                            line = torch.ones_like(x)
         
     | 
| 1358 | 
         
            +
                        else:
         
     | 
| 1359 | 
         
            +
                            line = torch.sign(line)
         
     | 
| 1360 | 
         
            +
                            x = torch.clamp(x + (1-line),-1,1)
         
     | 
| 1361 | 
         
            +
                        h,w = x.shape[-2:]
         
     | 
| 1362 | 
         
            +
                        input = torch.cat([x, line], 1)
         
     | 
| 1363 | 
         
            +
                        input = self.npad(input)
         
     | 
| 1364 | 
         
            +
                        inter = self.enc(input)[:,:,:h,:w]
         
     | 
| 1365 | 
         
            +
                        scr, logvar = torch.split(inter, (self.outc, self.outc), dim=1)
         
     | 
| 1366 | 
         
            +
                        if output_screen_only:
         
     | 
| 1367 | 
         
            +
                            return scr
         
     | 
| 1368 | 
         
            +
                        recons = self.dec(scr)
         
     | 
| 1369 | 
         
            +
                        return recons, scr, logvar
         
     | 
| 1370 | 
         
            +
                    else:
         
     | 
| 1371 | 
         
            +
                        h,w = x.shape[-2:]
         
     | 
| 1372 | 
         
            +
                        x = self.npad(x)
         
     | 
| 1373 | 
         
            +
                        recons = self.dec(x)[:,:,:h,:w]
         
     | 
| 1374 | 
         
            +
                        recons = (recons+1)*(line+1)/2-1
         
     | 
| 1375 | 
         
            +
                        return torch.clamp(recons,-1,1)
         
     | 
    	
        BidirectionalTranslation/options/base_options.py
    ADDED
    
    | 
         @@ -0,0 +1,142 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import argparse
         
     | 
| 2 | 
         
            +
            import os
         
     | 
| 3 | 
         
            +
            from util import util
         
     | 
| 4 | 
         
            +
            import torch
         
     | 
| 5 | 
         
            +
            import models
         
     | 
| 6 | 
         
            +
            import data
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            class BaseOptions():
         
     | 
| 9 | 
         
            +
                def __init__(self):
         
     | 
| 10 | 
         
            +
                    self.initialized = False
         
     | 
| 11 | 
         
            +
                
         
     | 
| 12 | 
         
            +
                def initialize(self, parser):
         
     | 
| 13 | 
         
            +
                    """Initialize options used during both training and test time."""
         
     | 
| 14 | 
         
            +
                    # Basic options
         
     | 
| 15 | 
         
            +
                    parser.add_argument('--dataroot', required=False, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')
         
     | 
| 16 | 
         
            +
                    parser.add_argument('--batch_size', type=int, default=2, help='input batch size')
         
     | 
| 17 | 
         
            +
                    parser.add_argument('--load_size', type=int, default=512, help='scale images to this size')  # Modified default
         
     | 
| 18 | 
         
            +
                    parser.add_argument('--crop_size', type=int, default=1024, help='then crop to this size')    # Modified default
         
     | 
| 19 | 
         
            +
                    parser.add_argument('--input_nc', type=int, default=1, help='# of input image channels')     # Modified default
         
     | 
| 20 | 
         
            +
                    parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels')   # Modified default
         
     | 
| 21 | 
         
            +
                    parser.add_argument('--nz', type=int, default=64, help='#latent vector')                     # Modified default
         
     | 
| 22 | 
         
            +
                    parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0  0,1,2, 0,2, -1 for CPU mode')
         
     | 
| 23 | 
         
            +
                    parser.add_argument('--name', type=str, default='color2manga_cycle_ganstft', help='name of the experiment')  # Modified default
         
     | 
| 24 | 
         
            +
                    parser.add_argument('--preprocess', type=str, default='none', help='not implemented')         # Modified default
         
     | 
| 25 | 
         
            +
                    parser.add_argument('--dataset_mode', type=str, default='aligned', help='aligned,single')
         
     | 
| 26 | 
         
            +
                    parser.add_argument('--model', type=str, default='cycle_ganstft', help='chooses which model to use')
         
     | 
| 27 | 
         
            +
                    parser.add_argument('--direction', type=str, default='BtoA', help='AtoB or BtoA')            # Modified default
         
     | 
| 28 | 
         
            +
                    parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
         
     | 
| 29 | 
         
            +
                    parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data')
         
     | 
| 30 | 
         
            +
                    parser.add_argument('--local_rank', default=0, type=int, help='# threads for loading data')
         
     | 
| 31 | 
         
            +
                    parser.add_argument('--checkpoints_dir', type=str, default=self.model_global_path+'/ScreenStyle/color2manga/', help='models are saved here')  # Modified default
         
     | 
| 32 | 
         
            +
                    parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
         
     | 
| 33 | 
         
            +
                    parser.add_argument('--use_dropout', action='store_true', help='use dropout for the generator')
         
     | 
| 34 | 
         
            +
                    parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset.')
         
     | 
| 35 | 
         
            +
                    parser.add_argument('--no_flip', action='store_false', help='if specified, do not flip the images for data argumentation')  # Modified default
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
                    # Model parameters
         
     | 
| 38 | 
         
            +
                    parser.add_argument('--level', type=int, default=0, help='level to train')
         
     | 
| 39 | 
         
            +
                    parser.add_argument('--num_Ds', type=int, default=2, help='number of Discriminators')
         
     | 
| 40 | 
         
            +
                    parser.add_argument('--netD', type=str, default='basic_256_multi', help='selects model to use for netD')
         
     | 
| 41 | 
         
            +
                    parser.add_argument('--netD2', type=str, default='basic_256_multi', help='selects model to use for netD2')
         
     | 
| 42 | 
         
            +
                    parser.add_argument('--netG', type=str, default='unet_256', help='selects model to use for netG')
         
     | 
| 43 | 
         
            +
                    parser.add_argument('--netC', type=str, default='unet_128', help='selects model to use for netC')
         
     | 
| 44 | 
         
            +
                    parser.add_argument('--netE', type=str, default='conv_256', help='selects model to use for netE')
         
     | 
| 45 | 
         
            +
                    parser.add_argument('--nef', type=int, default=48, help='# of encoder filters in the first conv layer')  # Modified default
         
     | 
| 46 | 
         
            +
                    parser.add_argument('--ngf', type=int, default=48, help='# of gen filters in the last conv layer')       # Modified default
         
     | 
| 47 | 
         
            +
                    parser.add_argument('--ndf', type=int, default=32, help='# of discrim filters in the first conv layer')  # Modified default
         
     | 
| 48 | 
         
            +
                    parser.add_argument('--norm', type=str, default='layer', help='instance normalization or batch normalization')
         
     | 
| 49 | 
         
            +
                    parser.add_argument('--upsample', type=str, default='bilinear', help='basic | bilinear')                  # Modified default
         
     | 
| 50 | 
         
            +
                    parser.add_argument('--nl', type=str, default='prelu', help='non-linearity activation: relu | lrelu | elu')
         
     | 
| 51 | 
         
            +
                    parser.add_argument('--no_encode', action='store_true', help='if specified, print more debugging information')
         
     | 
| 52 | 
         
            +
                    parser.add_argument('--color2screen', action='store_true', help='continue training: load the latest model including RGB model')  # Modified default
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
                    # Extra parameters
         
     | 
| 55 | 
         
            +
                    parser.add_argument('--where_add', type=str, default='all', help='input|all|middle; where to add z in the network G')
         
     | 
| 56 | 
         
            +
                    parser.add_argument('--conditional_D', action='store_true', help='if use conditional GAN for D')
         
     | 
| 57 | 
         
            +
                    parser.add_argument('--init_type', type=str, default='kaiming', help='network initialization [normal | xavier | kaiming | orthogonal]')
         
     | 
| 58 | 
         
            +
                    parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')
         
     | 
| 59 | 
         
            +
                    parser.add_argument('--center_crop', action='store_true', help='if apply for center cropping for the test')  # Modified default
         
     | 
| 60 | 
         
            +
                    parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information')
         
     | 
| 61 | 
         
            +
                    parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')
         
     | 
| 62 | 
         
            +
                    parser.add_argument('--display_winsize', type=int, default=256, help='display window size')
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
                    # Special tasks
         
     | 
| 65 | 
         
            +
                    self.initialized = True
         
     | 
| 66 | 
         
            +
                    return parser
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
                def gather_options(self):
         
     | 
| 69 | 
         
            +
                    """Initialize our parser with basic options (only once)."""
         
     | 
| 70 | 
         
            +
                    if not self.initialized:
         
     | 
| 71 | 
         
            +
                        parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
         
     | 
| 72 | 
         
            +
                        parser = self.initialize(parser)
         
     | 
| 73 | 
         
            +
                    
         
     | 
| 74 | 
         
            +
                    # Get the basic options
         
     | 
| 75 | 
         
            +
                    opt, _ = parser.parse_known_args()
         
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
                    # Modify model-related parser options
         
     | 
| 78 | 
         
            +
                    model_name = opt.model
         
     | 
| 79 | 
         
            +
                    model_option_setter = models.get_option_setter(model_name)
         
     | 
| 80 | 
         
            +
                    parser = model_option_setter(parser, self.isTrain)
         
     | 
| 81 | 
         
            +
                    opt, _ = parser.parse_known_args()  # Parse again with new defaults
         
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
                    # Modify dataset-related parser options
         
     | 
| 84 | 
         
            +
                    dataset_name = opt.dataset_mode
         
     | 
| 85 | 
         
            +
                    dataset_option_setter = data.get_option_setter(dataset_name)
         
     | 
| 86 | 
         
            +
                    parser = dataset_option_setter(parser, self.isTrain)
         
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
                    # Save and return the parser
         
     | 
| 89 | 
         
            +
                    self.parser = parser
         
     | 
| 90 | 
         
            +
                    return parser.parse_args()
         
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
                def print_options(self, opt):
         
     | 
| 93 | 
         
            +
                    """Print and save options."""
         
     | 
| 94 | 
         
            +
                    message = ''
         
     | 
| 95 | 
         
            +
                    message += '----------------- Options ---------------\n'
         
     | 
| 96 | 
         
            +
                    for k, v in sorted(vars(opt).items()):
         
     | 
| 97 | 
         
            +
                        comment = ''
         
     | 
| 98 | 
         
            +
                        default = self.parser.get_default(k)
         
     | 
| 99 | 
         
            +
                        if v != default:
         
     | 
| 100 | 
         
            +
                            comment = '\t[default: %s]' % str(default)
         
     | 
| 101 | 
         
            +
                        message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
         
     | 
| 102 | 
         
            +
                    message += '----------------- End -------------------'
         
     | 
| 103 | 
         
            +
                    print(message)
         
     | 
| 104 | 
         
            +
             
     | 
| 105 | 
         
            +
                    # Save to the disk
         
     | 
| 106 | 
         
            +
                    expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
         
     | 
| 107 | 
         
            +
                    if not os.path.exists(expr_dir):
         
     | 
| 108 | 
         
            +
                        try:
         
     | 
| 109 | 
         
            +
                            util.mkdirs(expr_dir)
         
     | 
| 110 | 
         
            +
                        except:
         
     | 
| 111 | 
         
            +
                            pass
         
     | 
| 112 | 
         
            +
                    file_name = os.path.join(expr_dir, 'opt.txt')
         
     | 
| 113 | 
         
            +
                    with open(file_name, 'wt') as opt_file:
         
     | 
| 114 | 
         
            +
                        opt_file.write(message)
         
     | 
| 115 | 
         
            +
                        opt_file.write('\n')
         
     | 
| 116 | 
         
            +
             
     | 
| 117 | 
         
            +
                def parse(self, model_global_path):
         
     | 
| 118 | 
         
            +
                    """Parse options, create checkpoints directory suffix, and set up gpu device."""
         
     | 
| 119 | 
         
            +
                    self.model_global_path = model_global_path
         
     | 
| 120 | 
         
            +
                    opt = self.gather_options()
         
     | 
| 121 | 
         
            +
                    opt.isTrain = self.isTrain  # train or test
         
     | 
| 122 | 
         
            +
                    
         
     | 
| 123 | 
         
            +
             
     | 
| 124 | 
         
            +
                    # Process opt.suffix
         
     | 
| 125 | 
         
            +
                    if opt.suffix:
         
     | 
| 126 | 
         
            +
                        suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else ''
         
     | 
| 127 | 
         
            +
                        opt.name = opt.name + suffix
         
     | 
| 128 | 
         
            +
             
     | 
| 129 | 
         
            +
                    self.print_options(opt)
         
     | 
| 130 | 
         
            +
             
     | 
| 131 | 
         
            +
                    # Set gpu ids
         
     | 
| 132 | 
         
            +
                    str_ids = opt.gpu_ids.split(',')
         
     | 
| 133 | 
         
            +
                    opt.gpu_ids = []
         
     | 
| 134 | 
         
            +
                    for str_id in str_ids:
         
     | 
| 135 | 
         
            +
                        id = int(str_id)
         
     | 
| 136 | 
         
            +
                        if id >= 0:
         
     | 
| 137 | 
         
            +
                            opt.gpu_ids.append(id)
         
     | 
| 138 | 
         
            +
                    if len(opt.gpu_ids) > 0:
         
     | 
| 139 | 
         
            +
                        torch.cuda.set_device(opt.gpu_ids[0])
         
     | 
| 140 | 
         
            +
             
     | 
| 141 | 
         
            +
                    self.opt = opt
         
     | 
| 142 | 
         
            +
                    return self.opt
         
     | 
    	
        BidirectionalTranslation/options/test_options.py
    ADDED
    
    | 
         @@ -0,0 +1,19 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from .base_options import BaseOptions
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            class TestOptions(BaseOptions):
         
     | 
| 4 | 
         
            +
                def initialize(self, parser):
         
     | 
| 5 | 
         
            +
                    BaseOptions.initialize(self, parser)
         
     | 
| 6 | 
         
            +
                    
         
     | 
| 7 | 
         
            +
                    
         
     | 
| 8 | 
         
            +
                    # Additional test-specific arguments
         
     | 
| 9 | 
         
            +
                    parser.add_argument('--results_dir', type=str, default='../results/', help='saves results here.')
         
     | 
| 10 | 
         
            +
                    parser.add_argument('--phase', type=str, default='val', help='train, val, test, etc')
         
     | 
| 11 | 
         
            +
                    parser.add_argument('--num_test', type=int, default=30, help='how many test images to run')
         
     | 
| 12 | 
         
            +
                    parser.add_argument('--n_samples', type=int, default=1, help='#samples')
         
     | 
| 13 | 
         
            +
                    parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio for the results')
         
     | 
| 14 | 
         
            +
                    parser.add_argument('--eval', action='store_true', help='use eval mode during test time.')
         
     | 
| 15 | 
         
            +
                    parser.add_argument('--folder', type=str, default='intra', help='saves results here.')
         
     | 
| 16 | 
         
            +
                    parser.add_argument('--sync', action='store_true', help='use the same latent code for different input images')
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
                    self.isTrain = False
         
     | 
| 19 | 
         
            +
                    return parser
         
     | 
    	
        BidirectionalTranslation/requirements.txt
    ADDED
    
    | 
         @@ -0,0 +1,8 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            torch~=1.6.0
         
     | 
| 2 | 
         
            +
            torchvision~=0.4.0
         
     | 
| 3 | 
         
            +
            tensorboardx~=1.9
         
     | 
| 4 | 
         
            +
            scipy==1.1
         
     | 
| 5 | 
         
            +
            dominate~=2.3.1
         
     | 
| 6 | 
         
            +
            scikit-image~=0.16.2
         
     | 
| 7 | 
         
            +
            opencv-python~=3.4.2
         
     | 
| 8 | 
         
            +
            lpips
         
     | 
    	
        BidirectionalTranslation/scripts/test_western2manga.sh
    ADDED
    
    | 
         @@ -0,0 +1,49 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            set -ex
         
     | 
| 2 | 
         
            +
            # models
         
     | 
| 3 | 
         
            +
            RESULTS_DIR='./results/test/western2manga'
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            # dataset
         
     | 
| 6 | 
         
            +
            CLASS='color2manga'
         
     | 
| 7 | 
         
            +
            MODEL='cycle_ganstft'
         
     | 
| 8 | 
         
            +
            DIRECTION='BtoA' # from domain A to domain B
         
     | 
| 9 | 
         
            +
            PREPROCESS='none'
         
     | 
| 10 | 
         
            +
            LOAD_SIZE=512 # scale images to this size
         
     | 
| 11 | 
         
            +
            CROP_SIZE=1024 # then crop to this size
         
     | 
| 12 | 
         
            +
            INPUT_NC=1  # number of channels in the input image
         
     | 
| 13 | 
         
            +
            OUTPUT_NC=3  # number of channels in the input image
         
     | 
| 14 | 
         
            +
            NGF=48
         
     | 
| 15 | 
         
            +
            NEF=48
         
     | 
| 16 | 
         
            +
            NDF=32
         
     | 
| 17 | 
         
            +
            NZ=64
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            # misc
         
     | 
| 20 | 
         
            +
            GPU_ID=0   # gpu id
         
     | 
| 21 | 
         
            +
            NUM_TEST=30 # number of input images duirng test
         
     | 
| 22 | 
         
            +
            NUM_SAMPLES=1 # number of samples per input images
         
     | 
| 23 | 
         
            +
            NAME=${CLASS}_${MODEL}
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
            # command
         
     | 
| 26 | 
         
            +
            CUDA_VISIBLE_DEVICES=${GPU_ID} \
         
     | 
| 27 | 
         
            +
            python3 ./test.py \
         
     | 
| 28 | 
         
            +
              --dataroot ./datasets/${CLASS} \
         
     | 
| 29 | 
         
            +
              --results_dir ${RESULTS_DIR} \
         
     | 
| 30 | 
         
            +
              --checkpoints_dir ./checkpoints/${CLASS}/ \
         
     | 
| 31 | 
         
            +
              --name ${NAME} \
         
     | 
| 32 | 
         
            +
              --model ${MODEL} \
         
     | 
| 33 | 
         
            +
              --direction ${DIRECTION} \
         
     | 
| 34 | 
         
            +
              --preprocess ${PREPROCESS} \
         
     | 
| 35 | 
         
            +
              --load_size ${LOAD_SIZE} \
         
     | 
| 36 | 
         
            +
              --crop_size ${CROP_SIZE} \
         
     | 
| 37 | 
         
            +
              --input_nc ${INPUT_NC} \
         
     | 
| 38 | 
         
            +
              --output_nc ${OUTPUT_NC} \
         
     | 
| 39 | 
         
            +
              --nz ${NZ} \
         
     | 
| 40 | 
         
            +
              --netE conv_256 \
         
     | 
| 41 | 
         
            +
              --num_test ${NUM_TEST} \
         
     | 
| 42 | 
         
            +
              --n_samples ${NUM_SAMPLES} \
         
     | 
| 43 | 
         
            +
              --upsample bilinear \
         
     | 
| 44 | 
         
            +
              --ngf ${NGF} \
         
     | 
| 45 | 
         
            +
              --nef ${NEF} \
         
     | 
| 46 | 
         
            +
              --ndf ${NDF} \
         
     | 
| 47 | 
         
            +
              --center_crop \
         
     | 
| 48 | 
         
            +
              --color2screen \
         
     | 
| 49 | 
         
            +
              --no_flip
         
     | 
    	
        BidirectionalTranslation/test.py
    ADDED
    
    | 
         @@ -0,0 +1,71 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import os
         
     | 
| 2 | 
         
            +
            from options.test_options import TestOptions
         
     | 
| 3 | 
         
            +
            from data import create_dataset
         
     | 
| 4 | 
         
            +
            from models import create_model
         
     | 
| 5 | 
         
            +
            from util.visualizer import save_images
         
     | 
| 6 | 
         
            +
            from itertools import islice
         
     | 
| 7 | 
         
            +
            from util import html
         
     | 
| 8 | 
         
            +
            import cv2
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            seed = 10
         
     | 
| 11 | 
         
            +
            import torch
         
     | 
| 12 | 
         
            +
            import numpy as np
         
     | 
| 13 | 
         
            +
            torch.manual_seed(seed)
         
     | 
| 14 | 
         
            +
            torch.cuda.manual_seed(seed)
         
     | 
| 15 | 
         
            +
            np.random.seed(seed)
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            # options
         
     | 
| 18 | 
         
            +
            opt = TestOptions().parse()
         
     | 
| 19 | 
         
            +
            opt.num_threads = 1   # test code only supports num_threads=1
         
     | 
| 20 | 
         
            +
            opt.batch_size = 1   # test code only supports batch_size=1
         
     | 
| 21 | 
         
            +
            opt.serial_batches = True  # no shuffle
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
            model = create_model(opt)
         
     | 
| 24 | 
         
            +
            model.setup(opt)
         
     | 
| 25 | 
         
            +
            model.eval() 
         
     | 
| 26 | 
         
            +
            print('Loading model %s' % opt.model)
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
            testdata = ['manga_paper']
         
     | 
| 29 | 
         
            +
            # fake_sty = model.get_z_random(1, 64, truncation=True)
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
            opt.dataset_mode = 'singleSr'
         
     | 
| 32 | 
         
            +
            for folder in testdata:
         
     | 
| 33 | 
         
            +
                opt.folder = folder
         
     | 
| 34 | 
         
            +
                # create dataset
         
     | 
| 35 | 
         
            +
                dataset = create_dataset(opt)
         
     | 
| 36 | 
         
            +
                web_dir = os.path.join(opt.results_dir, opt.folder + '_Sr2Co')
         
     | 
| 37 | 
         
            +
                webpage = html.HTML(web_dir, 'Training = %s, Phase = %s, Class =%s' % (opt.name, opt.phase, opt.name))
         
     | 
| 38 | 
         
            +
                # fake_sty = model.get_z_random(1, 64, truncation=True)
         
     | 
| 39 | 
         
            +
                for i, data in enumerate(islice(dataset, opt.num_test)):
         
     | 
| 40 | 
         
            +
                    h = data['h']
         
     | 
| 41 | 
         
            +
                    w = data['w']
         
     | 
| 42 | 
         
            +
                    model.set_input(data)
         
     | 
| 43 | 
         
            +
                    fake_sty = model.get_z_random(1, 64, truncation=True, tvalue=1.25)
         
     | 
| 44 | 
         
            +
                    fake_B, SCR, line = model.forward(AtoB=False, sty=fake_sty)
         
     | 
| 45 | 
         
            +
                    images=[fake_B[:,:,:h,:w]]
         
     | 
| 46 | 
         
            +
                    names=['color']
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
                    img_path = 'input_%3.3d' % i
         
     | 
| 49 | 
         
            +
                    save_images(webpage, images, names, img_path, aspect_ratio=opt.aspect_ratio, width=opt.crop_size)
         
     | 
| 50 | 
         
            +
                webpage.save()
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
            testdata = ['western_paper']
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
            opt.dataset_mode = 'singleCo'
         
     | 
| 55 | 
         
            +
            for folder in testdata:
         
     | 
| 56 | 
         
            +
                opt.folder = folder
         
     | 
| 57 | 
         
            +
                # create dataset
         
     | 
| 58 | 
         
            +
                dataset = create_dataset(opt)
         
     | 
| 59 | 
         
            +
                web_dir = os.path.join(opt.results_dir, opt.folder + '_Sr2Co')
         
     | 
| 60 | 
         
            +
                webpage = html.HTML(web_dir, 'Training = %s, Phase = %s, Class =%s' % (opt.name, opt.phase, opt.name))
         
     | 
| 61 | 
         
            +
                for i, data in enumerate(islice(dataset, opt.num_test)):
         
     | 
| 62 | 
         
            +
                    h = data['h']
         
     | 
| 63 | 
         
            +
                    w = data['w']
         
     | 
| 64 | 
         
            +
                    model.set_input(data)
         
     | 
| 65 | 
         
            +
                    fake_B, fake_B2, SCR = model.forward(AtoB=True)
         
     | 
| 66 | 
         
            +
                    images=[fake_B2[:,:,:h,:w]]
         
     | 
| 67 | 
         
            +
                    names=['manga']
         
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
            +
                    img_path = 'input_%3.3d' % i
         
     | 
| 70 | 
         
            +
                    save_images(webpage, images, names, img_path, aspect_ratio=opt.aspect_ratio, width=opt.crop_size)
         
     | 
| 71 | 
         
            +
                webpage.save()
         
     | 
    	
        BidirectionalTranslation/util/html.py
    ADDED
    
    | 
         @@ -0,0 +1,86 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import dominate
         
     | 
| 2 | 
         
            +
            from dominate.tags import meta, h3, table, tr, td, p, a, img, br
         
     | 
| 3 | 
         
            +
            import os
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            class HTML:
         
     | 
| 7 | 
         
            +
                """This HTML class allows us to save images and write texts into a single HTML file.
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
                 It consists of functions such as <add_header> (add a text header to the HTML file),
         
     | 
| 10 | 
         
            +
                 <add_images> (add a row of images to the HTML file), and <save> (save the HTML to the disk).
         
     | 
| 11 | 
         
            +
                 It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API.
         
     | 
| 12 | 
         
            +
                """
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
                def __init__(self, web_dir, title, refresh=0):
         
     | 
| 15 | 
         
            +
                    """Initialize the HTML classes
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
                    Parameters:
         
     | 
| 18 | 
         
            +
                        web_dir (str) -- a directory that stores the webpage. HTML file will be created at <web_dir>/index.html; images will be saved at <web_dir/images/
         
     | 
| 19 | 
         
            +
                        title (str)   -- the webpage name
         
     | 
| 20 | 
         
            +
                        reflect (int) -- how often the website refresh itself; if 0; no refreshing
         
     | 
| 21 | 
         
            +
                    """
         
     | 
| 22 | 
         
            +
                    self.title = title
         
     | 
| 23 | 
         
            +
                    self.web_dir = web_dir
         
     | 
| 24 | 
         
            +
                    self.img_dir = os.path.join(self.web_dir, 'images')
         
     | 
| 25 | 
         
            +
                    if not os.path.exists(self.web_dir):
         
     | 
| 26 | 
         
            +
                        os.makedirs(self.web_dir)
         
     | 
| 27 | 
         
            +
                    if not os.path.exists(self.img_dir):
         
     | 
| 28 | 
         
            +
                        os.makedirs(self.img_dir)
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
                    self.doc = dominate.document(title=title)
         
     | 
| 31 | 
         
            +
                    if refresh > 0:
         
     | 
| 32 | 
         
            +
                        with self.doc.head:
         
     | 
| 33 | 
         
            +
                            meta(http_equiv="refresh", content=str(refresh))
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
                def get_image_dir(self):
         
     | 
| 36 | 
         
            +
                    """Return the directory that stores images"""
         
     | 
| 37 | 
         
            +
                    return self.img_dir
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
                def add_header(self, text):
         
     | 
| 40 | 
         
            +
                    """Insert a header to the HTML file
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
                    Parameters:
         
     | 
| 43 | 
         
            +
                        text (str) -- the header text
         
     | 
| 44 | 
         
            +
                    """
         
     | 
| 45 | 
         
            +
                    with self.doc:
         
     | 
| 46 | 
         
            +
                        h3(text)
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
                def add_images(self, ims, txts, links, width=400):
         
     | 
| 49 | 
         
            +
                    """add images to the HTML file
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
                    Parameters:
         
     | 
| 52 | 
         
            +
                        ims (str list)   -- a list of image paths
         
     | 
| 53 | 
         
            +
                        txts (str list)  -- a list of image names shown on the website
         
     | 
| 54 | 
         
            +
                        links (str list) --  a list of hyperref links; when you click an image, it will redirect you to a new page
         
     | 
| 55 | 
         
            +
                    """
         
     | 
| 56 | 
         
            +
                    self.t = table(border=1, style="table-layout: fixed;")  # Insert a table
         
     | 
| 57 | 
         
            +
                    self.doc.add(self.t)
         
     | 
| 58 | 
         
            +
                    with self.t:
         
     | 
| 59 | 
         
            +
                        with tr():
         
     | 
| 60 | 
         
            +
                            for im, txt, link in zip(ims, txts, links):
         
     | 
| 61 | 
         
            +
                                with td(style="word-wrap: break-word;", halign="center", valign="top"):
         
     | 
| 62 | 
         
            +
                                    with p():
         
     | 
| 63 | 
         
            +
                                        with a(href=os.path.join('images', link)):
         
     | 
| 64 | 
         
            +
                                            img(style="width:%dpx" % width, src=os.path.join('images', im))
         
     | 
| 65 | 
         
            +
                                        br()
         
     | 
| 66 | 
         
            +
                                        p(txt)
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
                def save(self):
         
     | 
| 69 | 
         
            +
                    """save the current content to the HMTL file"""
         
     | 
| 70 | 
         
            +
                    html_file = '%s/index.html' % self.web_dir
         
     | 
| 71 | 
         
            +
                    f = open(html_file, 'wt')
         
     | 
| 72 | 
         
            +
                    f.write(self.doc.render())
         
     | 
| 73 | 
         
            +
                    f.close()
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
            if __name__ == '__main__':  # we show an example usage here.
         
     | 
| 77 | 
         
            +
                html = HTML('web/', 'test_html')
         
     | 
| 78 | 
         
            +
                html.add_header('hello world')
         
     | 
| 79 | 
         
            +
             
     | 
| 80 | 
         
            +
                ims, txts, links = [], [], []
         
     | 
| 81 | 
         
            +
                for n in range(4):
         
     | 
| 82 | 
         
            +
                    ims.append('image_%d.png' % n)
         
     | 
| 83 | 
         
            +
                    txts.append('text_%d' % n)
         
     | 
| 84 | 
         
            +
                    links.append('image_%d.png' % n)
         
     | 
| 85 | 
         
            +
                html.add_images(ims, txts, links)
         
     | 
| 86 | 
         
            +
                html.save()
         
     | 
    	
        BidirectionalTranslation/util/util.py
    ADDED
    
    | 
         @@ -0,0 +1,136 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from __future__ import print_function
         
     | 
| 2 | 
         
            +
            import torch
         
     | 
| 3 | 
         
            +
            import numpy as np
         
     | 
| 4 | 
         
            +
            from PIL import Image
         
     | 
| 5 | 
         
            +
            import os
         
     | 
| 6 | 
         
            +
            import pickle
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            def tensor2im(input_image, imtype=np.uint8):
         
     | 
| 10 | 
         
            +
                """"Convert a Tensor array into a numpy image array.
         
     | 
| 11 | 
         
            +
                Parameters:
         
     | 
| 12 | 
         
            +
                    input_image (tensor) --  the input image tensor array
         
     | 
| 13 | 
         
            +
                    imtype (type)        --  the desired type of the converted numpy array
         
     | 
| 14 | 
         
            +
                """
         
     | 
| 15 | 
         
            +
                if not isinstance(input_image, np.ndarray):
         
     | 
| 16 | 
         
            +
                    if isinstance(input_image, torch.Tensor):  # get the data from a variable
         
     | 
| 17 | 
         
            +
                        image_tensor = input_image.data
         
     | 
| 18 | 
         
            +
                    else:
         
     | 
| 19 | 
         
            +
                        return input_image
         
     | 
| 20 | 
         
            +
                    image_numpy = image_tensor[0].cpu().float().numpy()  # convert it into a numpy array
         
     | 
| 21 | 
         
            +
                    if image_numpy.shape[0] == 1:  # grayscale to RGB
         
     | 
| 22 | 
         
            +
                        image_numpy = np.tile(image_numpy, (3, 1, 1))
         
     | 
| 23 | 
         
            +
                    image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0  # post-processing: tranpose and scaling
         
     | 
| 24 | 
         
            +
                else:  # if it is a numpy array, do nothing
         
     | 
| 25 | 
         
            +
                    image_numpy = input_image
         
     | 
| 26 | 
         
            +
                return image_numpy.astype(imtype)
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
            def tensor2vec(vector_tensor):
         
     | 
| 30 | 
         
            +
                numpy_vec = vector_tensor.data.cpu().numpy()
         
     | 
| 31 | 
         
            +
                if numpy_vec.ndim == 4:
         
     | 
| 32 | 
         
            +
                    return numpy_vec[:, :, 0, 0]
         
     | 
| 33 | 
         
            +
                else:
         
     | 
| 34 | 
         
            +
                    return numpy_vec
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
            def pickle_load(file_name):
         
     | 
| 38 | 
         
            +
                data = None
         
     | 
| 39 | 
         
            +
                with open(file_name, 'rb') as f:
         
     | 
| 40 | 
         
            +
                    data = pickle.load(f)
         
     | 
| 41 | 
         
            +
                return data
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
            def pickle_save(file_name, data):
         
     | 
| 45 | 
         
            +
                with open(file_name, 'wb') as f:
         
     | 
| 46 | 
         
            +
                    pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
            def diagnose_network(net, name='network'):
         
     | 
| 50 | 
         
            +
                """Calculate and print the mean of average absolute(gradients)
         
     | 
| 51 | 
         
            +
                Parameters:
         
     | 
| 52 | 
         
            +
                    net (torch network) -- Torch network
         
     | 
| 53 | 
         
            +
                    name (str) -- the name of the network
         
     | 
| 54 | 
         
            +
                """
         
     | 
| 55 | 
         
            +
                mean = 0.0
         
     | 
| 56 | 
         
            +
                count = 0
         
     | 
| 57 | 
         
            +
                for param in net.parameters():
         
     | 
| 58 | 
         
            +
                    if param.grad is not None:
         
     | 
| 59 | 
         
            +
                        mean += torch.mean(torch.abs(param.grad.data))
         
     | 
| 60 | 
         
            +
                        count += 1
         
     | 
| 61 | 
         
            +
                if count > 0:
         
     | 
| 62 | 
         
            +
                    mean = mean / count
         
     | 
| 63 | 
         
            +
                print(name)
         
     | 
| 64 | 
         
            +
                print(mean)
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
            def interp_z(z0, z1, num_frames, interp_mode='linear'):
         
     | 
| 68 | 
         
            +
                zs = []
         
     | 
| 69 | 
         
            +
                if interp_mode == 'linear':
         
     | 
| 70 | 
         
            +
                    for n in range(num_frames):
         
     | 
| 71 | 
         
            +
                        ratio = n / float(num_frames - 1)
         
     | 
| 72 | 
         
            +
                        z_t = (1 - ratio) * z0 + ratio * z1
         
     | 
| 73 | 
         
            +
                        zs.append(z_t[np.newaxis, :])
         
     | 
| 74 | 
         
            +
                    zs = np.concatenate(zs, axis=0).astype(np.float32)
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
                if interp_mode == 'slerp':
         
     | 
| 77 | 
         
            +
                    z0_n = z0 / (np.linalg.norm(z0) + 1e-10)
         
     | 
| 78 | 
         
            +
                    z1_n = z1 / (np.linalg.norm(z1) + 1e-10)
         
     | 
| 79 | 
         
            +
                    omega = np.arccos(np.dot(z0_n, z1_n))
         
     | 
| 80 | 
         
            +
                    sin_omega = np.sin(omega)
         
     | 
| 81 | 
         
            +
                    if sin_omega < 1e-10 and sin_omega > -1e-10:
         
     | 
| 82 | 
         
            +
                        zs = interp_z(z0, z1, num_frames, interp_mode='linear')
         
     | 
| 83 | 
         
            +
                    else:
         
     | 
| 84 | 
         
            +
                        for n in range(num_frames):
         
     | 
| 85 | 
         
            +
                            ratio = n / float(num_frames - 1)
         
     | 
| 86 | 
         
            +
                            z_t = np.sin((1 - ratio) * omega) / sin_omega * z0 + np.sin(ratio * omega) / sin_omega * z1
         
     | 
| 87 | 
         
            +
                            zs.append(z_t[np.newaxis, :])
         
     | 
| 88 | 
         
            +
                    zs = np.concatenate(zs, axis=0).astype(np.float32)
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
                return zs
         
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
             
     | 
| 93 | 
         
            +
            def save_image(image_numpy, image_path):
         
     | 
| 94 | 
         
            +
                """Save a numpy image to the disk
         
     | 
| 95 | 
         
            +
                Parameters:
         
     | 
| 96 | 
         
            +
                    image_numpy (numpy array) -- input numpy array
         
     | 
| 97 | 
         
            +
                    image_path (str)          -- the path of the image
         
     | 
| 98 | 
         
            +
                """
         
     | 
| 99 | 
         
            +
                image_pil = Image.fromarray(image_numpy)
         
     | 
| 100 | 
         
            +
                image_pil.save(image_path)
         
     | 
| 101 | 
         
            +
             
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
            def print_numpy(x, val=True, shp=False):
         
     | 
| 104 | 
         
            +
                """Print the mean, min, max, median, std, and size of a numpy array
         
     | 
| 105 | 
         
            +
                Parameters:
         
     | 
| 106 | 
         
            +
                    val (bool) -- if print the values of the numpy array
         
     | 
| 107 | 
         
            +
                    shp (bool) -- if print the shape of the numpy array
         
     | 
| 108 | 
         
            +
                """
         
     | 
| 109 | 
         
            +
                x = x.astype(np.float64)
         
     | 
| 110 | 
         
            +
                if shp:
         
     | 
| 111 | 
         
            +
                    print('shape,', x.shape)
         
     | 
| 112 | 
         
            +
                if val:
         
     | 
| 113 | 
         
            +
                    x = x.flatten()
         
     | 
| 114 | 
         
            +
                    print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
         
     | 
| 115 | 
         
            +
                        np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
         
     | 
| 116 | 
         
            +
             
     | 
| 117 | 
         
            +
             
     | 
| 118 | 
         
            +
            def mkdirs(paths):
         
     | 
| 119 | 
         
            +
                """create empty directories if they don't exist
         
     | 
| 120 | 
         
            +
                Parameters:
         
     | 
| 121 | 
         
            +
                    paths (str list) -- a list of directory paths
         
     | 
| 122 | 
         
            +
                """
         
     | 
| 123 | 
         
            +
                if isinstance(paths, list) and not isinstance(paths, str):
         
     | 
| 124 | 
         
            +
                    for path in paths:
         
     | 
| 125 | 
         
            +
                        mkdir(path)
         
     | 
| 126 | 
         
            +
                else:
         
     | 
| 127 | 
         
            +
                    mkdir(paths)
         
     | 
| 128 | 
         
            +
             
     | 
| 129 | 
         
            +
             
     | 
| 130 | 
         
            +
            def mkdir(path):
         
     | 
| 131 | 
         
            +
                """create a single empty directory if it didn't exist
         
     | 
| 132 | 
         
            +
                Parameters:
         
     | 
| 133 | 
         
            +
                    path (str) -- a single directory path
         
     | 
| 134 | 
         
            +
                """
         
     | 
| 135 | 
         
            +
                if not os.path.exists(path):
         
     | 
| 136 | 
         
            +
                    os.makedirs(path, exist_ok=True)
         
     | 
    	
        BidirectionalTranslation/util/visualizer.py
    ADDED
    
    | 
         @@ -0,0 +1,221 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import numpy as np
         
     | 
| 2 | 
         
            +
            import os
         
     | 
| 3 | 
         
            +
            import sys
         
     | 
| 4 | 
         
            +
            import ntpath
         
     | 
| 5 | 
         
            +
            import time
         
     | 
| 6 | 
         
            +
            from . import util
         
     | 
| 7 | 
         
            +
            from . import html
         
     | 
| 8 | 
         
            +
            from subprocess import Popen, PIPE
         
     | 
| 9 | 
         
            +
            import cv2
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            # if sys.version_info[0] == 2:
         
     | 
| 13 | 
         
            +
            #     VisdomExceptionBase = Exception
         
     | 
| 14 | 
         
            +
            # else:
         
     | 
| 15 | 
         
            +
            #     VisdomExceptionBase = ConnectionError
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            def save_images(webpage, images, names, image_path, aspect_ratio=1.0, width=256):
         
     | 
| 19 | 
         
            +
                """Save images to the disk.
         
     | 
| 20 | 
         
            +
                Parameters:
         
     | 
| 21 | 
         
            +
                    webpage (the HTML class)  -- the HTML webpage class that stores these imaegs (see html.py for more details)
         
     | 
| 22 | 
         
            +
                    images (numpy array list) -- a list of numpy array that stores images
         
     | 
| 23 | 
         
            +
                    names (str list)          -- a str list stores the names of the images above
         
     | 
| 24 | 
         
            +
                    image_path (str)         -- the string is used to create image paths
         
     | 
| 25 | 
         
            +
                    aspect_ratio (float)     -- the aspect ratio of saved images
         
     | 
| 26 | 
         
            +
                    width (int)              -- the images will be resized to width x width
         
     | 
| 27 | 
         
            +
                This function will save images stored in 'visuals' to the HTML file specified by 'webpage'.
         
     | 
| 28 | 
         
            +
                """
         
     | 
| 29 | 
         
            +
                image_dir = webpage.get_image_dir()
         
     | 
| 30 | 
         
            +
                name = ntpath.basename(image_path)
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
                webpage.add_header(name)
         
     | 
| 33 | 
         
            +
                ims, txts, links = [], [], []
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
                for label, im_data in zip(names, images):
         
     | 
| 36 | 
         
            +
                    im = util.tensor2im(im_data)
         
     | 
| 37 | 
         
            +
                    image_name = '%s_%s.jpg' % (name, label)
         
     | 
| 38 | 
         
            +
                    save_path = os.path.join(image_dir, image_name)
         
     | 
| 39 | 
         
            +
                    h, w, _ = im.shape
         
     | 
| 40 | 
         
            +
                    if aspect_ratio > 1.0:
         
     | 
| 41 | 
         
            +
                        im = cv2.resize(im, (h, int(w * aspect_ratio)), interpolation=cv2.INTER_CUBIC)
         
     | 
| 42 | 
         
            +
                    if aspect_ratio < 1.0:
         
     | 
| 43 | 
         
            +
                        im = cv2.resize(im, (int(h / aspect_ratio), w), interpolation=cv2.INTER_CUBIC)
         
     | 
| 44 | 
         
            +
                    util.save_image(im, save_path)
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
                    ims.append(image_name)
         
     | 
| 47 | 
         
            +
                    txts.append(label)
         
     | 
| 48 | 
         
            +
                    links.append(image_name)
         
     | 
| 49 | 
         
            +
                webpage.add_images(ims, txts, links, width=width)
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
            class Visualizer():
         
     | 
| 53 | 
         
            +
                """This class includes several functions that can display/save images and print/save logging information.
         
     | 
| 54 | 
         
            +
                It uses a Python library 'visdom' for display, and a Python library 'dominate' (wrapped in 'HTML') for creating HTML files with images.
         
     | 
| 55 | 
         
            +
                """
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
                def __init__(self, opt):
         
     | 
| 58 | 
         
            +
                    """Initialize the Visualizer class
         
     | 
| 59 | 
         
            +
                    Parameters:
         
     | 
| 60 | 
         
            +
                        opt -- stores all the experiment flags; needs to be a subclass of BaseOptions
         
     | 
| 61 | 
         
            +
                    Step 1: Cache the training/test options
         
     | 
| 62 | 
         
            +
                    Step 2: connect to a visdom server
         
     | 
| 63 | 
         
            +
                    Step 3: create an HTML object for saveing HTML filters
         
     | 
| 64 | 
         
            +
                    Step 4: create a logging file to store training losses
         
     | 
| 65 | 
         
            +
                    """
         
     | 
| 66 | 
         
            +
                    self.opt = opt  # cache the option
         
     | 
| 67 | 
         
            +
                    self.display_id = opt.display_id
         
     | 
| 68 | 
         
            +
                    self.use_html = opt.isTrain and not opt.no_html
         
     | 
| 69 | 
         
            +
                    self.win_size = opt.display_winsize
         
     | 
| 70 | 
         
            +
                    self.name = opt.name
         
     | 
| 71 | 
         
            +
                    self.port = opt.display_port
         
     | 
| 72 | 
         
            +
                    self.saved = False
         
     | 
| 73 | 
         
            +
                    # if self.display_id > 0:  # connect to a visdom server given <display_port> and <display_server>
         
     | 
| 74 | 
         
            +
                    #     import visdom
         
     | 
| 75 | 
         
            +
                    #     self.ncols = opt.display_ncols
         
     | 
| 76 | 
         
            +
                    #     self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, env=opt.display_env)
         
     | 
| 77 | 
         
            +
                    #     if not self.vis.check_connection():
         
     | 
| 78 | 
         
            +
                    #         self.create_visdom_connections()
         
     | 
| 79 | 
         
            +
                    if self.use_html:  # create an HTML object at <checkpoints_dir>/web/; images will be saved under <checkpoints_dir>/web/images/
         
     | 
| 80 | 
         
            +
                        self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
         
     | 
| 81 | 
         
            +
                        self.img_dir = os.path.join(self.web_dir, 'images')
         
     | 
| 82 | 
         
            +
                        print('create web directory %s...' % self.web_dir)
         
     | 
| 83 | 
         
            +
                        util.mkdirs([self.web_dir, self.img_dir])
         
     | 
| 84 | 
         
            +
                    # create a logging file to store training losses
         
     | 
| 85 | 
         
            +
                    self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
         
     | 
| 86 | 
         
            +
                    with open(self.log_name, "a") as log_file:
         
     | 
| 87 | 
         
            +
                        now = time.strftime("%c")
         
     | 
| 88 | 
         
            +
                        log_file.write('================ Training Loss (%s) ================\n' % now)
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
                def reset(self):
         
     | 
| 91 | 
         
            +
                    """Reset the self.saved status"""
         
     | 
| 92 | 
         
            +
                    self.saved = False
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
                def create_visdom_connections(self):
         
     | 
| 95 | 
         
            +
                    """If the program could not connect to Visdom server, this function will start a new server at port < self.port > """
         
     | 
| 96 | 
         
            +
                    cmd = sys.executable + ' -m visdom.server -p %d &>/dev/null &' % self.port
         
     | 
| 97 | 
         
            +
                    print('\n\nCould not connect to Visdom server. \n Trying to start a server....')
         
     | 
| 98 | 
         
            +
                    print('Command: %s' % cmd)
         
     | 
| 99 | 
         
            +
                    Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE)
         
     | 
| 100 | 
         
            +
             
     | 
| 101 | 
         
            +
                def display_current_results(self, visuals, epoch, save_result):
         
     | 
| 102 | 
         
            +
                    """Display current results on visdom; save current results to an HTML file.
         
     | 
| 103 | 
         
            +
                    Parameters:
         
     | 
| 104 | 
         
            +
                        visuals (OrderedDict) - - dictionary of images to display or save
         
     | 
| 105 | 
         
            +
                        epoch (int) - - the current epoch
         
     | 
| 106 | 
         
            +
                        save_result (bool) - - if save the current results to an HTML file
         
     | 
| 107 | 
         
            +
                    """
         
     | 
| 108 | 
         
            +
                    # if self.display_id > 0:  # show images in the browser using visdom
         
     | 
| 109 | 
         
            +
                    #     ncols = self.ncols
         
     | 
| 110 | 
         
            +
                    #     if ncols > 0:        # show all the images in one visdom panel
         
     | 
| 111 | 
         
            +
                    #         ncols = min(ncols, len(visuals))
         
     | 
| 112 | 
         
            +
                    #         h, w = next(iter(visuals.values())).shape[:2]
         
     | 
| 113 | 
         
            +
                    #         table_css = """<style>
         
     | 
| 114 | 
         
            +
                    #                 table {border-collapse: separate; border-spacing: 4px; white-space: nowrap; text-align: center}
         
     | 
| 115 | 
         
            +
                    #                 table td {width: % dpx; height: % dpx; padding: 4px; outline: 4px solid black}
         
     | 
| 116 | 
         
            +
                    #                 </style>""" % (w, h)  # create a table css
         
     | 
| 117 | 
         
            +
                    #         # create a table of images.
         
     | 
| 118 | 
         
            +
                    #         title = self.name
         
     | 
| 119 | 
         
            +
                    #         label_html = ''
         
     | 
| 120 | 
         
            +
                    #         label_html_row = ''
         
     | 
| 121 | 
         
            +
                    #         images = []
         
     | 
| 122 | 
         
            +
                    #         idx = 0
         
     | 
| 123 | 
         
            +
                    #         for label, image in visuals.items():
         
     | 
| 124 | 
         
            +
                    #             image_numpy = util.tensor2im(image)
         
     | 
| 125 | 
         
            +
                    #             label_html_row += '<td>%s</td>' % label
         
     | 
| 126 | 
         
            +
                    #             images.append(image_numpy.transpose([2, 0, 1]))
         
     | 
| 127 | 
         
            +
                    #             idx += 1
         
     | 
| 128 | 
         
            +
                    #             if idx % ncols == 0:
         
     | 
| 129 | 
         
            +
                    #                 label_html += '<tr>%s</tr>' % label_html_row
         
     | 
| 130 | 
         
            +
                    #                 label_html_row = ''
         
     | 
| 131 | 
         
            +
                    #         white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255
         
     | 
| 132 | 
         
            +
                    #         while idx % ncols != 0:
         
     | 
| 133 | 
         
            +
                    #             images.append(white_image)
         
     | 
| 134 | 
         
            +
                    #             label_html_row += '<td></td>'
         
     | 
| 135 | 
         
            +
                    #             idx += 1
         
     | 
| 136 | 
         
            +
                    #         if label_html_row != '':
         
     | 
| 137 | 
         
            +
                    #             label_html += '<tr>%s</tr>' % label_html_row
         
     | 
| 138 | 
         
            +
                    #         try:
         
     | 
| 139 | 
         
            +
                    #             self.vis.images(images, nrow=ncols, win=self.display_id + 1,
         
     | 
| 140 | 
         
            +
                    #                             padding=2, opts=dict(title=title + ' images'))
         
     | 
| 141 | 
         
            +
                    #             label_html = '<table>%s</table>' % label_html
         
     | 
| 142 | 
         
            +
                    #             self.vis.text(table_css + label_html, win=self.display_id + 2,
         
     | 
| 143 | 
         
            +
                    #                           opts=dict(title=title + ' labels'))
         
     | 
| 144 | 
         
            +
                    #         except VisdomExceptionBase:
         
     | 
| 145 | 
         
            +
                    #             self.create_visdom_connections()
         
     | 
| 146 | 
         
            +
             
     | 
| 147 | 
         
            +
                    #     else:     # show each image in a separate visdom panel;
         
     | 
| 148 | 
         
            +
                    #         idx = 1
         
     | 
| 149 | 
         
            +
                    #         try:
         
     | 
| 150 | 
         
            +
                    #             for label, image in visuals.items():
         
     | 
| 151 | 
         
            +
                    #                 image_numpy = util.tensor2im(image)
         
     | 
| 152 | 
         
            +
                    #                 self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label),
         
     | 
| 153 | 
         
            +
                    #                                win=self.display_id + idx)
         
     | 
| 154 | 
         
            +
                    #                 idx += 1
         
     | 
| 155 | 
         
            +
                    #         except VisdomExceptionBase:
         
     | 
| 156 | 
         
            +
                    #             self.create_visdom_connections()
         
     | 
| 157 | 
         
            +
             
     | 
| 158 | 
         
            +
                    if self.use_html and (save_result or not self.saved):  # save images to an HTML file if they haven't been saved.
         
     | 
| 159 | 
         
            +
                        self.saved = True
         
     | 
| 160 | 
         
            +
                        # save images to the disk
         
     | 
| 161 | 
         
            +
                        for label, image in visuals.items():
         
     | 
| 162 | 
         
            +
                            image_numpy = util.tensor2im(image)
         
     | 
| 163 | 
         
            +
                            img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label))
         
     | 
| 164 | 
         
            +
                            util.save_image(image_numpy, img_path)
         
     | 
| 165 | 
         
            +
             
     | 
| 166 | 
         
            +
                        # update website
         
     | 
| 167 | 
         
            +
                        webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=1)
         
     | 
| 168 | 
         
            +
                        for n in range(epoch, 0, -1):
         
     | 
| 169 | 
         
            +
                            webpage.add_header('epoch [%d]' % n)
         
     | 
| 170 | 
         
            +
                            ims, txts, links = [], [], []
         
     | 
| 171 | 
         
            +
             
     | 
| 172 | 
         
            +
                            for label, image_numpy in visuals.items():
         
     | 
| 173 | 
         
            +
                                image_numpy = util.tensor2im(image)
         
     | 
| 174 | 
         
            +
                                img_path = 'epoch%.3d_%s.png' % (n, label)
         
     | 
| 175 | 
         
            +
                                ims.append(img_path)
         
     | 
| 176 | 
         
            +
                                txts.append(label)
         
     | 
| 177 | 
         
            +
                                links.append(img_path)
         
     | 
| 178 | 
         
            +
                            webpage.add_images(ims, txts, links, width=self.win_size)
         
     | 
| 179 | 
         
            +
                        webpage.save()
         
     | 
| 180 | 
         
            +
             
     | 
| 181 | 
         
            +
                def plot_current_losses(self, epoch, counter_ratio, losses):
         
     | 
| 182 | 
         
            +
                    """display the current losses on visdom display: dictionary of error labels and values
         
     | 
| 183 | 
         
            +
                    Parameters:
         
     | 
| 184 | 
         
            +
                        epoch (int)           -- current epoch
         
     | 
| 185 | 
         
            +
                        counter_ratio (float) -- progress (percentage) in the current epoch, between 0 to 1
         
     | 
| 186 | 
         
            +
                        losses (OrderedDict)  -- training losses stored in the format of (name, float) pairs
         
     | 
| 187 | 
         
            +
                    """
         
     | 
| 188 | 
         
            +
                    if not hasattr(self, 'plot_data'):
         
     | 
| 189 | 
         
            +
                        self.plot_data = {'X': [], 'Y': [], 'legend': list(losses.keys())}
         
     | 
| 190 | 
         
            +
                    self.plot_data['X'].append(epoch + counter_ratio)
         
     | 
| 191 | 
         
            +
                    self.plot_data['Y'].append([losses[k] for k in self.plot_data['legend']])
         
     | 
| 192 | 
         
            +
                    # try:
         
     | 
| 193 | 
         
            +
                    #     self.vis.line(
         
     | 
| 194 | 
         
            +
                    #         X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1),
         
     | 
| 195 | 
         
            +
                    #         Y=np.array(self.plot_data['Y']),
         
     | 
| 196 | 
         
            +
                    #         opts={
         
     | 
| 197 | 
         
            +
                    #             'title': self.name + ' loss over time',
         
     | 
| 198 | 
         
            +
                    #             'legend': self.plot_data['legend'],
         
     | 
| 199 | 
         
            +
                    #             'xlabel': 'epoch',
         
     | 
| 200 | 
         
            +
                    #             'ylabel': 'loss'},
         
     | 
| 201 | 
         
            +
                    #         win=self.display_id)
         
     | 
| 202 | 
         
            +
                    # except VisdomExceptionBase:
         
     | 
| 203 | 
         
            +
                    #     self.create_visdom_connections()
         
     | 
| 204 | 
         
            +
             
     | 
| 205 | 
         
            +
                # losses: same format as |losses| of plot_current_losses
         
     | 
| 206 | 
         
            +
                def print_current_losses(self, epoch, iters, losses, t_comp, t_data):
         
     | 
| 207 | 
         
            +
                    """print current losses on console; also save the losses to the disk
         
     | 
| 208 | 
         
            +
                    Parameters:
         
     | 
| 209 | 
         
            +
                        epoch (int) -- current epoch
         
     | 
| 210 | 
         
            +
                        iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch)
         
     | 
| 211 | 
         
            +
                        losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
         
     | 
| 212 | 
         
            +
                        t_comp (float) -- computational time per data point (normalized by batch_size)
         
     | 
| 213 | 
         
            +
                        t_data (float) -- data loading time per data point (normalized by batch_size)
         
     | 
| 214 | 
         
            +
                    """
         
     | 
| 215 | 
         
            +
                    message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data)
         
     | 
| 216 | 
         
            +
                    for k, v in losses.items():
         
     | 
| 217 | 
         
            +
                        message += '%s: %.3f ' % (k, v)
         
     | 
| 218 | 
         
            +
             
     | 
| 219 | 
         
            +
                    print(message)  # print the message
         
     | 
| 220 | 
         
            +
                    with open(self.log_name, "a") as log_file:
         
     | 
| 221 | 
         
            +
                        log_file.write('%s\n' % message)  # save the message
         
     | 
    	
        app.py
    ADDED
    
    | 
         @@ -0,0 +1,507 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import contextlib
         
     | 
| 2 | 
         
            +
            import gc
         
     | 
| 3 | 
         
            +
            import json
         
     | 
| 4 | 
         
            +
            import logging
         
     | 
| 5 | 
         
            +
            import math
         
     | 
| 6 | 
         
            +
            import os
         
     | 
| 7 | 
         
            +
            import random
         
     | 
| 8 | 
         
            +
            import shutil
         
     | 
| 9 | 
         
            +
            import sys
         
     | 
| 10 | 
         
            +
            import time
         
     | 
| 11 | 
         
            +
            import itertools
         
     | 
| 12 | 
         
            +
            from pathlib import Path
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            import cv2
         
     | 
| 15 | 
         
            +
            import numpy as np
         
     | 
| 16 | 
         
            +
            from PIL import Image, ImageDraw
         
     | 
| 17 | 
         
            +
            import torch
         
     | 
| 18 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 19 | 
         
            +
            import torch.utils.checkpoint
         
     | 
| 20 | 
         
            +
            from torch.utils.data import Dataset
         
     | 
| 21 | 
         
            +
            from torchvision import transforms
         
     | 
| 22 | 
         
            +
            from tqdm.auto import tqdm
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            import accelerate
         
     | 
| 25 | 
         
            +
            from accelerate import Accelerator
         
     | 
| 26 | 
         
            +
            from accelerate.logging import get_logger
         
     | 
| 27 | 
         
            +
            from accelerate.utils import ProjectConfiguration, set_seed
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
            from datasets import load_dataset
         
     | 
| 30 | 
         
            +
            from huggingface_hub import create_repo, upload_folder
         
     | 
| 31 | 
         
            +
            from packaging import version
         
     | 
| 32 | 
         
            +
            from safetensors.torch import load_model
         
     | 
| 33 | 
         
            +
            from peft import LoraConfig
         
     | 
| 34 | 
         
            +
            import gradio as gr
         
     | 
| 35 | 
         
            +
            import pandas as pd
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
            import transformers
         
     | 
| 38 | 
         
            +
            from transformers import (
         
     | 
| 39 | 
         
            +
                AutoTokenizer,
         
     | 
| 40 | 
         
            +
                PretrainedConfig,
         
     | 
| 41 | 
         
            +
                CLIPVisionModelWithProjection,
         
     | 
| 42 | 
         
            +
                CLIPImageProcessor,
         
     | 
| 43 | 
         
            +
                CLIPProcessor,
         
     | 
| 44 | 
         
            +
            )
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
            import diffusers
         
     | 
| 47 | 
         
            +
            from diffusers import (
         
     | 
| 48 | 
         
            +
                AutoencoderKL,
         
     | 
| 49 | 
         
            +
                DDPMScheduler,
         
     | 
| 50 | 
         
            +
                ColorGuiderPixArtModel,
         
     | 
| 51 | 
         
            +
                ColorGuiderSDModel,
         
     | 
| 52 | 
         
            +
                UNet2DConditionModel,
         
     | 
| 53 | 
         
            +
                PixArtTransformer2DModel,
         
     | 
| 54 | 
         
            +
                ColorFlowPixArtAlphaPipeline,
         
     | 
| 55 | 
         
            +
                ColorFlowSDPipeline,
         
     | 
| 56 | 
         
            +
                UniPCMultistepScheduler,
         
     | 
| 57 | 
         
            +
            )
         
     | 
| 58 | 
         
            +
            from util_colorflow.utils import *
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
            sys.path.append('./BidirectionalTranslation')
         
     | 
| 61 | 
         
            +
            from options.test_options import TestOptions
         
     | 
| 62 | 
         
            +
            from models import create_model
         
     | 
| 63 | 
         
            +
            from util import util
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
            from huggingface_hub import snapshot_download
         
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
            model_global_path = snapshot_download(repo_id="JunhaoZhuang/ColorFlow", cache_dir='./colorflow/')
         
     | 
| 68 | 
         
            +
            print(model_global_path)
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
            transform = transforms.Compose([
         
     | 
| 72 | 
         
            +
                transforms.ToTensor(),  
         
     | 
| 73 | 
         
            +
                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  
         
     | 
| 74 | 
         
            +
            ])
         
     | 
| 75 | 
         
            +
            weight_dtype = torch.float16
         
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
            # line model
         
     | 
| 78 | 
         
            +
            line_model_path = model_global_path + '/LE/erika.pth'
         
     | 
| 79 | 
         
            +
            line_model = res_skip()
         
     | 
| 80 | 
         
            +
            line_model.load_state_dict(torch.load(line_model_path))
         
     | 
| 81 | 
         
            +
            line_model.eval()
         
     | 
| 82 | 
         
            +
            line_model.cuda()
         
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
            # screen model
         
     | 
| 85 | 
         
            +
            global opt
         
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
            opt = TestOptions().parse(model_global_path)
         
     | 
| 88 | 
         
            +
            ScreenModel = create_model(opt, model_global_path)
         
     | 
| 89 | 
         
            +
            ScreenModel.setup(opt)
         
     | 
| 90 | 
         
            +
            ScreenModel.eval()
         
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
            image_processor = CLIPImageProcessor()
         
     | 
| 93 | 
         
            +
            image_encoder = CLIPVisionModelWithProjection.from_pretrained(model_global_path + '/image_encoder/').to('cuda')
         
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
            examples = [
         
     | 
| 97 | 
         
            +
                [
         
     | 
| 98 | 
         
            +
                    "./assets/example_5/input.png", 
         
     | 
| 99 | 
         
            +
                    ["./assets/example_5/ref1.png", "./assets/example_5/ref2.png", "./assets/example_5/ref3.png"], 
         
     | 
| 100 | 
         
            +
                    "GrayImage(ScreenStyle)", 
         
     | 
| 101 | 
         
            +
                    "800x512",  
         
     | 
| 102 | 
         
            +
                    0, 
         
     | 
| 103 | 
         
            +
                    10 
         
     | 
| 104 | 
         
            +
                ],
         
     | 
| 105 | 
         
            +
                [
         
     | 
| 106 | 
         
            +
                    "./assets/example_4/input.jpg", 
         
     | 
| 107 | 
         
            +
                    ["./assets/example_4/ref1.jpg", "./assets/example_4/ref2.jpg", "./assets/example_4/ref3.jpg"], 
         
     | 
| 108 | 
         
            +
                    "GrayImage(ScreenStyle)", 
         
     | 
| 109 | 
         
            +
                    "640x640",  
         
     | 
| 110 | 
         
            +
                    0, 
         
     | 
| 111 | 
         
            +
                    10 
         
     | 
| 112 | 
         
            +
                ],
         
     | 
| 113 | 
         
            +
                [
         
     | 
| 114 | 
         
            +
                    "./assets/example_3/input.png", 
         
     | 
| 115 | 
         
            +
                    ["./assets/example_3/ref1.png", "./assets/example_3/ref2.png", "./assets/example_3/ref3.png"], 
         
     | 
| 116 | 
         
            +
                    "GrayImage(ScreenStyle)", 
         
     | 
| 117 | 
         
            +
                    "800x512", 
         
     | 
| 118 | 
         
            +
                    0, 
         
     | 
| 119 | 
         
            +
                    10 
         
     | 
| 120 | 
         
            +
                ],
         
     | 
| 121 | 
         
            +
                [
         
     | 
| 122 | 
         
            +
                    "./assets/example_2/input.png",  
         
     | 
| 123 | 
         
            +
                    ["./assets/example_2/ref1.png", "./assets/example_2/ref2.png", "./assets/example_2/ref3.png"], 
         
     | 
| 124 | 
         
            +
                    "GrayImage(ScreenStyle)",  
         
     | 
| 125 | 
         
            +
                    "800x512",  
         
     | 
| 126 | 
         
            +
                    0,  
         
     | 
| 127 | 
         
            +
                    10  
         
     | 
| 128 | 
         
            +
                ],
         
     | 
| 129 | 
         
            +
                [
         
     | 
| 130 | 
         
            +
                    "./assets/example_1/input.jpg", 
         
     | 
| 131 | 
         
            +
                    ["./assets/example_1/ref1.jpg", "./assets/example_1/ref2.jpg", "./assets/example_1/ref3.jpg"], 
         
     | 
| 132 | 
         
            +
                    "Sketch",  
         
     | 
| 133 | 
         
            +
                    "640x640", 
         
     | 
| 134 | 
         
            +
                    0, 
         
     | 
| 135 | 
         
            +
                    10  
         
     | 
| 136 | 
         
            +
                ],
         
     | 
| 137 | 
         
            +
                [
         
     | 
| 138 | 
         
            +
                    "./assets/example_0/input.jpg", 
         
     | 
| 139 | 
         
            +
                    ["./assets/example_0/ref1.jpg"], 
         
     | 
| 140 | 
         
            +
                    "Sketch", 
         
     | 
| 141 | 
         
            +
                    "640x640",  
         
     | 
| 142 | 
         
            +
                    0, 
         
     | 
| 143 | 
         
            +
                    10 
         
     | 
| 144 | 
         
            +
                ],
         
     | 
| 145 | 
         
            +
            ]
         
     | 
| 146 | 
         
            +
             
     | 
| 147 | 
         
            +
            global pipeline
         
     | 
| 148 | 
         
            +
            global MultiResNetModel
         
     | 
| 149 | 
         
            +
             
     | 
| 150 | 
         
            +
            def load_ckpt(input_style):
         
     | 
| 151 | 
         
            +
                global pipeline
         
     | 
| 152 | 
         
            +
                global MultiResNetModel
         
     | 
| 153 | 
         
            +
                if input_style == "Sketch":
         
     | 
| 154 | 
         
            +
                    ckpt_path = model_global_path + '/sketch/'
         
     | 
| 155 | 
         
            +
                    rank = 128
         
     | 
| 156 | 
         
            +
                    pretrained_model_name_or_path = 'PixArt-alpha/PixArt-XL-2-1024-MS'
         
     | 
| 157 | 
         
            +
                    transformer = PixArtTransformer2DModel.from_pretrained(
         
     | 
| 158 | 
         
            +
                        pretrained_model_name_or_path, subfolder="transformer", revision=None, variant=None
         
     | 
| 159 | 
         
            +
                    )
         
     | 
| 160 | 
         
            +
                    pixart_config = get_pixart_config()
         
     | 
| 161 | 
         
            +
             
     | 
| 162 | 
         
            +
                    ColorGuider = ColorGuiderPixArtModel.from_pretrained(ckpt_path)
         
     | 
| 163 | 
         
            +
             
     | 
| 164 | 
         
            +
                    transformer_lora_config = LoraConfig(
         
     | 
| 165 | 
         
            +
                        r=rank,
         
     | 
| 166 | 
         
            +
                        lora_alpha=rank,
         
     | 
| 167 | 
         
            +
                        init_lora_weights="gaussian",
         
     | 
| 168 | 
         
            +
                        target_modules=["to_k", "to_q", "to_v", "to_out.0", "proj_in", "proj_out", "ff.net.0.proj", "ff.net.2", "proj", "linear", "linear_1", "linear_2"]
         
     | 
| 169 | 
         
            +
                    )
         
     | 
| 170 | 
         
            +
                    transformer.add_adapter(transformer_lora_config)
         
     | 
| 171 | 
         
            +
                    ckpt_key_t = torch.load(ckpt_path + 'transformer_lora.bin', map_location='cpu')
         
     | 
| 172 | 
         
            +
                    transformer.load_state_dict(ckpt_key_t, strict=False)
         
     | 
| 173 | 
         
            +
             
     | 
| 174 | 
         
            +
                    transformer.to('cuda', dtype=weight_dtype)
         
     | 
| 175 | 
         
            +
                    ColorGuider.to('cuda', dtype=weight_dtype)
         
     | 
| 176 | 
         
            +
                    
         
     | 
| 177 | 
         
            +
                    pipeline = ColorFlowPixArtAlphaPipeline.from_pretrained(
         
     | 
| 178 | 
         
            +
                        pretrained_model_name_or_path,
         
     | 
| 179 | 
         
            +
                        transformer=transformer,
         
     | 
| 180 | 
         
            +
                        colorguider=ColorGuider,
         
     | 
| 181 | 
         
            +
                        safety_checker=None,
         
     | 
| 182 | 
         
            +
                        revision=None,
         
     | 
| 183 | 
         
            +
                        variant=None,
         
     | 
| 184 | 
         
            +
                        torch_dtype=weight_dtype,
         
     | 
| 185 | 
         
            +
                    )
         
     | 
| 186 | 
         
            +
                    pipeline = pipeline.to("cuda")
         
     | 
| 187 | 
         
            +
                    block_out_channels = [128, 128, 256, 512, 512]
         
     | 
| 188 | 
         
            +
                    
         
     | 
| 189 | 
         
            +
                    MultiResNetModel = MultiHiddenResNetModel(block_out_channels, len(block_out_channels))
         
     | 
| 190 | 
         
            +
                    MultiResNetModel.load_state_dict(torch.load(ckpt_path + 'MultiResNetModel.bin', map_location='cpu'), strict=False)
         
     | 
| 191 | 
         
            +
                    MultiResNetModel.to('cuda', dtype=weight_dtype)
         
     | 
| 192 | 
         
            +
             
     | 
| 193 | 
         
            +
                elif input_style == "GrayImage(ScreenStyle)":
         
     | 
| 194 | 
         
            +
                    ckpt_path = model_global_path + '/GraySD/'
         
     | 
| 195 | 
         
            +
                    rank = 64
         
     | 
| 196 | 
         
            +
                    pretrained_model_name_or_path = 'stable-diffusion-v1-5/stable-diffusion-v1-5'
         
     | 
| 197 | 
         
            +
                    unet = UNet2DConditionModel.from_pretrained(
         
     | 
| 198 | 
         
            +
                        pretrained_model_name_or_path, subfolder="unet", revision=None, variant=None
         
     | 
| 199 | 
         
            +
                    )
         
     | 
| 200 | 
         
            +
                    ColorGuider = ColorGuiderSDModel.from_pretrained(ckpt_path)
         
     | 
| 201 | 
         
            +
                    ColorGuider.to('cuda', dtype=weight_dtype)
         
     | 
| 202 | 
         
            +
                    unet.to('cuda', dtype=weight_dtype)
         
     | 
| 203 | 
         
            +
                    
         
     | 
| 204 | 
         
            +
                    pipeline = ColorFlowSDPipeline.from_pretrained(
         
     | 
| 205 | 
         
            +
                        pretrained_model_name_or_path,
         
     | 
| 206 | 
         
            +
                        unet=unet,
         
     | 
| 207 | 
         
            +
                        colorguider=ColorGuider,
         
     | 
| 208 | 
         
            +
                        safety_checker=None,
         
     | 
| 209 | 
         
            +
                        revision=None,
         
     | 
| 210 | 
         
            +
                        variant=None,
         
     | 
| 211 | 
         
            +
                        torch_dtype=weight_dtype,
         
     | 
| 212 | 
         
            +
                    )
         
     | 
| 213 | 
         
            +
                    pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
         
     | 
| 214 | 
         
            +
                    unet_lora_config = LoraConfig(
         
     | 
| 215 | 
         
            +
                        r=rank,
         
     | 
| 216 | 
         
            +
                        lora_alpha=rank,
         
     | 
| 217 | 
         
            +
                        init_lora_weights="gaussian",
         
     | 
| 218 | 
         
            +
                        target_modules=["to_k", "to_q", "to_v", "to_out.0", "ff.net.0.proj", "ff.net.2"],#ff.net.0.proj ff.net.2
         
     | 
| 219 | 
         
            +
                    )
         
     | 
| 220 | 
         
            +
                    pipeline.unet.add_adapter(unet_lora_config)
         
     | 
| 221 | 
         
            +
                    pipeline.unet.load_state_dict(torch.load(ckpt_path + 'unet_lora.bin', map_location='cpu'), strict=False)
         
     | 
| 222 | 
         
            +
                    pipeline = pipeline.to("cuda")
         
     | 
| 223 | 
         
            +
                    block_out_channels = [128, 128, 256, 512, 512]
         
     | 
| 224 | 
         
            +
                    
         
     | 
| 225 | 
         
            +
                    MultiResNetModel = MultiHiddenResNetModel(block_out_channels, len(block_out_channels))
         
     | 
| 226 | 
         
            +
                    MultiResNetModel.load_state_dict(torch.load(ckpt_path + 'MultiResNetModel.bin', map_location='cpu'), strict=False)
         
     | 
| 227 | 
         
            +
                    MultiResNetModel.to('cuda', dtype=weight_dtype)
         
     | 
| 228 | 
         
            +
             
     | 
| 229 | 
         
            +
                
         
     | 
| 230 | 
         
            +
             
     | 
| 231 | 
         
            +
             
     | 
| 232 | 
         
            +
             
     | 
| 233 | 
         
            +
            global cur_input_style
         
     | 
| 234 | 
         
            +
            cur_input_style = "Sketch"
         
     | 
| 235 | 
         
            +
            load_ckpt(cur_input_style)
         
     | 
| 236 | 
         
            +
            cur_input_style = "GrayImage(ScreenStyle)"
         
     | 
| 237 | 
         
            +
            load_ckpt(cur_input_style)
         
     | 
| 238 | 
         
            +
             
     | 
| 239 | 
         
            +
             
     | 
| 240 | 
         
            +
            def fix_random_seeds(seed):
         
     | 
| 241 | 
         
            +
                random.seed(seed)
         
     | 
| 242 | 
         
            +
                np.random.seed(seed)
         
     | 
| 243 | 
         
            +
                torch.manual_seed(seed)
         
     | 
| 244 | 
         
            +
                if torch.cuda.is_available():
         
     | 
| 245 | 
         
            +
                    torch.cuda.manual_seed(seed)
         
     | 
| 246 | 
         
            +
                    torch.cuda.manual_seed_all(seed)
         
     | 
| 247 | 
         
            +
             
     | 
| 248 | 
         
            +
            def process_multi_images(files):
         
     | 
| 249 | 
         
            +
                images = [Image.open(file.name) for file in files]
         
     | 
| 250 | 
         
            +
                imgs = []
         
     | 
| 251 | 
         
            +
                for i, img in enumerate(images):
         
     | 
| 252 | 
         
            +
                    imgs.append(img)
         
     | 
| 253 | 
         
            +
                return imgs 
         
     | 
| 254 | 
         
            +
             
     | 
| 255 | 
         
            +
            def extract_lines(image):
         
     | 
| 256 | 
         
            +
                src = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
         
     | 
| 257 | 
         
            +
             
     | 
| 258 | 
         
            +
                rows = int(np.ceil(src.shape[0] / 16)) * 16
         
     | 
| 259 | 
         
            +
                cols = int(np.ceil(src.shape[1] / 16)) * 16
         
     | 
| 260 | 
         
            +
             
     | 
| 261 | 
         
            +
                patch = np.ones((1, 1, rows, cols), dtype="float32")
         
     | 
| 262 | 
         
            +
                patch[0, 0, 0:src.shape[0], 0:src.shape[1]] = src
         
     | 
| 263 | 
         
            +
             
     | 
| 264 | 
         
            +
                tensor = torch.from_numpy(patch).cuda()
         
     | 
| 265 | 
         
            +
             
     | 
| 266 | 
         
            +
                with torch.no_grad():
         
     | 
| 267 | 
         
            +
                    y = line_model(tensor)
         
     | 
| 268 | 
         
            +
             
     | 
| 269 | 
         
            +
                yc = y.cpu().numpy()[0, 0, :, :]
         
     | 
| 270 | 
         
            +
                yc[yc > 255] = 255
         
     | 
| 271 | 
         
            +
                yc[yc < 0] = 0
         
     | 
| 272 | 
         
            +
             
     | 
| 273 | 
         
            +
                outimg = yc[0:src.shape[0], 0:src.shape[1]]
         
     | 
| 274 | 
         
            +
                outimg = outimg.astype(np.uint8)
         
     | 
| 275 | 
         
            +
                outimg = Image.fromarray(outimg)
         
     | 
| 276 | 
         
            +
                torch.cuda.empty_cache()
         
     | 
| 277 | 
         
            +
                return outimg
         
     | 
| 278 | 
         
            +
             
     | 
| 279 | 
         
            +
            def to_screen_image(input_image):
         
     | 
| 280 | 
         
            +
                global opt
         
     | 
| 281 | 
         
            +
                global ScreenModel
         
     | 
| 282 | 
         
            +
                input_image = input_image.convert('RGB')
         
     | 
| 283 | 
         
            +
                input_image = get_ScreenVAE_input(input_image, opt)
         
     | 
| 284 | 
         
            +
                h = input_image['h']
         
     | 
| 285 | 
         
            +
                w = input_image['w']
         
     | 
| 286 | 
         
            +
                ScreenModel.set_input(input_image)
         
     | 
| 287 | 
         
            +
                fake_B, fake_B2, SCR = ScreenModel.forward(AtoB=True)
         
     | 
| 288 | 
         
            +
                images=fake_B2[:,:,:h,:w]
         
     | 
| 289 | 
         
            +
                im = util.tensor2im(images)
         
     | 
| 290 | 
         
            +
                image_pil = Image.fromarray(im)
         
     | 
| 291 | 
         
            +
                torch.cuda.empty_cache()
         
     | 
| 292 | 
         
            +
                return image_pil
         
     | 
| 293 | 
         
            +
             
     | 
| 294 | 
         
            +
            def extract_line_image(query_image_, input_style, resolution):
         
     | 
| 295 | 
         
            +
                if resolution == "640x640":
         
     | 
| 296 | 
         
            +
                    tar_width = 640
         
     | 
| 297 | 
         
            +
                    tar_height = 640
         
     | 
| 298 | 
         
            +
                elif resolution == "512x800":
         
     | 
| 299 | 
         
            +
                    tar_width = 512
         
     | 
| 300 | 
         
            +
                    tar_height = 800
         
     | 
| 301 | 
         
            +
                elif resolution == "800x512":
         
     | 
| 302 | 
         
            +
                    tar_width = 800
         
     | 
| 303 | 
         
            +
                    tar_height = 512
         
     | 
| 304 | 
         
            +
                else:
         
     | 
| 305 | 
         
            +
                    gr.Info("Unsupported resolution")
         
     | 
| 306 | 
         
            +
             
     | 
| 307 | 
         
            +
                query_image = process_image(query_image_, int(tar_width*1.5), int(tar_height*1.5))
         
     | 
| 308 | 
         
            +
                if input_style == "GrayImage(ScreenStyle)":
         
     | 
| 309 | 
         
            +
                    extracted_line = to_screen_image(query_image)
         
     | 
| 310 | 
         
            +
                    extracted_line = Image.blend(extracted_line.convert('L').convert('RGB'), query_image.convert('L').convert('RGB'), 0.5)
         
     | 
| 311 | 
         
            +
                    input_context = extracted_line
         
     | 
| 312 | 
         
            +
                elif input_style == "Sketch":
         
     | 
| 313 | 
         
            +
                    query_image = query_image.convert('L').convert('RGB')
         
     | 
| 314 | 
         
            +
                    extracted_line = extract_lines(query_image)
         
     | 
| 315 | 
         
            +
                    extracted_line = extracted_line.convert('L').convert('RGB')
         
     | 
| 316 | 
         
            +
                    input_context = extracted_line
         
     | 
| 317 | 
         
            +
                torch.cuda.empty_cache()
         
     | 
| 318 | 
         
            +
                return input_context, extracted_line, input_context  
         
     | 
| 319 | 
         
            +
             
     | 
| 320 | 
         
            +
            def colorize_image(VAE_input, input_context, reference_images, resolution, seed, input_style, num_inference_steps):
         
     | 
| 321 | 
         
            +
                if VAE_input is None or input_context is None:
         
     | 
| 322 | 
         
            +
                    gr.Info("Please preprocess the image first")
         
     | 
| 323 | 
         
            +
                    raise ValueError("Please preprocess the image first")
         
     | 
| 324 | 
         
            +
                global cur_input_style
         
     | 
| 325 | 
         
            +
                global pipeline
         
     | 
| 326 | 
         
            +
                global MultiResNetModel
         
     | 
| 327 | 
         
            +
                if input_style != cur_input_style:
         
     | 
| 328 | 
         
            +
                    gr.Info(f"Loading {input_style} model...")
         
     | 
| 329 | 
         
            +
                    load_ckpt(input_style)
         
     | 
| 330 | 
         
            +
                    cur_input_style = input_style
         
     | 
| 331 | 
         
            +
                    gr.Info(f"{input_style} model loaded")
         
     | 
| 332 | 
         
            +
                reference_images = process_multi_images(reference_images)
         
     | 
| 333 | 
         
            +
                fix_random_seeds(seed)
         
     | 
| 334 | 
         
            +
                if resolution == "640x640":
         
     | 
| 335 | 
         
            +
                    tar_width = 640
         
     | 
| 336 | 
         
            +
                    tar_height = 640
         
     | 
| 337 | 
         
            +
                elif resolution == "512x800":
         
     | 
| 338 | 
         
            +
                    tar_width = 512
         
     | 
| 339 | 
         
            +
                    tar_height = 800
         
     | 
| 340 | 
         
            +
                elif resolution == "800x512":
         
     | 
| 341 | 
         
            +
                    tar_width = 800
         
     | 
| 342 | 
         
            +
                    tar_height = 512
         
     | 
| 343 | 
         
            +
                else:
         
     | 
| 344 | 
         
            +
                    gr.Info("Unsupported resolution")
         
     | 
| 345 | 
         
            +
                validation_mask = Image.open('./assets/mask.png').convert('RGB').resize((tar_width*2, tar_height*2))
         
     | 
| 346 | 
         
            +
                gr.Info("Image retrieval in progress...")
         
     | 
| 347 | 
         
            +
                query_image_bw = process_image(input_context, int(tar_width), int(tar_height))
         
     | 
| 348 | 
         
            +
                query_image = query_image_bw.convert('RGB')
         
     | 
| 349 | 
         
            +
                query_image_vae = process_image(VAE_input, int(tar_width*1.5), int(tar_height*1.5))
         
     | 
| 350 | 
         
            +
                reference_images = [process_image(ref_image, tar_width, tar_height) for ref_image in reference_images]
         
     | 
| 351 | 
         
            +
                query_patches_pil = process_image_Q_varres(query_image, tar_width, tar_height)
         
     | 
| 352 | 
         
            +
                reference_patches_pil = []
         
     | 
| 353 | 
         
            +
                for reference_image in reference_images:
         
     | 
| 354 | 
         
            +
                    reference_patches_pil += process_image_ref_varres(reference_image, tar_width, tar_height)
         
     | 
| 355 | 
         
            +
                combined_image = None
         
     | 
| 356 | 
         
            +
                with torch.no_grad():
         
     | 
| 357 | 
         
            +
                    clip_img = image_processor(images=query_patches_pil, return_tensors="pt").pixel_values.to(image_encoder.device, dtype=image_encoder.dtype)
         
     | 
| 358 | 
         
            +
                    query_embeddings = image_encoder(clip_img).image_embeds
         
     | 
| 359 | 
         
            +
                    reference_patches_pil_gray = [rimg.convert('RGB').convert('RGB') for rimg in reference_patches_pil]
         
     | 
| 360 | 
         
            +
                    clip_img = image_processor(images=reference_patches_pil_gray, return_tensors="pt").pixel_values.to(image_encoder.device, dtype=image_encoder.dtype)
         
     | 
| 361 | 
         
            +
                    reference_embeddings = image_encoder(clip_img).image_embeds
         
     | 
| 362 | 
         
            +
                    cosine_similarities = F.cosine_similarity(query_embeddings.unsqueeze(1), reference_embeddings.unsqueeze(0), dim=-1)
         
     | 
| 363 | 
         
            +
                    sorted_indices = torch.argsort(cosine_similarities, descending=True, dim=1).tolist()
         
     | 
| 364 | 
         
            +
                    top_k = 3
         
     | 
| 365 | 
         
            +
                    top_k_indices = [cur_sortlist[:top_k] for cur_sortlist in sorted_indices]
         
     | 
| 366 | 
         
            +
                    combined_image = Image.new('RGB', (tar_width * 2, tar_height * 2), 'white')
         
     | 
| 367 | 
         
            +
                    combined_image.paste(query_image_bw.resize((tar_width, tar_height)), (tar_width//2, tar_height//2))
         
     | 
| 368 | 
         
            +
                    idx_table = {0:[(1,0), (0,1), (0,0)], 1:[(1,3), (0,2),(0,3)], 2:[(2,0),(3,1), (3,0)], 3:[(2,3), (3,2),(3,3)]}
         
     | 
| 369 | 
         
            +
                    for i in range(2):
         
     | 
| 370 | 
         
            +
                        for j in range(2):
         
     | 
| 371 | 
         
            +
                            idx_list = idx_table[i * 2 + j]
         
     | 
| 372 | 
         
            +
                            for k in range(top_k):
         
     | 
| 373 | 
         
            +
                                ref_index = top_k_indices[i * 2 + j][k]
         
     | 
| 374 | 
         
            +
                                idx_y = idx_list[k][0]
         
     | 
| 375 | 
         
            +
                                idx_x = idx_list[k][1]
         
     | 
| 376 | 
         
            +
                                combined_image.paste(reference_patches_pil[ref_index].resize((tar_width//2-2, tar_height//2-2)), (tar_width//2 * idx_x + 1, tar_height//2 * idx_y + 1))
         
     | 
| 377 | 
         
            +
                gr.Info("Model inference in progress...")
         
     | 
| 378 | 
         
            +
                generator = torch.Generator(device='cuda').manual_seed(seed)
         
     | 
| 379 | 
         
            +
                image = pipeline(
         
     | 
| 380 | 
         
            +
                    "manga", cond_image=combined_image, cond_mask=validation_mask, num_inference_steps=num_inference_steps, generator=generator
         
     | 
| 381 | 
         
            +
                ).images[0]
         
     | 
| 382 | 
         
            +
                gr.Info("Post-processing image...")
         
     | 
| 383 | 
         
            +
                with torch.no_grad():
         
     | 
| 384 | 
         
            +
                    width, height = image.size
         
     | 
| 385 | 
         
            +
                    new_width = width // 2
         
     | 
| 386 | 
         
            +
                    new_height = height // 2
         
     | 
| 387 | 
         
            +
                    left = (width - new_width) // 2
         
     | 
| 388 | 
         
            +
                    top = (height - new_height) // 2
         
     | 
| 389 | 
         
            +
                    right = left + new_width
         
     | 
| 390 | 
         
            +
                    bottom = top + new_height
         
     | 
| 391 | 
         
            +
                    center_crop = image.crop((left, top, right, bottom))
         
     | 
| 392 | 
         
            +
                    up_img = center_crop.resize(query_image_vae.size)
         
     | 
| 393 | 
         
            +
                    test_low_color = transform(up_img).unsqueeze(0).to('cuda', dtype=weight_dtype)
         
     | 
| 394 | 
         
            +
                    query_image_vae = transform(query_image_vae).unsqueeze(0).to('cuda', dtype=weight_dtype)
         
     | 
| 395 | 
         
            +
             
     | 
| 396 | 
         
            +
                    h_color, hidden_list_color = pipeline.vae._encode(test_low_color,return_dict = False, hidden_flag = True)
         
     | 
| 397 | 
         
            +
                    h_bw, hidden_list_bw = pipeline.vae._encode(query_image_vae, return_dict = False, hidden_flag = True)
         
     | 
| 398 | 
         
            +
             
     | 
| 399 | 
         
            +
                    hidden_list_double = [torch.cat((hidden_list_color[hidden_idx], hidden_list_bw[hidden_idx]), dim = 1) for hidden_idx in range(len(hidden_list_color))]
         
     | 
| 400 | 
         
            +
             
     | 
| 401 | 
         
            +
             
     | 
| 402 | 
         
            +
                    hidden_list = MultiResNetModel(hidden_list_double)
         
     | 
| 403 | 
         
            +
                    output = pipeline.vae._decode(h_color.sample(),return_dict = False, hidden_list = hidden_list)[0]
         
     | 
| 404 | 
         
            +
             
     | 
| 405 | 
         
            +
                    output[output > 1] = 1
         
     | 
| 406 | 
         
            +
                    output[output < -1] = -1
         
     | 
| 407 | 
         
            +
                    high_res_image = Image.fromarray(((output[0] * 0.5 + 0.5).permute(1, 2, 0).detach().cpu().numpy() * 255).astype(np.uint8)).convert("RGB")
         
     | 
| 408 | 
         
            +
                gr.Info("Colorization complete!")
         
     | 
| 409 | 
         
            +
                torch.cuda.empty_cache()
         
     | 
| 410 | 
         
            +
                return high_res_image, up_img, image, query_image_bw
         
     | 
| 411 | 
         
            +
             
     | 
| 412 | 
         
            +
            with gr.Blocks() as demo:
         
     | 
| 413 | 
         
            +
                gr.HTML(
         
     | 
| 414 | 
         
            +
                """
         
     | 
| 415 | 
         
            +
            <div style="text-align: center;">
         
     | 
| 416 | 
         
            +
                <h1 style="text-align: center; font-size: 3em;">🎨 ColorFlow:</h1>
         
     | 
| 417 | 
         
            +
                <h3 style="text-align: center; font-size: 1.8em;">Retrieval-Augmented Image Sequence Colorization</h3>
         
     | 
| 418 | 
         
            +
                <p style="text-align: center; font-weight: bold;">
         
     | 
| 419 | 
         
            +
                    <a href="https://zhuang2002.github.io/ColorFlow/">Project Page</a> | 
         
     | 
| 420 | 
         
            +
                    <a href="https://arxiv.org/abs/">ArXiv Preprint</a> | 
         
     | 
| 421 | 
         
            +
                    <a href="https://github.com/TencentARC/ColorFlow">GitHub Repository</a>
         
     | 
| 422 | 
         
            +
                </p>
         
     | 
| 423 | 
         
            +
                <p style="text-align: center; font-weight: bold;">
         
     | 
| 424 | 
         
            +
                    NOTE: Each time you switch the input style, the corresponding model will be reloaded, which may take some time. Please be patient.
         
     | 
| 425 | 
         
            +
                </p>
         
     | 
| 426 | 
         
            +
                <p style="text-align: left; font-size: 1.1em;">
         
     | 
| 427 | 
         
            +
                    Welcome to the demo of <strong>ColorFlow</strong>. Follow the steps below to explore the capabilities of our model:
         
     | 
| 428 | 
         
            +
                </p>
         
     | 
| 429 | 
         
            +
            </div>
         
     | 
| 430 | 
         
            +
            <div style="text-align: left; margin: 0 auto;">
         
     | 
| 431 | 
         
            +
                <ol style="font-size: 1.1em;">
         
     | 
| 432 | 
         
            +
                    <li>Choose input style: GrayImage(ScreenStyle) or Sketch.</li>
         
     | 
| 433 | 
         
            +
                    <li>Upload your image: Use the 'Upload' button to select the image you want to colorize.</li>
         
     | 
| 434 | 
         
            +
                    <li>Preprocess the image: Click the 'Preprocess' button to decolorize the image.</li>
         
     | 
| 435 | 
         
            +
                    <li>Upload reference images: Upload multiple reference images to guide the colorization.</li>
         
     | 
| 436 | 
         
            +
                    <li>Set sampling parameters (optional): Adjust the settings and click the <b>Colorize</b> button.</li>
         
     | 
| 437 | 
         
            +
                </ol>
         
     | 
| 438 | 
         
            +
                <p>
         
     | 
| 439 | 
         
            +
                    ⏱️ <b>ZeroGPU Time Limit</b>: Hugging Face ZeroGPU has an inference time limit of 180 seconds. You may need to log in with a free account to use this demo. Large sampling steps might lead to timeout (GPU Abort). In that case, please consider logging in with a Pro account or running it on your local machine.
         
     | 
| 440 | 
         
            +
                </p>
         
     | 
| 441 | 
         
            +
            </div>
         
     | 
| 442 | 
         
            +
            <div style="text-align: center;">
         
     | 
| 443 | 
         
            +
                <p style="text-align: center; font-weight: bold;">
         
     | 
| 444 | 
         
            +
                    注意:每次切换输入样式时,相应的模型将被重新加载,可能需要一些时间。请耐心等待。
         
     | 
| 445 | 
         
            +
                </p>
         
     | 
| 446 | 
         
            +
                <p style="text-align: left; font-size: 1.1em;">
         
     | 
| 447 | 
         
            +
                    欢迎使用 <strong>ColorFlow</strong> 演示。请按照以下步骤探索我们模型的能力:
         
     | 
| 448 | 
         
            +
                </p>
         
     | 
| 449 | 
         
            +
            </div>
         
     | 
| 450 | 
         
            +
            <div style="text-align: left; margin: 0 auto;">
         
     | 
| 451 | 
         
            +
                <ol style="font-size: 1.1em;">
         
     | 
| 452 | 
         
            +
                    <li>选择输入样式:灰度图(ScreenStyle)、线稿。</li>
         
     | 
| 453 | 
         
            +
                    <li>上传您的图像:使用“上传”按钮选择要上色的图像。</li>
         
     | 
| 454 | 
         
            +
                    <li>预处理图像:点击“预处理”按钮以去色图像。</li>
         
     | 
| 455 | 
         
            +
                    <li>上传参考图像:上传多张参考图像以指导上色。</li>
         
     | 
| 456 | 
         
            +
                    <li>设置采样参数(可选):调整设置并点击 <b>上色</b> 按钮。</li>
         
     | 
| 457 | 
         
            +
                </ol>
         
     | 
| 458 | 
         
            +
                <p>
         
     | 
| 459 | 
         
            +
                    ⏱️ <b>ZeroGPU时间限制</b>:Hugging Face ZeroGPU 的推理时间限制为 180 秒。您可能需要使用免费帐户登录以使用此演示。大采样步骤可能会导致超时(GPU 中止)。在这种情况下,请考虑使用专业帐户登录或在本地计算机上运行。
         
     | 
| 460 | 
         
            +
                </p>
         
     | 
| 461 | 
         
            +
            </div>
         
     | 
| 462 | 
         
            +
                """
         
     | 
| 463 | 
         
            +
            )
         
     | 
| 464 | 
         
            +
                VAE_input = gr.State()
         
     | 
| 465 | 
         
            +
                input_context = gr.State()
         
     | 
| 466 | 
         
            +
                # example_loading = gr.State(value=None)
         
     | 
| 467 | 
         
            +
                
         
     | 
| 468 | 
         
            +
                with gr.Column():
         
     | 
| 469 | 
         
            +
                    with gr.Row():
         
     | 
| 470 | 
         
            +
                        input_style = gr.Radio(["GrayImage(ScreenStyle)", "Sketch"], label="Input Style", value="GrayImage(ScreenStyle)")
         
     | 
| 471 | 
         
            +
                    with gr.Row():
         
     | 
| 472 | 
         
            +
                        with gr.Column():
         
     | 
| 473 | 
         
            +
                            input_image = gr.Image(type="pil", label="Image to Colorize")
         
     | 
| 474 | 
         
            +
                            resolution = gr.Radio(["640x640", "512x800", "800x512"], label="Select Resolution(Width*Height)", value="640x640")
         
     | 
| 475 | 
         
            +
                            extract_button = gr.Button("Preprocess (Decolorize)")
         
     | 
| 476 | 
         
            +
                        extracted_image = gr.Image(type="pil", label="Decolorized Result")
         
     | 
| 477 | 
         
            +
                    with gr.Row():
         
     | 
| 478 | 
         
            +
                        reference_images = gr.Files(label="Reference Images (Upload multiple)", file_count="multiple")
         
     | 
| 479 | 
         
            +
                        with gr.Column():
         
     | 
| 480 | 
         
            +
                            output_gallery = gr.Gallery(label="Colorization Results", type="pil")
         
     | 
| 481 | 
         
            +
                            seed = gr.Slider(label="Random Seed", minimum=0, maximum=100000, value=0, step=1)
         
     | 
| 482 | 
         
            +
                            num_inference_steps = gr.Slider(label="Inference Steps", minimum=4, maximum=100, value=10, step=1)
         
     | 
| 483 | 
         
            +
                            colorize_button = gr.Button("Colorize")
         
     | 
| 484 | 
         
            +
                
         
     | 
| 485 | 
         
            +
                # progress_text = gr.Textbox(label="Progress", interactive=False)
         
     | 
| 486 | 
         
            +
                
         
     | 
| 487 | 
         
            +
                
         
     | 
| 488 | 
         
            +
                extract_button.click(
         
     | 
| 489 | 
         
            +
                    extract_line_image, 
         
     | 
| 490 | 
         
            +
                    inputs=[input_image, input_style, resolution], 
         
     | 
| 491 | 
         
            +
                    outputs=[extracted_image, VAE_input, input_context]
         
     | 
| 492 | 
         
            +
                )
         
     | 
| 493 | 
         
            +
                colorize_button.click(
         
     | 
| 494 | 
         
            +
                    colorize_image, 
         
     | 
| 495 | 
         
            +
                    inputs=[VAE_input, input_context, reference_images, resolution, seed, input_style, num_inference_steps], 
         
     | 
| 496 | 
         
            +
                    outputs=output_gallery
         
     | 
| 497 | 
         
            +
                )
         
     | 
| 498 | 
         
            +
             
     | 
| 499 | 
         
            +
                with gr.Column():
         
     | 
| 500 | 
         
            +
                    gr.Markdown("### Quick Examples")
         
     | 
| 501 | 
         
            +
                    gr.Examples(
         
     | 
| 502 | 
         
            +
                        examples=examples,
         
     | 
| 503 | 
         
            +
                        inputs=[input_image, reference_images, input_style, resolution, seed, num_inference_steps],
         
     | 
| 504 | 
         
            +
                        label="Examples",
         
     | 
| 505 | 
         
            +
                        examples_per_page=6,
         
     | 
| 506 | 
         
            +
                    )
         
     | 
| 507 | 
         
            +
            demo.launch(server_name="0.0.0.0", server_port=22348)
         
     | 
    	
        assets/example_0/input.jpg
    ADDED
    
    
											 
									 | 
									
								
    	
        assets/example_0/ref1.jpg
    ADDED
    
    
											 
									 | 
									
								
    	
        assets/example_1/input.jpg
    ADDED
    
    
											 
									 | 
									
								
    	
        assets/example_1/ref1.jpg
    ADDED
    
    
											 
									 | 
									
								
    	
        assets/example_1/ref2.jpg
    ADDED
    
    
											 
									 | 
									
								
    	
        assets/example_1/ref3.jpg
    ADDED
    
    
											 
									 | 
									
								
    	
        assets/example_2/input.png
    ADDED
    
    
											 
									 | 
									
								
    	
        assets/example_2/ref1.png
    ADDED
    
    
											 
									 | 
									
								
    	
        assets/example_2/ref2.png
    ADDED
    
    
											 
									 | 
									
								
    	
        assets/example_2/ref3.png
    ADDED
    
    
											 
									 | 
									
								
    	
        assets/example_3/input.png
    ADDED
    
    
											 
									 | 
									
								
    	
        assets/example_3/ref1.png
    ADDED
    
    
											 
									 | 
									
								
    	
        assets/example_3/ref2.png
    ADDED
    
    
											 
									 | 
									
								
    	
        assets/example_3/ref3.png
    ADDED
    
    
											 
									 | 
									
								
    	
        assets/example_4/input.jpg
    ADDED
    
    
											 
									 | 
									
								
    	
        assets/example_4/ref1.jpg
    ADDED
    
    
											 
									 | 
									
								
    	
        assets/example_4/ref2.jpg
    ADDED
    
    
											 
									 | 
									
								
    	
        assets/example_4/ref3.jpg
    ADDED
    
    
											 
									 | 
									
								
    	
        assets/example_5/input.png
    ADDED
    
    
											 
									 | 
									
								
    	
        assets/example_5/ref1.png
    ADDED
    
    
											 
									 | 
									
								
    	
        assets/example_5/ref2.png
    ADDED
    
    
											 
									 | 
									
								
    	
        assets/example_5/ref3.png
    ADDED
    
    
											 
									 | 
									
								
    	
        assets/mask.png
    ADDED
    
    
											 
									 | 
									
								
    	
        diffusers/.github/ISSUE_TEMPLATE/bug-report.yml
    ADDED
    
    | 
         @@ -0,0 +1,110 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            name: "\U0001F41B Bug Report"
         
     | 
| 2 | 
         
            +
            description: Report a bug on Diffusers
         
     | 
| 3 | 
         
            +
            labels: [ "bug" ]
         
     | 
| 4 | 
         
            +
            body:
         
     | 
| 5 | 
         
            +
              - type: markdown
         
     | 
| 6 | 
         
            +
                attributes:
         
     | 
| 7 | 
         
            +
                  value: |
         
     | 
| 8 | 
         
            +
                    Thanks a lot for taking the time to file this issue 🤗.
         
     | 
| 9 | 
         
            +
                    Issues do not only help to improve the library, but also publicly document common problems, questions, workflows for the whole community!
         
     | 
| 10 | 
         
            +
                    Thus, issues are of the same importance as pull requests when contributing to this library ❤️.
         
     | 
| 11 | 
         
            +
                    In order to make your issue as **useful for the community as possible**, let's try to stick to some simple guidelines:
         
     | 
| 12 | 
         
            +
                    - 1. Please try to be as precise and concise as possible.
         
     | 
| 13 | 
         
            +
                         *Give your issue a fitting title. Assume that someone which very limited knowledge of Diffusers can understand your issue. Add links to the source code, documentation other issues, pull requests etc...*
         
     | 
| 14 | 
         
            +
                    - 2. If your issue is about something not working, **always** provide a reproducible code snippet. The reader should be able to reproduce your issue by **only copy-pasting your code snippet into a Python shell**.
         
     | 
| 15 | 
         
            +
                         *The community cannot solve your issue if it cannot reproduce it. If your bug is related to training, add your training script and make everything needed to train public. Otherwise, just add a simple Python code snippet.*
         
     | 
| 16 | 
         
            +
                    - 3. Add the **minimum** amount of code / context that is needed to understand, reproduce your issue.
         
     | 
| 17 | 
         
            +
                         *Make the life of maintainers easy. `diffusers` is getting many issues every day. Make sure your issue is about one bug and one bug only. Make sure you add only the context, code needed to understand your issues - nothing more. Generally, every issue is a way of documenting this library, try to make it a good documentation entry.*
         
     | 
| 18 | 
         
            +
                    - 4. For issues related to community pipelines (i.e., the pipelines located in the `examples/community` folder), please tag the author of the pipeline in your issue thread as those pipelines are not maintained.
         
     | 
| 19 | 
         
            +
              - type: markdown
         
     | 
| 20 | 
         
            +
                attributes:
         
     | 
| 21 | 
         
            +
                  value: |
         
     | 
| 22 | 
         
            +
                    For more in-detail information on how to write good issues you can have a look [here](https://huggingface.co/course/chapter8/5?fw=pt).
         
     | 
| 23 | 
         
            +
              - type: textarea
         
     | 
| 24 | 
         
            +
                id: bug-description
         
     | 
| 25 | 
         
            +
                attributes:
         
     | 
| 26 | 
         
            +
                  label: Describe the bug
         
     | 
| 27 | 
         
            +
                  description: A clear and concise description of what the bug is. If you intend to submit a pull request for this issue, tell us in the description. Thanks!
         
     | 
| 28 | 
         
            +
                  placeholder: Bug description
         
     | 
| 29 | 
         
            +
                validations:
         
     | 
| 30 | 
         
            +
                  required: true
         
     | 
| 31 | 
         
            +
              - type: textarea
         
     | 
| 32 | 
         
            +
                id: reproduction
         
     | 
| 33 | 
         
            +
                attributes:
         
     | 
| 34 | 
         
            +
                  label: Reproduction
         
     | 
| 35 | 
         
            +
                  description: Please provide a minimal reproducible code which we can copy/paste and reproduce the issue.
         
     | 
| 36 | 
         
            +
                  placeholder: Reproduction
         
     | 
| 37 | 
         
            +
                validations:
         
     | 
| 38 | 
         
            +
                  required: true
         
     | 
| 39 | 
         
            +
              - type: textarea
         
     | 
| 40 | 
         
            +
                id: logs
         
     | 
| 41 | 
         
            +
                attributes:
         
     | 
| 42 | 
         
            +
                  label: Logs
         
     | 
| 43 | 
         
            +
                  description: "Please include the Python logs if you can."
         
     | 
| 44 | 
         
            +
                  render: shell
         
     | 
| 45 | 
         
            +
              - type: textarea
         
     | 
| 46 | 
         
            +
                id: system-info
         
     | 
| 47 | 
         
            +
                attributes:
         
     | 
| 48 | 
         
            +
                  label: System Info
         
     | 
| 49 | 
         
            +
                  description: Please share your system info with us. You can run the command `diffusers-cli env` and copy-paste its output below.
         
     | 
| 50 | 
         
            +
                  placeholder: Diffusers version, platform, Python version, ...
         
     | 
| 51 | 
         
            +
                validations:
         
     | 
| 52 | 
         
            +
                  required: true
         
     | 
| 53 | 
         
            +
              - type: textarea
         
     | 
| 54 | 
         
            +
                id: who-can-help
         
     | 
| 55 | 
         
            +
                attributes:
         
     | 
| 56 | 
         
            +
                  label: Who can help?
         
     | 
| 57 | 
         
            +
                  description: |
         
     | 
| 58 | 
         
            +
                    Your issue will be replied to more quickly if you can figure out the right person to tag with @.
         
     | 
| 59 | 
         
            +
                    If you know how to use git blame, that is the easiest way, otherwise, here is a rough guide of **who to tag**.
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
                    All issues are read by one of the core maintainers, so if you don't know who to tag, just leave this blank and
         
     | 
| 62 | 
         
            +
                    a core maintainer will ping the right person.
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
                    Please tag a maximum of 2 people.
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
                    Questions on DiffusionPipeline (Saving, Loading, From pretrained, ...): @sayakpaul @DN6
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
                    Questions on pipelines:
         
     | 
| 69 | 
         
            +
                    - Stable Diffusion @yiyixuxu @asomoza
         
     | 
| 70 | 
         
            +
                    - Stable Diffusion XL @yiyixuxu @sayakpaul @DN6
         
     | 
| 71 | 
         
            +
                    - Stable Diffusion 3: @yiyixuxu @sayakpaul @DN6 @asomoza
         
     | 
| 72 | 
         
            +
                    - Kandinsky @yiyixuxu
         
     | 
| 73 | 
         
            +
                    - ControlNet @sayakpaul @yiyixuxu @DN6
         
     | 
| 74 | 
         
            +
                    - T2I Adapter @sayakpaul @yiyixuxu @DN6
         
     | 
| 75 | 
         
            +
                    - IF @DN6
         
     | 
| 76 | 
         
            +
                    - Text-to-Video / Video-to-Video @DN6 @a-r-r-o-w
         
     | 
| 77 | 
         
            +
                    - Wuerstchen @DN6
         
     | 
| 78 | 
         
            +
                    - Other: @yiyixuxu @DN6
         
     | 
| 79 | 
         
            +
                    - Improving generation quality: @asomoza
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
                    Questions on models:
         
     | 
| 82 | 
         
            +
                    - UNet @DN6 @yiyixuxu @sayakpaul
         
     | 
| 83 | 
         
            +
                    - VAE @sayakpaul @DN6 @yiyixuxu
         
     | 
| 84 | 
         
            +
                    - Transformers/Attention @DN6 @yiyixuxu @sayakpaul
         
     | 
| 85 | 
         
            +
             
     | 
| 86 | 
         
            +
                    Questions on single file checkpoints: @DN6
         
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
                    Questions on Schedulers: @yiyixuxu
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
                    Questions on LoRA: @sayakpaul
         
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
                    Questions on Textual Inversion: @sayakpaul
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
                    Questions on Training:
         
     | 
| 95 | 
         
            +
                    - DreamBooth @sayakpaul
         
     | 
| 96 | 
         
            +
                    - Text-to-Image Fine-tuning @sayakpaul
         
     | 
| 97 | 
         
            +
                    - Textual Inversion @sayakpaul
         
     | 
| 98 | 
         
            +
                    - ControlNet @sayakpaul
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
                    Questions on Tests: @DN6 @sayakpaul @yiyixuxu
         
     | 
| 101 | 
         
            +
             
     | 
| 102 | 
         
            +
                    Questions on Documentation: @stevhliu
         
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
                    Questions on JAX- and MPS-related things: @pcuenca
         
     | 
| 105 | 
         
            +
             
     | 
| 106 | 
         
            +
                    Questions on audio pipelines: @sanchit-gandhi
         
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
             
     | 
| 109 | 
         
            +
             
     | 
| 110 | 
         
            +
                  placeholder: "@Username ..."
         
     | 
    	
        diffusers/.github/ISSUE_TEMPLATE/config.yml
    ADDED
    
    | 
         @@ -0,0 +1,4 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            contact_links:
         
     | 
| 2 | 
         
            +
              - name: Questions / Discussions
         
     | 
| 3 | 
         
            +
                url: https://github.com/huggingface/diffusers/discussions
         
     | 
| 4 | 
         
            +
                about: General usage questions and community discussions
         
     | 
    	
        diffusers/.github/ISSUE_TEMPLATE/feature_request.md
    ADDED
    
    | 
         @@ -0,0 +1,20 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            ---
         
     | 
| 2 | 
         
            +
            name: "\U0001F680 Feature Request"
         
     | 
| 3 | 
         
            +
            about: Suggest an idea for this project
         
     | 
| 4 | 
         
            +
            title: ''
         
     | 
| 5 | 
         
            +
            labels: ''
         
     | 
| 6 | 
         
            +
            assignees: ''
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            ---
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            **Is your feature request related to a problem? Please describe.**
         
     | 
| 11 | 
         
            +
            A clear and concise description of what the problem is. Ex. I'm always frustrated when [...].
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            **Describe the solution you'd like.**
         
     | 
| 14 | 
         
            +
            A clear and concise description of what you want to happen.
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            **Describe alternatives you've considered.**
         
     | 
| 17 | 
         
            +
            A clear and concise description of any alternative solutions or features you've considered.
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            **Additional context.**
         
     | 
| 20 | 
         
            +
            Add any other context or screenshots about the feature request here.
         
     | 
    	
        diffusers/.github/ISSUE_TEMPLATE/feedback.md
    ADDED
    
    | 
         @@ -0,0 +1,12 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            ---
         
     | 
| 2 | 
         
            +
            name: "💬 Feedback about API Design"
         
     | 
| 3 | 
         
            +
            about: Give feedback about the current API design
         
     | 
| 4 | 
         
            +
            title: ''
         
     | 
| 5 | 
         
            +
            labels: ''
         
     | 
| 6 | 
         
            +
            assignees: ''
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            ---
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            **What API design would you like to have changed or added to the library? Why?**
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            **What use case would this enable or better enable? Can you give us a code example?**
         
     | 
    	
        diffusers/.github/ISSUE_TEMPLATE/new-model-addition.yml
    ADDED
    
    | 
         @@ -0,0 +1,31 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            name: "\U0001F31F New Model/Pipeline/Scheduler Addition"
         
     | 
| 2 | 
         
            +
            description: Submit a proposal/request to implement a new diffusion model/pipeline/scheduler
         
     | 
| 3 | 
         
            +
            labels: [ "New model/pipeline/scheduler" ]
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            body:
         
     | 
| 6 | 
         
            +
              - type: textarea
         
     | 
| 7 | 
         
            +
                id: description-request
         
     | 
| 8 | 
         
            +
                validations:
         
     | 
| 9 | 
         
            +
                  required: true
         
     | 
| 10 | 
         
            +
                attributes:
         
     | 
| 11 | 
         
            +
                  label: Model/Pipeline/Scheduler description
         
     | 
| 12 | 
         
            +
                  description: |
         
     | 
| 13 | 
         
            +
                    Put any and all important information relative to the model/pipeline/scheduler
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
              - type: checkboxes
         
     | 
| 16 | 
         
            +
                id: information-tasks
         
     | 
| 17 | 
         
            +
                attributes:
         
     | 
| 18 | 
         
            +
                  label: Open source status
         
     | 
| 19 | 
         
            +
                  description: |
         
     | 
| 20 | 
         
            +
                      Please note that if the model implementation isn't available or if the weights aren't open-source, we are less likely to implement it in `diffusers`.
         
     | 
| 21 | 
         
            +
                  options:
         
     | 
| 22 | 
         
            +
                    - label: "The model implementation is available."
         
     | 
| 23 | 
         
            +
                    - label: "The model weights are available (Only relevant if addition is not a scheduler)."
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
              - type: textarea
         
     | 
| 26 | 
         
            +
                id: additional-info
         
     | 
| 27 | 
         
            +
                attributes:
         
     | 
| 28 | 
         
            +
                  label: Provide useful links for the implementation
         
     | 
| 29 | 
         
            +
                  description: |
         
     | 
| 30 | 
         
            +
                    Please provide information regarding the implementation, the weights, and the authors.
         
     | 
| 31 | 
         
            +
                    Please mention the authors by @gh-username if you're aware of their usernames.
         
     | 
    	
        diffusers/.github/ISSUE_TEMPLATE/translate.md
    ADDED
    
    | 
         @@ -0,0 +1,29 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            ---
         
     | 
| 2 | 
         
            +
            name: 🌐 Translating a New Language?
         
     | 
| 3 | 
         
            +
            about: Start a new translation effort in your language
         
     | 
| 4 | 
         
            +
            title: '[<languageCode>] Translating docs to <languageName>'
         
     | 
| 5 | 
         
            +
            labels: WIP
         
     | 
| 6 | 
         
            +
            assignees: ''
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            ---
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            <!--
         
     | 
| 11 | 
         
            +
            Note: Please search to see if an issue already exists for the language you are trying to translate.
         
     | 
| 12 | 
         
            +
            -->
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            Hi!
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            Let's bring the documentation to all the <languageName>-speaking community 🌐.
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            Who would want to translate? Please follow the 🤗 [TRANSLATING guide](https://github.com/huggingface/diffusers/blob/main/docs/TRANSLATING.md). Here is a list of the files ready for translation. Let us know in this issue if you'd like to translate any, and we'll add your name to the list.
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            Some notes:
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
            * Please translate using an informal tone (imagine you are talking with a friend about Diffusers 🤗).
         
     | 
| 23 | 
         
            +
            * Please translate in a gender-neutral way.
         
     | 
| 24 | 
         
            +
            * Add your translations to the folder called `<languageCode>` inside the [source folder](https://github.com/huggingface/diffusers/tree/main/docs/source).
         
     | 
| 25 | 
         
            +
            * Register your translation in `<languageCode>/_toctree.yml`; please follow the order of the [English version](https://github.com/huggingface/diffusers/blob/main/docs/source/en/_toctree.yml).
         
     | 
| 26 | 
         
            +
            * Once you're finished, open a pull request and tag this issue by including #issue-number in the description, where issue-number is the number of this issue. Please ping @stevhliu for review.
         
     | 
| 27 | 
         
            +
            * 🙋 If you'd like others to help you with the translation, you can also post in the 🤗 [forums](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63).
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
            Thank you so much for your help! 🤗
         
     |