katielink commited on
Commit
8dff69d
1 Parent(s): 5204736

Initial version

Browse files
README.md ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - monai
4
+ - medical
5
+ library_name: monai
6
+ license: unknown
7
+ ---
8
+
9
+ # MedNIST GAN Hand Model
10
+
11
+ This model is a generator for creating images like the Hand category in the MedNIST dataset. It was trained as a GAN and accepts random values as inputs to produce an image output. The `train.json` file describes the training process along with the definition of the discriminator network used, and is based on the [MONAI GAN tutorials](https://github.com/Project-MONAI/tutorials/blob/main/modules/mednist_GAN_workflow_dict.ipynb).
12
+
13
+ This is a demonstration network meant to just show the training process for this sort of network with MONAI, its outputs are not particularly good and are of the same tiny size as the images in MedNIST. The training process was very short so a network with a longer training time would produce better results.
14
+
15
+ ### Downloading the Dataset
16
+
17
+ Download the dataset from [here](https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/MedNIST.tar.gz) and extract the contents to a convenient location.
18
+
19
+ The MedNIST dataset was gathered from several sets from [TCIA](https://wiki.cancerimagingarchive.net/display/Public/Data+Usage+Policies+and+Restrictions),
20
+ [the RSNA Bone Age Challenge](http://rsnachallenges.cloudapp.net/competitions/4),
21
+ and [the NIH Chest X-ray dataset](https://cloud.google.com/healthcare/docs/resources/public-datasets/nih-chest).
22
+
23
+ The dataset is kindly made available by [Dr. Bradley J. Erickson M.D., Ph.D.](https://www.mayo.edu/research/labs/radiology-informatics/overview) (Department of Radiology, Mayo Clinic)
24
+ under the Creative Commons [CC BY-SA 4.0 license](https://creativecommons.org/licenses/by-sa/4.0/).
25
+
26
+
27
+ If you use the MedNIST dataset, please acknowledge the source.
28
+
29
+ ### Training
30
+
31
+ Assuming the current directory is the bundle directory, and the dataset was extracted to the directory `./MedNIST`, the following command will train the network for 50 epochs:
32
+
33
+ ```
34
+ PYTHONPATH=./scripts python -m monai.bundle run training --meta_file configs/metadata.json --config_file configs/train.json --logging_file configs/logging.conf --bundle_root .
35
+ ```
36
+
37
+ Note that the training code relies on extra scripts in the `scripts` directory which are made accessible by changing the `PYTHONPATH` variable in this invocation. If your `PYTHONPATH` is already used for other things you will have to add `./scripts` to the variable rather than replace it.
38
+
39
+ Not also the output from the training will be placed in the `models` directory but will not overwrite the `model.pt` file that may be there already. You will have to manually rename the most recent checkpoint file to `model.pt` to use the inference script mentioned below after checking the results are correct. This saved checkpoint contains a dictionary with the generator weights stored as `model` and omits the discriminator.
40
+
41
+ Another feature in the training file is the addition of sigmoid activation to the network by modifying it's structure at runtime. This is done with a line in the `training` section calling `add_module` on a layer of the network. This works best for training although the definition of the model now doesn't strictly match what it is in the `generator` section.
42
+
43
+ The generator and discriminator networks were both trained with the `Adam` optimizer with a learning rate of 0.0002 and `betas` values `[0.5, 0.999]`. These have been emperically found to be good values for the optimizer and this GAN problem.
44
+
45
+ ### Inference
46
+
47
+ The included `inference.json` generates a set number of png samples from the network and saves these to the directory `./outputs`. The output directory can be changed by setting the `output_dir` value, and the number of samples changed by setting the `num_samples` value. The following command line assumes it is invoked in the bundle directory:
48
+
49
+ ```
50
+ python -m monai.bundle run inferring --meta_file configs/metadata.json --config_file configs/inference.json --logging_file configs/logging.conf --bundle_root .
51
+ ```
52
+
53
+ Note this script uses postprocessing to apply the sigmoid activation the model's outputs and to save the results to image files.
54
+
55
+
56
+ ### Export
57
+
58
+ The generator can be exported to a Torchscript bundle with the following:
59
+
60
+ ```
61
+ python -m monai.bundle ckpt_export network_def --filepath mednist_gan.ts --ckpt_file models/model.pt --meta_file configs/metadata.json --config_file configs/inference.json
62
+ ```
63
+
64
+ The model can be loaded without MONAI code after this operation. For example, an image can be generated from a set of random values with:
65
+
66
+ ```python
67
+ import torch
68
+ net = torch.jit.load("mednist_gan.ts")
69
+ latent = torch.rand(1,64)
70
+ img = net(latent) # (1,1,64,64)
71
+ ```
configs/inference.json ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "imports": [
3
+ "$import glob"
4
+ ],
5
+ "device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')",
6
+ "ckpt_path": "$@bundle_root + '/models/model.pt'",
7
+ "output_dir": "./outputs",
8
+ "latent_size": 64,
9
+ "num_samples": 10,
10
+ "network_def": {
11
+ "_target_": "Generator",
12
+ "latent_shape": "@latent_size",
13
+ "start_shape": [
14
+ 64,
15
+ 8,
16
+ 8
17
+ ],
18
+ "channels": [
19
+ 32,
20
+ 16,
21
+ 8,
22
+ 1
23
+ ],
24
+ "strides": [
25
+ 2,
26
+ 2,
27
+ 2,
28
+ 1
29
+ ]
30
+ },
31
+ "network": "$@network_def.to(@device)",
32
+ "dataset": {
33
+ "_target_": "Dataset",
34
+ "data": "$[torch.rand(@latent_size) for i in range(@num_samples)]"
35
+ },
36
+ "dataloader": {
37
+ "_target_": "DataLoader",
38
+ "dataset": "@dataset",
39
+ "batch_size": 1,
40
+ "shuffle": false,
41
+ "num_workers": 0
42
+ },
43
+ "inferer": {
44
+ "_target_": "SimpleInferer"
45
+ },
46
+ "postprocessing": {
47
+ "_target_": "Compose",
48
+ "transforms": [
49
+ {
50
+ "_target_": "Activationsd",
51
+ "keys": "pred",
52
+ "sigmoid": true
53
+ },
54
+ {
55
+ "_target_": "SaveImaged",
56
+ "keys": "pred",
57
+ "output_dir": "@output_dir",
58
+ "output_ext": "png",
59
+ "separate_folder": false,
60
+ "scale": 255,
61
+ "output_dtype": "$np.uint8",
62
+ "meta_key_postfix": null
63
+ }
64
+ ]
65
+ },
66
+ "handlers": [
67
+ {
68
+ "_target_": "CheckpointLoader",
69
+ "load_path": "@ckpt_path",
70
+ "load_dict": {
71
+ "model": "@network"
72
+ }
73
+ }
74
+ ],
75
+ "evaluator": {
76
+ "_target_": "SupervisedEvaluator",
77
+ "device": "@device",
78
+ "val_data_loader": "@dataloader",
79
+ "network": "@network",
80
+ "inferer": "@inferer",
81
+ "postprocessing": "@postprocessing",
82
+ "prepare_batch": "$lambda batchdata, *_,**__: (batchdata.to(@device),None,(),{})",
83
+ "val_handlers": "@handlers"
84
+ },
85
+ "inferring": [
86
+ "$@evaluator.run()"
87
+ ]
88
+ }
configs/logging.conf ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [loggers]
2
+ keys=root
3
+
4
+ [handlers]
5
+ keys=consoleHandler
6
+
7
+ [formatters]
8
+ keys=fullFormatter
9
+
10
+ [logger_root]
11
+ level=INFO
12
+ handlers=consoleHandler
13
+
14
+ [handler_consoleHandler]
15
+ class=StreamHandler
16
+ level=INFO
17
+ formatter=fullFormatter
18
+ args=(sys.stdout,)
19
+
20
+ [formatter_fullFormatter]
21
+ format=%(asctime)s - %(name)s - %(levelname)s - %(message)s
configs/metadata.json ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_generator_20220718.json",
3
+ "version": "0.1.0",
4
+ "changelog": {
5
+ "0.1.0": "Initial version"
6
+ },
7
+ "monai_version": "0.9.0",
8
+ "pytorch_version": "1.10.0",
9
+ "numpy_version": "1.21.0",
10
+ "optional_packages_version": {
11
+ "pytorch-ignite": "0.4.8",
12
+ "pillow": "8.4.0"
13
+ },
14
+ "task": "Generate random hand images from the MedNIST dataset",
15
+ "description": "This example of a GAN generator produces hand xray images like those in the MedNIST dataset",
16
+ "authors": "MONAI Team",
17
+ "copyright": "Copyright (c) MONAI Consortium",
18
+ "intended_use": "This is an example of a GAN with generator discriminator networks using MONAI, suitable for demonstration purposes only.",
19
+ "data_source": "MedNIST dataset kindly made available by Dr. Bradley J. Erickson M.D., Ph.D. (Department of Radiology, Mayo Clinic)",
20
+ "data_type": "jpeg",
21
+ "network_data_format": {
22
+ "inputs": {
23
+ "latent": {
24
+ "type": "tuples",
25
+ "format": "latent",
26
+ "num_channels": 0,
27
+ "spatial_shape": [
28
+ 64
29
+ ],
30
+ "dtype": "float32",
31
+ "value_range": [
32
+ 0,
33
+ 1
34
+ ],
35
+ "is_patch_data": false,
36
+ "channel_def": {}
37
+ }
38
+ },
39
+ "outputs": {
40
+ "pred": {
41
+ "type": "image",
42
+ "format": "magnitude",
43
+ "num_channels": 1,
44
+ "spatial_shape": [
45
+ 64,
46
+ 64
47
+ ],
48
+ "dtype": "float32",
49
+ "value_range": [
50
+ 0,
51
+ 1
52
+ ],
53
+ "is_patch_data": false,
54
+ "channel_def": {
55
+ "0": "image"
56
+ }
57
+ }
58
+ }
59
+ }
60
+ }
configs/train.json ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "imports": [
3
+ "$from functools import partial",
4
+ "$import glob",
5
+ "$from losses import discriminator_loss",
6
+ "$from losses import generator_loss"
7
+ ],
8
+ "device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')",
9
+ "ckpt_path": "$@bundle_root + '/models/model.pt'",
10
+ "dataset_dir": "./MedNIST/Hand",
11
+ "datalist": "$list(sorted(glob.glob(@dataset_dir + '/*.jpeg')))",
12
+ "latent_size": 64,
13
+ "discriminator": {
14
+ "_target_": "Discriminator",
15
+ "in_shape": [
16
+ 1,
17
+ 64,
18
+ 64
19
+ ],
20
+ "channels": [
21
+ 8,
22
+ 16,
23
+ 32,
24
+ 64,
25
+ 1
26
+ ],
27
+ "strides": [
28
+ 2,
29
+ 2,
30
+ 2,
31
+ 2,
32
+ 1
33
+ ],
34
+ "num_res_units": 1,
35
+ "kernel_size": 5
36
+ },
37
+ "generator": {
38
+ "_target_": "Generator",
39
+ "latent_shape": "@latent_size",
40
+ "start_shape": [
41
+ 64,
42
+ 8,
43
+ 8
44
+ ],
45
+ "channels": [
46
+ 32,
47
+ 16,
48
+ 8,
49
+ 1
50
+ ],
51
+ "strides": [
52
+ 2,
53
+ 2,
54
+ 2,
55
+ 1
56
+ ]
57
+ },
58
+ "dnetwork": "$@discriminator.apply(monai.networks.normal_init).to(@device)",
59
+ "gnetwork": "$@generator.apply(monai.networks.normal_init).to(@device)",
60
+ "preprocessing": {
61
+ "_target_": "Compose",
62
+ "transforms": [
63
+ {
64
+ "_target_": "LoadImaged",
65
+ "keys": "reals"
66
+ },
67
+ {
68
+ "_target_": "AddChanneld",
69
+ "keys": "reals"
70
+ },
71
+ {
72
+ "_target_": "ScaleIntensityd",
73
+ "keys": "reals"
74
+ },
75
+ {
76
+ "_target_": "RandRotated",
77
+ "keys": "reals",
78
+ "range_x": "$np.pi/12",
79
+ "prob": 0.5,
80
+ "keep_size": true
81
+ },
82
+ {
83
+ "_target_": "RandFlipd",
84
+ "keys": "reals",
85
+ "spatial_axis": 0,
86
+ "prob": 0.5
87
+ },
88
+ {
89
+ "_target_": "RandZoomd",
90
+ "keys": "reals",
91
+ "min_zoom": 0.9,
92
+ "max_zoom": 1.1,
93
+ "prob": 0.5
94
+ },
95
+ {
96
+ "_target_": "EnsureTyped",
97
+ "keys": "reals"
98
+ }
99
+ ]
100
+ },
101
+ "real_dataset": {
102
+ "_target_": "CacheDataset",
103
+ "data": "$[{'reals': i} for i in @datalist]",
104
+ "transform": "@preprocessing"
105
+ },
106
+ "real_dataloader": {
107
+ "_target_": "DataLoader",
108
+ "dataset": "@real_dataset",
109
+ "batch_size": 600,
110
+ "shuffle": true,
111
+ "num_workers": 12
112
+ },
113
+ "doptimizer": {
114
+ "_target_": "torch.optim.Adam",
115
+ "params": "$@dnetwork.parameters()",
116
+ "lr": 0.0002,
117
+ "betas": [
118
+ 0.5,
119
+ 0.999
120
+ ]
121
+ },
122
+ "goptimizer": {
123
+ "_target_": "torch.optim.Adam",
124
+ "params": "$@gnetwork.parameters()",
125
+ "lr": 0.0002,
126
+ "betas": [
127
+ 0.5,
128
+ 0.999
129
+ ]
130
+ },
131
+ "handlers": [
132
+ {
133
+ "_target_": "CheckpointSaver",
134
+ "save_dir": "$@bundle_root + '/models'",
135
+ "save_dict": {
136
+ "model": "@gnetwork"
137
+ },
138
+ "save_interval": 0,
139
+ "save_final": true,
140
+ "epoch_level": true
141
+ }
142
+ ],
143
+ "trainer": {
144
+ "_target_": "GanTrainer",
145
+ "device": "@device",
146
+ "max_epochs": 50,
147
+ "train_data_loader": "@real_dataloader",
148
+ "g_network": "@gnetwork",
149
+ "g_optimizer": "@goptimizer",
150
+ "g_loss_function": "$partial(generator_loss, disc_net=@dnetwork)",
151
+ "d_network": "@dnetwork",
152
+ "d_optimizer": "@doptimizer",
153
+ "d_loss_function": "$partial(discriminator_loss, disc_net=@dnetwork)",
154
+ "d_train_steps": 5,
155
+ "g_update_latents": true,
156
+ "latent_shape": "@latent_size",
157
+ "key_train_metric": "$None",
158
+ "train_handlers": "@handlers"
159
+ },
160
+ "training": [
161
+ "$@gnetwork.conv.add_module('activation', torch.nn.Sigmoid())",
162
+ "$@trainer.run()"
163
+ ]
164
+ }
docs/README.md ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # MedNIST GAN Hand Model
3
+
4
+ This model is a generator for creating images like the Hand category in the MedNIST dataset. It was trained as a GAN and accepts random values as inputs to produce an image output. The `train.json` file describes the training process along with the definition of the discriminator network used, and is based on the [MONAI GAN tutorials](https://github.com/Project-MONAI/tutorials/blob/main/modules/mednist_GAN_workflow_dict.ipynb).
5
+
6
+ This is a demonstration network meant to just show the training process for this sort of network with MONAI, its outputs are not particularly good and are of the same tiny size as the images in MedNIST. The training process was very short so a network with a longer training time would produce better results.
7
+
8
+ ### Downloading the Dataset
9
+
10
+ Download the dataset from [here](https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/MedNIST.tar.gz) and extract the contents to a convenient location.
11
+
12
+ The MedNIST dataset was gathered from several sets from [TCIA](https://wiki.cancerimagingarchive.net/display/Public/Data+Usage+Policies+and+Restrictions),
13
+ [the RSNA Bone Age Challenge](http://rsnachallenges.cloudapp.net/competitions/4),
14
+ and [the NIH Chest X-ray dataset](https://cloud.google.com/healthcare/docs/resources/public-datasets/nih-chest).
15
+
16
+ The dataset is kindly made available by [Dr. Bradley J. Erickson M.D., Ph.D.](https://www.mayo.edu/research/labs/radiology-informatics/overview) (Department of Radiology, Mayo Clinic)
17
+ under the Creative Commons [CC BY-SA 4.0 license](https://creativecommons.org/licenses/by-sa/4.0/).
18
+
19
+
20
+ If you use the MedNIST dataset, please acknowledge the source.
21
+
22
+ ### Training
23
+
24
+ Assuming the current directory is the bundle directory, and the dataset was extracted to the directory `./MedNIST`, the following command will train the network for 50 epochs:
25
+
26
+ ```
27
+ PYTHONPATH=./scripts python -m monai.bundle run training --meta_file configs/metadata.json --config_file configs/train.json --logging_file configs/logging.conf --bundle_root .
28
+ ```
29
+
30
+ Note that the training code relies on extra scripts in the `scripts` directory which are made accessible by changing the `PYTHONPATH` variable in this invocation. If your `PYTHONPATH` is already used for other things you will have to add `./scripts` to the variable rather than replace it.
31
+
32
+ Not also the output from the training will be placed in the `models` directory but will not overwrite the `model.pt` file that may be there already. You will have to manually rename the most recent checkpoint file to `model.pt` to use the inference script mentioned below after checking the results are correct. This saved checkpoint contains a dictionary with the generator weights stored as `model` and omits the discriminator.
33
+
34
+ Another feature in the training file is the addition of sigmoid activation to the network by modifying it's structure at runtime. This is done with a line in the `training` section calling `add_module` on a layer of the network. This works best for training although the definition of the model now doesn't strictly match what it is in the `generator` section.
35
+
36
+ The generator and discriminator networks were both trained with the `Adam` optimizer with a learning rate of 0.0002 and `betas` values `[0.5, 0.999]`. These have been emperically found to be good values for the optimizer and this GAN problem.
37
+
38
+ ### Inference
39
+
40
+ The included `inference.json` generates a set number of png samples from the network and saves these to the directory `./outputs`. The output directory can be changed by setting the `output_dir` value, and the number of samples changed by setting the `num_samples` value. The following command line assumes it is invoked in the bundle directory:
41
+
42
+ ```
43
+ python -m monai.bundle run inferring --meta_file configs/metadata.json --config_file configs/inference.json --logging_file configs/logging.conf --bundle_root .
44
+ ```
45
+
46
+ Note this script uses postprocessing to apply the sigmoid activation the model's outputs and to save the results to image files.
47
+
48
+
49
+ ### Export
50
+
51
+ The generator can be exported to a Torchscript bundle with the following:
52
+
53
+ ```
54
+ python -m monai.bundle ckpt_export network_def --filepath mednist_gan.ts --ckpt_file models/model.pt --meta_file configs/metadata.json --config_file configs/inference.json
55
+ ```
56
+
57
+ The model can be loaded without MONAI code after this operation. For example, an image can be generated from a set of random values with:
58
+
59
+ ```python
60
+ import torch
61
+ net = torch.jit.load("mednist_gan.ts")
62
+ latent = torch.rand(1,64)
63
+ img = net(latent) # (1,1,64,64)
64
+ ```
docs/license.txt ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright 2022 MONAI Consortium
2
+ Licensed under the Apache License, Version 2.0 (the "License");
3
+ you may not use this file except in compliance with the License.
4
+ You may obtain a copy of the License at
5
+ http://www.apache.org/licenses/LICENSE-2.0
6
+ Unless required by applicable law or agreed to in writing, software
7
+ distributed under the License is distributed on an "AS IS" BASIS,
8
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ See the License for the specific language governing permissions and
10
+ limitations under the License.
11
+
12
+ Third Party Licenses
13
+ -----------------------------------------------------------------------
14
+
15
+ /*********************************************************************/
16
+ i. MedNIST Dataset
17
+ The dataset is kindly made available by Dr. Bradley J. Erickson M.D., Ph.D. (https://www.mayo.edu/research/labs/radiology-informatics/overview), Department of Radiology, Mayo Clinic under the Creative Commons [CC BY-SA 4.0 license](https://creativecommons.org/licenses/by-sa/4.0/).
18
+
19
+ The MedNIST dataset was gathered from several sets from:
20
+ * TCIA (https://wiki.cancerimagingarchive.net/display/Public/Data+Usage+Policies+and+Restrictions)
21
+ * the RSNA Bone Age Challenge (http://rsnachallenges.cloudapp.net/competitions/4),
22
+ * the NIH Chest X-ray dataset (https://cloud.google.com/healthcare/docs/resources/public-datasets/nih-chest).
23
+
24
+ If you use the MedNIST dataset, please acknowledge the source. For the license and usage conditions of the source datasets, please see their respective sites.
models/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5ee4b4a812063936e95a5e39e628884b0160ab5be4897d99ede3dc7746c4308a
3
+ size 1272514
scripts/losses.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ disc_loss_criterion = torch.nn.BCELoss()
4
+ gen_loss_criterion = torch.nn.BCELoss()
5
+ real_label = 1
6
+ fake_label = 0
7
+
8
+
9
+ def discriminator_loss(gen_images, real_images, disc_net):
10
+ real = real_images.new_full((real_images.shape[0], 1), real_label)
11
+ gen = gen_images.new_full((gen_images.shape[0], 1), fake_label)
12
+
13
+ realloss = disc_loss_criterion(disc_net(real_images), real)
14
+ genloss = disc_loss_criterion(disc_net(gen_images.detach()), gen)
15
+
16
+ return (genloss + realloss) / 2
17
+
18
+
19
+ def generator_loss(gen_images, disc_net):
20
+ output = disc_net(gen_images)
21
+ cats = output.new_full(output.shape, real_label)
22
+ return gen_loss_criterion(output, cats)