ndhieunguyen commited on
Commit
7dd9869
1 Parent(s): 925f3c0

Add application file

Browse files
Files changed (46) hide show
  1. .gitignore +160 -0
  2. README.md +1 -1
  3. app.py +110 -0
  4. checkpoints/PLAIN_ema_0.9999_360000.pt +3 -0
  5. dataset/selfies_dict.txt +2944 -0
  6. environment.yaml +129 -0
  7. inference.py +202 -0
  8. inference_submission.py +189 -0
  9. requirements.txt +0 -0
  10. src/__init__.py +0 -0
  11. src/anlg_infill/anlg.py +130 -0
  12. src/anlg_infill/mbr_eval.py +351 -0
  13. src/anlg_infill/post_process.py +35 -0
  14. src/anlg_infill/run_evaluation.py +81 -0
  15. src/control_gen/baseline_control.py +500 -0
  16. src/control_gen/eval_control.py +567 -0
  17. src/ev.py +117 -0
  18. src/evaluation/fcd_metric.py +54 -0
  19. src/evaluation/fingerprint_metrics.py +81 -0
  20. src/evaluation/mol_translation_metrics.py +129 -0
  21. src/improved_diffusion/__init__.py +0 -0
  22. src/improved_diffusion/dist_util.py +87 -0
  23. src/improved_diffusion/fp16_util.py +76 -0
  24. src/improved_diffusion/gaussian_diffusion.py +1606 -0
  25. src/improved_diffusion/image_datasets.py +120 -0
  26. src/improved_diffusion/logger.py +498 -0
  27. src/improved_diffusion/losses.py +119 -0
  28. src/improved_diffusion/nn.py +170 -0
  29. src/improved_diffusion/resample.py +154 -0
  30. src/improved_diffusion/respace.py +131 -0
  31. src/improved_diffusion/rounding.py +119 -0
  32. src/improved_diffusion/script_util.py +201 -0
  33. src/improved_diffusion/test_util.py +108 -0
  34. src/improved_diffusion/text_datasets.py +948 -0
  35. src/improved_diffusion/train_util.py +445 -0
  36. src/improved_diffusion/transformer_model.py +118 -0
  37. src/improved_diffusion/transformer_utils.py +450 -0
  38. src/scripts/__init__.py +0 -0
  39. src/scripts/batch_decode.py +149 -0
  40. src/scripts/batch_nll.py +29 -0
  41. src/scripts/infill_util.py +355 -0
  42. src/scripts/mydatasets.py +326 -0
  43. src/scripts/mytokenizers.py +249 -0
  44. src/scripts/nll.py +241 -0
  45. src/scripts/tree_helper.py +110 -0
  46. train.py +177 -0
.gitignore ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
  title: Lang2mol Diff
3
- emoji: 🏆
4
  colorFrom: pink
5
  colorTo: pink
6
  sdk: streamlit
 
1
  ---
2
  title: Lang2mol Diff
3
+ emoji: 🧬
4
  colorFrom: pink
5
  colorTo: pink
6
  sdk: streamlit
app.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import argparse
3
+ import selfies as sf
4
+ from tqdm import tqdm
5
+ from transformers import T5EncoderModel
6
+ from transformers import set_seed
7
+ from src.scripts.mytokenizers import Tokenizer
8
+ from src.improved_diffusion import gaussian_diffusion as gd
9
+ from src.improved_diffusion import dist_util, logger
10
+ from src.improved_diffusion.respace import SpacedDiffusion
11
+ from src.improved_diffusion.transformer_model import TransformerNetModel
12
+ from src.improved_diffusion.script_util import (
13
+ model_and_diffusion_defaults,
14
+ add_dict_to_argparser,
15
+ )
16
+ from src.scripts.mydatasets import Lang2molDataset_submission
17
+ import streamlit as st
18
+ import os
19
+
20
+
21
+ @st.cache_resource
22
+ def get_encoder():
23
+ model = T5EncoderModel.from_pretrained("QizhiPei/biot5-base-text2mol")
24
+ model.eval()
25
+ return model
26
+
27
+
28
+ @st.cache_resource
29
+ def get_tokenizer():
30
+ return Tokenizer()
31
+
32
+
33
+ @st.cache_resource
34
+ def get_model():
35
+ model = TransformerNetModel(
36
+ in_channels=32,
37
+ model_channels=128,
38
+ dropout=0.1,
39
+ vocab_size=35073,
40
+ hidden_size=1024,
41
+ num_attention_heads=16,
42
+ num_hidden_layers=12,
43
+ )
44
+ model.load_state_dict(
45
+ dist_util.load_state_dict(
46
+ os.path.join("checkpoints", "PLAIN_ema_0.9999_360000.pt"),
47
+ map_location="cpu",
48
+ )
49
+ )
50
+ model.eval()
51
+ return model
52
+
53
+
54
+ @st.cache_resource
55
+ def get_diffusion():
56
+ return SpacedDiffusion(
57
+ use_timesteps=[i for i in range(0, 2000, 10)],
58
+ betas=gd.get_named_beta_schedule("sqrt", 2000),
59
+ model_mean_type=(gd.ModelMeanType.START_X),
60
+ model_var_type=((gd.ModelVarType.FIXED_LARGE)),
61
+ loss_type=gd.LossType.E2E_MSE,
62
+ rescale_timesteps=True,
63
+ model_arch="transformer",
64
+ training_mode="e2e",
65
+ )
66
+
67
+
68
+ tokenizer = get_tokenizer()
69
+ encoder = get_encoder()
70
+ model = get_model()
71
+ diffusion = get_diffusion()
72
+
73
+ sample_fn = diffusion.ddim_sample_loop
74
+
75
+ text_input = st.text_area("Enter molecule description")
76
+ output = tokenizer(
77
+ text_input,
78
+ max_length=256,
79
+ truncation=True,
80
+ padding="max_length",
81
+ add_special_tokens=True,
82
+ return_tensors="pt",
83
+ return_attention_mask=True,
84
+ )
85
+ caption_state = encoder(
86
+ input_ids=output["input_ids"],
87
+ attention_mask=output["attention_mask"],
88
+ ).last_hidden_state
89
+ caption_mask = output["attention_mask"]
90
+
91
+ outputs = sample_fn(
92
+ model,
93
+ (1, 256, 32),
94
+ clip_denoised=False,
95
+ denoised_fn=None,
96
+ model_kwargs={},
97
+ top_p=1.0,
98
+ progress=True,
99
+ caption=(caption_state, caption_mask),
100
+ )
101
+ logits = model.get_logits(torch.tensor(outputs))
102
+ cands = torch.topk(logits, k=1, dim=-1)
103
+ outputs = cands.indices
104
+ outputs = outputs.squeeze(-1)
105
+ outputs = tokenizer.decode(outputs)
106
+ result = sf.decoder(
107
+ outputs[0].replace("<pad>", "").replace("</s>", "").replace("\t", "")
108
+ ).replace("\t", "")
109
+
110
+ st.write(result)
checkpoints/PLAIN_ema_0.9999_360000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d77c45acf5644b5e42e68000b1b2f94a25c1f3b4eb1dde26fdfcca3d7482f11b
3
+ size 1021819692
dataset/selfies_dict.txt ADDED
@@ -0,0 +1,2944 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [U-5]
2
+ [V]
3
+ [40Ca]
4
+ [SbH3]
5
+ [232Np]
6
+ [127Sn]
7
+ [SnH2+2]
8
+ [195Pt+2]
9
+ [21NH3]
10
+ [SiH1+1]
11
+ [ClH0]
12
+ [175Yb+3]
13
+ [184Ta]
14
+ [Pt+1]
15
+ [81Sr]
16
+ [=32P]
17
+ [116Sn]
18
+ [C@@]
19
+ [ClH3+2]
20
+ [99Tc+5]
21
+ [=Mo+4]
22
+ [238Th]
23
+ [141Pr]
24
+ [SiH4]
25
+ [/SiH2]
26
+ [=Branch3]
27
+ [PoH2]
28
+ [52Fe]
29
+ [66Cu]
30
+ [226Rn]
31
+ [138Xe]
32
+ [PH4+1]
33
+ [Zn+1]
34
+ [V+1]
35
+ [253Fm]
36
+ [121IH1]
37
+ [199Po]
38
+ [62Cu+2]
39
+ [12BH2]
40
+ [I+3]
41
+ [Te]
42
+ [208Bi]
43
+ [O-1]
44
+ [Cu-3]
45
+ [#Branch3]
46
+ [198Au]
47
+ [224Ra]
48
+ [156Ho]
49
+ [=Dy]
50
+ [CH2]
51
+ [\N]
52
+ [125Sn]
53
+ [220Ra]
54
+ [/13C@H1]
55
+ [Ta-1]
56
+ [/SH0]
57
+ [=WH4]
58
+ [#11C]
59
+ [65Cu]
60
+ [169Lu]
61
+ [=Si+2]
62
+ [72As]
63
+ [=U]
64
+ [/O+1]
65
+ [ClH1+1]
66
+ [98Tc+4]
67
+ [/Al-1]
68
+ [#Ce]
69
+ [GeH3]
70
+ [N@]
71
+ [107Cd]
72
+ [202Bi]
73
+ [CuH1+1]
74
+ [248Cm]
75
+ [\O]
76
+ [=TeH2]
77
+ [72Ge]
78
+ [#Yb]
79
+ [/Te]
80
+ [=Al]
81
+ [#121Sb]
82
+ [246Pu]
83
+ [18OH2]
84
+ [=Si-2]
85
+ [\N@+1]
86
+ [Ni+2]
87
+ [Nb-1]
88
+ [171Tm]
89
+ [Co-2]
90
+ [71Zn]
91
+ [/Hg]
92
+ [PtH2]
93
+ [86Y+3]
94
+ [18O-2]
95
+ [Ta+2]
96
+ [IH1]
97
+ [153Tb]
98
+ [169Er+3]
99
+ [211Bi]
100
+ [=11C]
101
+ [Li-1]
102
+ [107Rh]
103
+ [=Cu]
104
+ [126Xe]
105
+ [88Rb]
106
+ [Ge]
107
+ [123I]
108
+ [\NH1+1]
109
+ [Rh-3]
110
+ [184W]
111
+ [\CH1]
112
+ [9C-1]
113
+ [110Cd]
114
+ [=AlH1]
115
+ [MnH2]
116
+ [Ge@H1]
117
+ [108Ag]
118
+ [141Pm]
119
+ [In+3]
120
+ [13NH2-1]
121
+ [#Cr]
122
+ [=P@]
123
+ [8BH2]
124
+ [94Zr+4]
125
+ [130Ba]
126
+ [13NH4+1]
127
+ [=V]
128
+ [12C-1]
129
+ [Mg+2]
130
+ [#SH1-1]
131
+ [19F]
132
+ [89Zr+3]
133
+ [232Th]
134
+ [112Cd]
135
+ [In]
136
+ [RhH3]
137
+ [91Y]
138
+ [I]
139
+ [184Re]
140
+ [92Sr]
141
+ [BiH3]
142
+ [\P+1]
143
+ [/P-1]
144
+ [151Sm]
145
+ [Au-3]
146
+ [69Ge]
147
+ [=TeH1]
148
+ [SmH3]
149
+ [183Re]
150
+ [ReH2]
151
+ [17F]
152
+ [122Te]
153
+ [195Pt]
154
+ [167Tm+3]
155
+ [2H-1]
156
+ [232Pa]
157
+ [113In+3]
158
+ [=95Tc+4]
159
+ [=InH1]
160
+ [Ag-1]
161
+ [NiH2+2]
162
+ [AuH3]
163
+ [70Zn+2]
164
+ [160Tb]
165
+ [/131I]
166
+ [14CH1]
167
+ [35P]
168
+ [Ni-2]
169
+ [=W]
170
+ [/NH1+1]
171
+ [13OH2]
172
+ [197Po]
173
+ [RuH2+2]
174
+ [39ClH1]
175
+ [FeH1]
176
+ [NH1]
177
+ [7Be]
178
+ [144Ce+4]
179
+ [Po@]
180
+ [33ClH1]
181
+ [\AlH1]
182
+ [18CH3]
183
+ [SnH1]
184
+ [45Ca]
185
+ [As]
186
+ [Sn@@]
187
+ [/BH1-1]
188
+ [107Pd]
189
+ [Tm]
190
+ [SiH2]
191
+ [ZrH3]
192
+ [20OH1]
193
+ [SH1+1]
194
+ [44Ti]
195
+ [AlH5-2]
196
+ [MoH1]
197
+ [149Pr]
198
+ [#Ta]
199
+ [176Ta]
200
+ [=20CH1]
201
+ [Ru+3]
202
+ [=W-1]
203
+ [14C@@]
204
+ [33PH3]
205
+ [16OH1]
206
+ [=GaH1]
207
+ [53Ni]
208
+ [35Cl-1]
209
+ [92Zr]
210
+ [83Kr]
211
+ [32Cl]
212
+ [TeH1]
213
+ [Ir-4]
214
+ [13N+1]
215
+ [19BH2]
216
+ [=18O]
217
+ [31PH1]
218
+ [#Dy]
219
+ [PH1]
220
+ [Se+4]
221
+ [146Nd]
222
+ [125Sb]
223
+ [XeH1]
224
+ [186Pt]
225
+ [BiH2]
226
+ [=Tc+4]
227
+ [44Sc]
228
+ [BiH2+2]
229
+ [CoH3]
230
+ [SiH1-1]
231
+ [\PH0]
232
+ [203Tl]
233
+ [=Ta]
234
+ [Ge-1]
235
+ [Y]
236
+ [68Ga]
237
+ [=CoH1]
238
+ [Cl+3]
239
+ [=16O]
240
+ [/As+1]
241
+ [103Ru+2]
242
+ [62Co]
243
+ [207Bi]
244
+ [191Po]
245
+ [\F]
246
+ [Rb]
247
+ [113Sn]
248
+ [Ti+2]
249
+ [Sm+3]
250
+ [#PH1+1]
251
+ [V+2]
252
+ [125Xe]
253
+ [SbH1+1]
254
+ [Tc+6]
255
+ [AsH1]
256
+ [-/Ring2]
257
+ [#16O+1]
258
+ [CuH1]
259
+ [Zr-2]
260
+ [#GeH1]
261
+ [58Ni]
262
+ [77Ge]
263
+ [Co+2]
264
+ [87Sr+2]
265
+ [\PH2+1]
266
+ [93Y]
267
+ [=Mg]
268
+ [172Ta]
269
+ [=CrH2]
270
+ [#Tb]
271
+ [\2H]
272
+ [139Cs]
273
+ [136Nd]
274
+ [Ca+1]
275
+ [#P]
276
+ [36SH2]
277
+ [49Ca]
278
+ [19CH3]
279
+ [CH1-1]
280
+ [80Br-1]
281
+ [49Ti]
282
+ [88Y]
283
+ [TlH2]
284
+ [FeH4]
285
+ [226Ra]
286
+ [BH4-1]
287
+ [=14C-1]
288
+ [13CH2+1]
289
+ [Ge@]
290
+ [=Zr]
291
+ [47Ti]
292
+ [111IH1]
293
+ [\SH2+1]
294
+ [/9C]
295
+ [58Co]
296
+ [=NH2+1]
297
+ [206Pb]
298
+ [12CH1]
299
+ [93Mo]
300
+ [34S-2]
301
+ [77Kr]
302
+ [/Si-1]
303
+ [=32S]
304
+ [240Cm]
305
+ [249Bk]
306
+ [20CH2]
307
+ [128Sb]
308
+ [Zn-2]
309
+ [In+1]
310
+ [203Pb]
311
+ [18CH1]
312
+ [GaH1]
313
+ [\NH1-1]
314
+ [124Sn]
315
+ [Re]
316
+ [/NH1]
317
+ [/C-1]
318
+ [94Tc]
319
+ [118Sb]
320
+ [186Os]
321
+ [Co]
322
+ [47Ca+2]
323
+ [=SbH2]
324
+ [Branch3]
325
+ [30Si]
326
+ [Ring1]
327
+ [/Tl]
328
+ [S-1]
329
+ [96Mo]
330
+ [15N]
331
+ [SiH3-1]
332
+ [PH3+1]
333
+ [143Nd]
334
+ [=SbH3]
335
+ [\Ge]
336
+ [36Ar]
337
+ [=Th]
338
+ [=Pb]
339
+ [=Tc+3]
340
+ [/13CH1-1]
341
+ [AlH1]
342
+ [141Ba]
343
+ [177Ta]
344
+ [BrH1+1]
345
+ [=19O]
346
+ [156Gd]
347
+ [N@@H1+1]
348
+ [16OH2]
349
+ [N-1]
350
+ [254Fm]
351
+ [186Lu]
352
+ [18C]
353
+ [246Am]
354
+ [#Th]
355
+ [194Po]
356
+ [#Mo+1]
357
+ [=34S]
358
+ [110Ru]
359
+ [92Mo]
360
+ [169Yb+3]
361
+ [89Y+3]
362
+ [15NH2]
363
+ [173Yb]
364
+ [185Ir]
365
+ [3H+1]
366
+ [/79Br]
367
+ [IH0]
368
+ [121I]
369
+ [\15NH1]
370
+ [=Gd]
371
+ [=SnH1]
372
+ [151Nd]
373
+ [Os+7]
374
+ [74Kr]
375
+ [Bi-1]
376
+ [78Kr]
377
+ [119Sb]
378
+ [9CH4]
379
+ [=Ring1]
380
+ [\SiH2]
381
+ [#Nd]
382
+ [19Ne]
383
+ [#Ti]
384
+ [=CH0]
385
+ [95Tc]
386
+ [138Ba]
387
+ [16NH2]
388
+ [31P]
389
+ [120Xe]
390
+ [Se@@]
391
+ [15NH2-1]
392
+ [Pt+4]
393
+ [13NH3]
394
+ [85Sr+2]
395
+ [197Hg+2]
396
+ [14C@]
397
+ [Tl-3]
398
+ [233U]
399
+ [146Pm]
400
+ [221Fr]
401
+ [/Hg+1]
402
+ [N@H1+1]
403
+ [#12CH1]
404
+ [AlH3-1]
405
+ [/Ge]
406
+ [181Ta]
407
+ [#Y]
408
+ [143Ce]
409
+ [33S]
410
+ [=La]
411
+ [#In]
412
+ [Cu+1]
413
+ [Nb+3]
414
+ [65Cu+1]
415
+ [Zn+2]
416
+ [\OH1+1]
417
+ [=SH0]
418
+ [10Be]
419
+ [74As]
420
+ [164Er]
421
+ [Sn+2]
422
+ [188W]
423
+ [157Tb]
424
+ [84BrH1]
425
+ [71Se]
426
+ [/S]
427
+ [55Fe+3]
428
+ [208Tl]
429
+ [199Pt]
430
+ [WH6]
431
+ [151Pm]
432
+ [AlH3-3]
433
+ [65Zn]
434
+ [=Ag]
435
+ [77As]
436
+ [Co+3]
437
+ [132IH1]
438
+ [Rh-1]
439
+ [15NH1]
440
+ [PoH1]
441
+ [100Rh]
442
+ [8He]
443
+ [168Yb]
444
+ [#Ge]
445
+ [29Si]
446
+ [27Mg]
447
+ [205Bi+3]
448
+ [109Ag]
449
+ [13CH3-1]
450
+ [237Np]
451
+ [=Cd]
452
+ [35ClH1]
453
+ [137Cs]
454
+ [/Se-1]
455
+ [64Cu]
456
+ [AlH1-1]
457
+ [172Hf]
458
+ [92Nb]
459
+ [97Ru]
460
+ [2H+1]
461
+ [Cr+6]
462
+ [#14N]
463
+ [122Sn]
464
+ [=Pr]
465
+ [146Ce]
466
+ [SnH2]
467
+ [174Hf]
468
+ [212Pb+2]
469
+ [164Ho]
470
+ [TaH2]
471
+ [=Mo]
472
+ [104Cd]
473
+ [140Ce]
474
+ [98Mo]
475
+ [126Ba]
476
+ [Sn+3]
477
+ [=YH1]
478
+ [137Ce]
479
+ [85Kr]
480
+ [222Fr]
481
+ [CeH3]
482
+ [111Cd+2]
483
+ [Pd+1]
484
+ [24Mg]
485
+ [241Pu]
486
+ [/80Br]
487
+ [19O]
488
+ [129Cs+1]
489
+ [=PH1]
490
+ [127I-1]
491
+ .
492
+ [=14C]
493
+ [65Ga]
494
+ [12C@]
495
+ [GeH1]
496
+ [Ga-3]
497
+ [Ge-2]
498
+ [3HH1]
499
+ [/Br-1]
500
+ [33SH2]
501
+ [16OH1-1]
502
+ [133Xe]
503
+ [\123I]
504
+ [#MoH1]
505
+ [244Am]
506
+ [LaH3]
507
+ [\SnH3]
508
+ [/Al+2]
509
+ [157Gd]
510
+ [132Ba]
511
+ [Tl-1]
512
+ [10BH1-1]
513
+ [212Pb]
514
+ [Si+1]
515
+ [161Gd]
516
+ [=BH2-1]
517
+ [52Cr]
518
+ [30PH3]
519
+ [\CH1-1]
520
+ [238Pu]
521
+ [#Ta+1]
522
+ [69Ga+3]
523
+ [144Nd]
524
+ [=Be]
525
+ [97Nb]
526
+ [#N]
527
+ [206Tl]
528
+ [UH3]
529
+ [=P-1]
530
+ [141Nd]
531
+ [83Sr+2]
532
+ [109Cd+2]
533
+ [185W]
534
+ [46Sc]
535
+ [Ir-3]
536
+ [32S]
537
+ [75Se]
538
+ [/PH1-1]
539
+ [250Cm]
540
+ [BiH4-1]
541
+ [\PH3+1]
542
+ [166Tm]
543
+ [203Hg+1]
544
+ [Mg]
545
+ [Gd+2]
546
+ [11C-1]
547
+ [91Y+3]
548
+ [Tb+3]
549
+ [\C+1]
550
+ [FeH6-4]
551
+ [12C]
552
+ [141Sm]
553
+ [S]
554
+ [ReH7]
555
+ [P@H1]
556
+ [/SnH2]
557
+ [13OH1]
558
+ [IH1+1]
559
+ [Fe+3]
560
+ [Ge@@H1]
561
+ [=12CH1]
562
+ [S@@]
563
+ [Mo-2]
564
+ [182W]
565
+ [=13O]
566
+ [190Po]
567
+ [131La]
568
+ [13CH1+1]
569
+ [157Gd+3]
570
+ [BiH1+1]
571
+ [109In]
572
+ [OsH3]
573
+ [#Si+1]
574
+ [137Ba]
575
+ [211Po]
576
+ [130I-1]
577
+ [/123I]
578
+ [Kr]
579
+ [228Rn]
580
+ [25Mg]
581
+ [13CH1]
582
+ [Sc]
583
+ [Rn]
584
+ [\I]
585
+ [228Ac]
586
+ [22Na+1]
587
+ [Cu]
588
+ [=Tc+2]
589
+ [Ti-1]
590
+ [55Fe+2]
591
+ [=Se]
592
+ [Ni+1]
593
+ [Po]
594
+ [149Eu]
595
+ [ThH2]
596
+ [=S]
597
+ [CoH1+2]
598
+ [#Cl]
599
+ [#SiH1]
600
+ [13CH3+1]
601
+ [224Ac]
602
+ [60Co+3]
603
+ [\As]
604
+ [9Be]
605
+ [BH1]
606
+ [245Pu]
607
+ [#PH2]
608
+ [249Cm]
609
+ [138La]
610
+ [#Branch2]
611
+ [SiH3+1]
612
+ [231Th]
613
+ [-\Ring1]
614
+ [122IH1]
615
+ [117Sn+4]
616
+ [180Os]
617
+ [126Sb]
618
+ [209Tl]
619
+ [\Si]
620
+ [\Sn]
621
+ [67Ga+3]
622
+ [=Ca]
623
+ [208Pb]
624
+ [137Ba+2]
625
+ [99Tc]
626
+ [Ru+8]
627
+ [\11C]
628
+ [=FeH1]
629
+ [BH2+1]
630
+ [IH2+1]
631
+ [243Pu]
632
+ [32PH2]
633
+ [MoH2]
634
+ [TiH2]
635
+ [/Al+1]
636
+ [237Pu]
637
+ [\76Br]
638
+ [H]
639
+ [B-2]
640
+ [WH2]
641
+ [Nb]
642
+ [GaH2]
643
+ [\Pb]
644
+ [60Ni]
645
+ [238Cm]
646
+ [\C@@H1]
647
+ [218AtH1]
648
+ [P@H1+1]
649
+ [Co-1]
650
+ [\Sn+1]
651
+ [159Ho]
652
+ [BH2]
653
+ [11B-1]
654
+ [Ta-2]
655
+ [70Ge]
656
+ [/34S]
657
+ [134IH1]
658
+ [Rb+1]
659
+ [153Gd]
660
+ [135La]
661
+ [=Al-1]
662
+ [YbH2]
663
+ [/127I]
664
+ [Ho+3]
665
+ [44Sc+3]
666
+ [48V]
667
+ [104Ag]
668
+ [ClH2+2]
669
+ [12B]
670
+ [ReH3]
671
+ [43K+1]
672
+ [=NH0]
673
+ [\N-1]
674
+ [22CH3-1]
675
+ [Bi+2]
676
+ [82Kr]
677
+ [102Rh]
678
+ [#Sc]
679
+ [192Po]
680
+ [228Th+4]
681
+ [225Ra]
682
+ [/Sn+3]
683
+ [31PH3]
684
+ [#Ga]
685
+ [101Mo]
686
+ [232U]
687
+ [BiH1]
688
+ [220Fr]
689
+ [#17O+1]
690
+ [128Sn]
691
+ [18FH1]
692
+ [SiH2-2]
693
+ [=16N]
694
+ [75As]
695
+ [99Tc+4]
696
+ [210Pb]
697
+ [BrH1]
698
+ [\Bi]
699
+ [SnH3+1]
700
+ [\CH2+1]
701
+ [Al-3]
702
+ [254Es]
703
+ [66Zn+2]
704
+ [S@@H1]
705
+ [Ni-3]
706
+ [94Nb]
707
+ [217Bi]
708
+ [11C]
709
+ [166Tb]
710
+ [CH3]
711
+ [175Hf]
712
+ [AlH1+1]
713
+ [SbH2+1]
714
+ [162Ho]
715
+ [90Mo]
716
+ [Os+4]
717
+ [=Si-1]
718
+ [204Tl]
719
+ [13CH1-1]
720
+ [U+3]
721
+ [\P@]
722
+ [Cl+1]
723
+ [155Eu]
724
+ [215Po]
725
+ [33PH1]
726
+ [Cd]
727
+ [AtH1]
728
+ [57Fe]
729
+ [/CH2-1]
730
+ [142La]
731
+ [Se-1]
732
+ [14CH2]
733
+ [Cu-4]
734
+ [Sr+2]
735
+ [/C]
736
+ [35Cl]
737
+ [191Pt+2]
738
+ [169Er]
739
+ [15NH4+1]
740
+ [23Na]
741
+ [38Ar]
742
+ [/Sn+2]
743
+ [143La]
744
+ [43Ca]
745
+ [\I+1]
746
+ [213BiH1]
747
+ [SH2+1]
748
+ [13C@@]
749
+ [14CH3]
750
+ [194Hg]
751
+ [70Se]
752
+ [Zr+3]
753
+ [18O]
754
+ [=Ru]
755
+ [EuH2]
756
+ [#13C]
757
+ [SiH3]
758
+ [=13C]
759
+ [\14C@H1]
760
+ [-\Ring2]
761
+ [14C]
762
+ [/15N]
763
+ [\-Ring3]
764
+ [14CH4]
765
+ [46Ca]
766
+ [10B]
767
+ [#B]
768
+ [66Zn]
769
+ [#Sb]
770
+ [Os+1]
771
+ [=99Tc+2]
772
+ [#17C-1]
773
+ [Au]
774
+ [75SeH1]
775
+ [179Ta]
776
+ [139Pr]
777
+ [89Y]
778
+ [Branch2]
779
+ [/O-1]
780
+ [200Bi]
781
+ [2HH1]
782
+ [=13CH1]
783
+ [Fr]
784
+ [166Yb]
785
+ [239Pu]
786
+ [11CH3-1]
787
+ [103Ru]
788
+ [61Co]
789
+ [106Pd]
790
+ [103Rh]
791
+ [35SH1]
792
+ [Sb]
793
+ [18OH3+1]
794
+ [47V]
795
+ [50Cr+3]
796
+ [121Sn]
797
+ [171Lu]
798
+ [184Hf]
799
+ [110In]
800
+ [247Bk]
801
+ [AsH2]
802
+ [184Os]
803
+ [Er+3]
804
+ [86Zr]
805
+ [#Ni]
806
+ [126I]
807
+ [14NH3]
808
+ [32PH3]
809
+ [Si-1]
810
+ [125Te]
811
+ [#Ru]
812
+ [Ru-2]
813
+ [76Br-1]
814
+ [227Ra]
815
+ [/OH0]
816
+ [=14CH2]
817
+ [NH0]
818
+ [227Ac]
819
+ [234Pa]
820
+ [OsH1-1]
821
+ [69Ga]
822
+ [182Re]
823
+ [U+4]
824
+ [239Np]
825
+ [WH3]
826
+ [Ru+2]
827
+ [/N@+1]
828
+ [=In]
829
+ [201Bi]
830
+ [126Sb+3]
831
+ [Pd-1]
832
+ [#188Re]
833
+ [=C]
834
+ [OsH1]
835
+ [45Sc]
836
+ [/S-1]
837
+ [=99Tc+1]
838
+ [=VH1]
839
+ [GeH2-1]
840
+ [/NH2+1]
841
+ [NbH3]
842
+ [Sn-1]
843
+ [230U]
844
+ [37SH2]
845
+ [180W]
846
+ [105Ag]
847
+ [67Ge]
848
+ [91Zr]
849
+ [Tb+4]
850
+ [\14CH1]
851
+ [=WH1]
852
+ [UH2]
853
+ [258Md]
854
+ [Dy+3]
855
+ [220Rn]
856
+ [TeH3]
857
+ [86Sr]
858
+ [#Branch1]
859
+ [=15NH2+1]
860
+ [#Br]
861
+ [42Ca]
862
+ [46Ti]
863
+ [IrH1]
864
+ [133I-1]
865
+ [3H]
866
+ [/Se]
867
+ [/Ga]
868
+ [11CH4]
869
+ [Bi+1]
870
+ [MnH1]
871
+ [#18CH1]
872
+ [Zn-4]
873
+ [156Sm]
874
+ [113Ag]
875
+ [\BiH1]
876
+ [128Xe]
877
+ [175Ta]
878
+ [\NH3+1]
879
+ [=SeH1]
880
+ [69Zn]
881
+ [\Al]
882
+ [#W+1]
883
+ [233Np]
884
+ [253Cf]
885
+ [134Cs]
886
+ [\Br]
887
+ [253Es]
888
+ [C@@H1]
889
+ [#13N]
890
+ [/P@]
891
+ [173Ta]
892
+ [Nb+2]
893
+ [VH1]
894
+ [126I-1]
895
+ [121I-1]
896
+ [207At]
897
+ [\S]
898
+ [182Os]
899
+ [7Li]
900
+ [SH1]
901
+ [/AlH1+1]
902
+ [115In]
903
+ [AlH4-1]
904
+ [59Ni]
905
+ [123IH1]
906
+ [FH1+1]
907
+ [82Br-1]
908
+ [Cl@@-1]
909
+ [137Pr]
910
+ [SbH5]
911
+ [67Zn+2]
912
+ [132I-1]
913
+ [\SiH3]
914
+ [AlH3]
915
+ [AsH3]
916
+ [111In-1]
917
+ [/76Br]
918
+ [164Dy+3]
919
+ [50Cr]
920
+ [=Tc+5]
921
+ [82Se+6]
922
+ [SeH3+1]
923
+ [#W-1]
924
+ [Ir-2]
925
+ [\13C@@H1]
926
+ [/AlH2]
927
+ [99Mo]
928
+ [/14C@H1]
929
+ [76Br]
930
+ [Ag]
931
+ [145Eu]
932
+ [135I]
933
+ [/PH1]
934
+ [141Ce+3]
935
+ [84Sr]
936
+ [B+2]
937
+ [Th+2]
938
+ [117SnH2]
939
+ [=64Zn]
940
+ [Mg+1]
941
+ [38Cl-1]
942
+ [140Ba]
943
+ [22Ne]
944
+ [118Sn]
945
+ [145Pr]
946
+ [202Pb]
947
+ [125Sn+4]
948
+ [61Ni]
949
+ [233U+4]
950
+ [/18F]
951
+ [SeH1-1]
952
+ [12CH4]
953
+ [Cu-5]
954
+ [/NH0]
955
+ [=SH1+1]
956
+ [#U]
957
+ [153Sm]
958
+ [76Ge]
959
+ [207Tl]
960
+ [BiH5]
961
+ [Ru+4]
962
+ [ZrH1]
963
+ [131I]
964
+ [81Kr]
965
+ [66Ge]
966
+ [9C]
967
+ [193Os]
968
+ [59Co]
969
+ [Pb]
970
+ [Cr-1]
971
+ [95Zr]
972
+ [Gd+3]
973
+ [#PbH1]
974
+ [18OH1]
975
+ [134La]
976
+ [15CH2]
977
+ [Al+2]
978
+ [214Pb]
979
+ [17NH3]
980
+ [134Ba]
981
+ [\Si+1]
982
+ [17B]
983
+ [145Pm]
984
+ [/12C]
985
+ [Tl+1]
986
+ [=Fe]
987
+ [170Lu]
988
+ [182Ta]
989
+ [95Nb]
990
+ [SnH4+2]
991
+ [=As+1]
992
+ [\CH0]
993
+ [#S]
994
+ [79Rb]
995
+ [47Sc]
996
+ [49V]
997
+ [Nb-2]
998
+ [=As]
999
+ [81Se]
1000
+ [19FH1]
1001
+ [75Ge]
1002
+ [99Y]
1003
+ [79Br]
1004
+ [193Au]
1005
+ [210BiH3]
1006
+ [73Se]
1007
+ [54Mn]
1008
+ [51Ti]
1009
+ [ClH2+1]
1010
+ [90Sr+2]
1011
+ [TiH1+1]
1012
+ [129IH1]
1013
+ [/15N+1]
1014
+ [Fe+2]
1015
+ [199Hg]
1016
+ [74Br-1]
1017
+ [\15NH2]
1018
+ [85Rb+1]
1019
+ [42K+1]
1020
+ [203Tl+1]
1021
+ [#Er]
1022
+ [=76As]
1023
+ [SnH4]
1024
+ [/C@@]
1025
+ [182Ir]
1026
+ [VH2]
1027
+ [150Nd]
1028
+ [PH2+1]
1029
+ [137La]
1030
+ [135Xe]
1031
+ [179Hf]
1032
+ [HgH1]
1033
+ [Nd+3]
1034
+ [#O+1]
1035
+ [ReH4]
1036
+ [\Al-1]
1037
+ [Bi]
1038
+ [133Ba+2]
1039
+ [138Cs+1]
1040
+ [231Pa]
1041
+ [90Zr]
1042
+ [\CH1+1]
1043
+ [105Rh]
1044
+ [166Er]
1045
+ [34Cl-1]
1046
+ [PtH2+2]
1047
+ [/CH1-1]
1048
+ [=12CH2]
1049
+ [U]
1050
+ [Zn-1]
1051
+ [/IH1]
1052
+ [=13C-1]
1053
+ [=18O+1]
1054
+ [S@@+1]
1055
+ [154Eu+3]
1056
+ [97Zr]
1057
+ [178Yb]
1058
+ [InH1]
1059
+ [24Na]
1060
+ [82Br]
1061
+ [137Xe]
1062
+ [132La]
1063
+ [218Rn]
1064
+ [37S]
1065
+ [53Mn]
1066
+ [\W]
1067
+ [CeH1]
1068
+ [RuH5]
1069
+ [/PH2+1]
1070
+ [Re-2]
1071
+ [/Po]
1072
+ [28Si]
1073
+ [135Cs+1]
1074
+ [68Ga+3]
1075
+ [Co-4]
1076
+ [Sb+5]
1077
+ [177Yb]
1078
+ [=Ti]
1079
+ [246Cf]
1080
+ [196Bi]
1081
+ [22CH3]
1082
+ [90Nb]
1083
+ [#V+1]
1084
+ [GeH2+1]
1085
+ [243Am]
1086
+ [\B]
1087
+ [#Ir+1]
1088
+ [127Xe]
1089
+ [191Ir]
1090
+ [KrH1]
1091
+ [No]
1092
+ [#La]
1093
+ [194Ir]
1094
+ [89Sr+2]
1095
+ [/13CH1]
1096
+ [185Re]
1097
+ [\Cl]
1098
+ [/N+1]
1099
+ [\S@]
1100
+ [Tc+5]
1101
+ [60Cu]
1102
+ [/C@]
1103
+ [BiH2+1]
1104
+ [193Hg]
1105
+ [102Pd]
1106
+ [=188Re]
1107
+ [AsH3+1]
1108
+ [203Bi]
1109
+ [Pr]
1110
+ [/Cl+1]
1111
+ [94Zr]
1112
+ [43K]
1113
+ [138Cs]
1114
+ [153Gd+3]
1115
+ [\-Ring2]
1116
+ [OsH6]
1117
+ [=Er]
1118
+ [MnH1+1]
1119
+ [159Gd+3]
1120
+ [12NH3]
1121
+ [67Cu]
1122
+ [/XeH1]
1123
+ [77Br-1]
1124
+ [=14N]
1125
+ [=C-1]
1126
+ [MgH1]
1127
+ [#13C-1]
1128
+ [Hg+1]
1129
+ [SeH2]
1130
+ [=99Tc+4]
1131
+ [28Al]
1132
+ [Cm]
1133
+ [82Rb+1]
1134
+ [252Cf]
1135
+ [159Dy]
1136
+ [52Fe+3]
1137
+ [Se@]
1138
+ [BH0]
1139
+ [81Rb]
1140
+ [106Rh]
1141
+ [74BrH1]
1142
+ [210Bi]
1143
+ [206Bi]
1144
+ [\C@]
1145
+ [73As]
1146
+ [Cu-1]
1147
+ [\SiH2+1]
1148
+ [\Po]
1149
+ [Te+1]
1150
+ [144Ce+3]
1151
+ [41Ca+2]
1152
+ [132Xe]
1153
+ [=Xe]
1154
+ [87Y]
1155
+ [187Ir]
1156
+ [Br-1]
1157
+ [17O-1]
1158
+ [Cl+2]
1159
+ [229Th]
1160
+ [#Re]
1161
+ [146Eu]
1162
+ [238Am]
1163
+ [79Se]
1164
+ [136Ce]
1165
+ [SbH3+1]
1166
+ [58Co+2]
1167
+ [AsH2-1]
1168
+ [#C]
1169
+ [150Tb]
1170
+ [/18O]
1171
+ [109Cd]
1172
+ [B@@H1-1]
1173
+ [=11CH2]
1174
+ [124Xe]
1175
+ [1H]
1176
+ [#Nb]
1177
+ [219Rn]
1178
+ [Al]
1179
+ [90Y]
1180
+ [Cu-2]
1181
+ [170Er]
1182
+ [15OH2]
1183
+ [149Pm]
1184
+ [=O]
1185
+ [Rh]
1186
+ [228Th]
1187
+ [SbH6+3]
1188
+ [250Cf]
1189
+ [197Pb]
1190
+ [/CH2+1]
1191
+ [Pd+2]
1192
+ [12C@@]
1193
+ [10B-1]
1194
+ [#Pd]
1195
+ [=18C]
1196
+ [Ce+4]
1197
+ [\CH2-1]
1198
+ [13CH2-1]
1199
+ [181Ta+2]
1200
+ [\14C@]
1201
+ [117Cd]
1202
+ [186Ta]
1203
+ [#15N]
1204
+ [\SeH1]
1205
+ [\Se]
1206
+ [/SiH1]
1207
+ [HgH2]
1208
+ [32P+1]
1209
+ [V-1]
1210
+ [Cr+1]
1211
+ [SiH1-2]
1212
+ [=13N]
1213
+ [1H-1]
1214
+ [/35S]
1215
+ [13C-1]
1216
+ [74Se]
1217
+ [64Zn]
1218
+ [Cl]
1219
+ [142Pr]
1220
+ [72Br-1]
1221
+ [Pd]
1222
+ [200Tl]
1223
+ [92Sr+2]
1224
+ [=B-1]
1225
+ [79BrH1]
1226
+ [122I-1]
1227
+ [86Rb+1]
1228
+ [C-1]
1229
+ [187Re]
1230
+ [202Hg]
1231
+ [213Bi+3]
1232
+ [PtH3]
1233
+ [=35S]
1234
+ [39Ar]
1235
+ [13C+1]
1236
+ [152Sm+3]
1237
+ [161Ho]
1238
+ [181Hf]
1239
+ [26Mg]
1240
+ [/32P]
1241
+ [#C-1]
1242
+ [203Hg]
1243
+ [131Ba]
1244
+ [AsH4+1]
1245
+ [=SiH1]
1246
+ [FeH2]
1247
+ [227Th]
1248
+ [89Rb+1]
1249
+ [\14CH3]
1250
+ [152Tb]
1251
+ [Zr-4]
1252
+ [124IH1]
1253
+ [154Tb]
1254
+ [12CH3]
1255
+ [62Cu]
1256
+ [133I]
1257
+ [SiH2+1]
1258
+ [#SeH1]
1259
+ [39K+1]
1260
+ [As+3]
1261
+ [82BrH1]
1262
+ [/SiH3]
1263
+ [195Pb]
1264
+ [PdH1]
1265
+ [FeH3]
1266
+ [Pt-2]
1267
+ [=Mo+2]
1268
+ [/14CH1]
1269
+ [GaH4-1]
1270
+ [Ni-4]
1271
+ [Rh-2]
1272
+ [\Hg+1]
1273
+ [146Sm]
1274
+ [173Tm]
1275
+ [Pt+2]
1276
+ [P-3]
1277
+ [/I+1]
1278
+ [199Au]
1279
+ [66Ni]
1280
+ [78BrH1]
1281
+ [211Rn]
1282
+ [157Sm]
1283
+ [=Ni]
1284
+ [BrH2+1]
1285
+ [=S+1]
1286
+ [136Cs]
1287
+ [130Xe]
1288
+ [144Pr+3]
1289
+ [210At]
1290
+ [Cr+4]
1291
+ [128IH1]
1292
+ [174Lu]
1293
+ [185Ta]
1294
+ [=Y]
1295
+ [148Eu]
1296
+ [13N]
1297
+ [55Fe]
1298
+ [149Nd]
1299
+ [120IH1]
1300
+ [205Pb]
1301
+ [=125Te]
1302
+ [=GeH1]
1303
+ [=Ce]
1304
+ [90Zr+4]
1305
+ [105Pd]
1306
+ [32ClH1]
1307
+ [Mo-3]
1308
+ [/TlH1]
1309
+ [242Pu]
1310
+ [84Rb]
1311
+ [51Mn]
1312
+ [97Tc]
1313
+ [11CH3+1]
1314
+ [PbH1]
1315
+ [40K+1]
1316
+ [254Cf]
1317
+ [130IH1]
1318
+ [88Nb]
1319
+ [Ti]
1320
+ [90Y+3]
1321
+ [132Cs]
1322
+ [129Te]
1323
+ [/I-1]
1324
+ [182Hf]
1325
+ [CoH2]
1326
+ [TeH2]
1327
+ [#15O+1]
1328
+ [B]
1329
+ [131Cs+1]
1330
+ [59Co+3]
1331
+ [RhH1]
1332
+ [NiH1+1]
1333
+ [Zr-1]
1334
+ [Os-3]
1335
+ [204Hg+1]
1336
+ [193Pt+2]
1337
+ [I-1]
1338
+ [35S-1]
1339
+ [=15N]
1340
+ [\SnH1]
1341
+ [H-1]
1342
+ [108Cd]
1343
+ [11CH1]
1344
+ [176Yb]
1345
+ [TiH1]
1346
+ [48Ca]
1347
+ [=PH1+1]
1348
+ [195Ir]
1349
+ [La+3]
1350
+ [Se]
1351
+ [153Eu]
1352
+ [Hg+2]
1353
+ [138Pr]
1354
+ [Sb+1]
1355
+ [101Tc]
1356
+ [112Sn]
1357
+ [/InH2]
1358
+ [Tm+3]
1359
+ [#Zr]
1360
+ [PbH2+2]
1361
+ [\N@@+1]
1362
+ [114Cd]
1363
+ [Nb+5]
1364
+ [194Au]
1365
+ [BH4+1]
1366
+ [/GeH3]
1367
+ [66Ga]
1368
+ [\C-1]
1369
+ [96Zr]
1370
+ [204Po]
1371
+ [SiH2-1]
1372
+ [63Ni]
1373
+ [167Er]
1374
+ [234U]
1375
+ [Os+6]
1376
+ [201Po]
1377
+ [130Te]
1378
+ [/ClH1+1]
1379
+ [129I-1]
1380
+ [/Al]
1381
+ [Cr+5]
1382
+ [173Hf]
1383
+ [14C@@H1]
1384
+ [YH1]
1385
+ [57Mn]
1386
+ [111Cd]
1387
+ [102Ru]
1388
+ [/Sn]
1389
+ [21Ne]
1390
+ [160Dy]
1391
+ [139La]
1392
+ [89Sr]
1393
+ [257Fm]
1394
+ [Zn-3]
1395
+ [40PH1]
1396
+ [#Pb]
1397
+ [136Xe]
1398
+ [213Pb]
1399
+ [101Pd]
1400
+ [\BH0]
1401
+ [=17O]
1402
+ [1H+1]
1403
+ [87Kr]
1404
+ [158Gd]
1405
+ [NiH2]
1406
+ [\P@@]
1407
+ [PH1+1]
1408
+ [Al-1]
1409
+ [Cr]
1410
+ [99Tc+7]
1411
+ [#Fe+1]
1412
+ [172Yb]
1413
+ [=Ti+2]
1414
+ [235Pu]
1415
+ [\Se-1]
1416
+ [198Po]
1417
+ [134Te]
1418
+ [18CH2]
1419
+ [171Er]
1420
+ [69As]
1421
+ [/CH1+1]
1422
+ [Ho]
1423
+ [IrH2]
1424
+ [40PH3]
1425
+ [AsH5]
1426
+ [\Te+1]
1427
+ [Tc+4]
1428
+ [Te@]
1429
+ [Lr]
1430
+ [75As+3]
1431
+ [119Sn]
1432
+ [203Pb+2]
1433
+ [68Ge]
1434
+ [197Tl]
1435
+ [BH1+1]
1436
+ [15CH4]
1437
+ [209Bi]
1438
+ [75Br-1]
1439
+ [44Ca+2]
1440
+ [TeH3+1]
1441
+ [17C]
1442
+ [/14CH2-1]
1443
+ [=BiH1]
1444
+ [112In]
1445
+ [=Tc+1]
1446
+ [=15N-1]
1447
+ [61Cu+1]
1448
+ [4He]
1449
+ [51Cr]
1450
+ [Au+3]
1451
+ [=Tm]
1452
+ [222Rn]
1453
+ [72Ga]
1454
+ [P@+1]
1455
+ [193Pt+4]
1456
+ [Rf]
1457
+ [=P]
1458
+ [178Lu]
1459
+ [172Er]
1460
+ [110Pd]
1461
+ [200Pt]
1462
+ [SnH1+2]
1463
+ [83Se]
1464
+ [196Po]
1465
+ [111InH3]
1466
+ [=Nd]
1467
+ [\125I]
1468
+ [Br]
1469
+ [P@@]
1470
+ [70As]
1471
+ [SbH4]
1472
+ [Fe]
1473
+ [144Pr]
1474
+ [151Eu+3]
1475
+ [45Ca+2]
1476
+ [11CH2]
1477
+ [66Ga+3]
1478
+ [Cd+2]
1479
+ [64Zn+2]
1480
+ [152Dy]
1481
+ [15O-2]
1482
+ [AlH1+2]
1483
+ [106Ag]
1484
+ [=OH1+1]
1485
+ [120I]
1486
+ [OH3+1]
1487
+ [106Cd]
1488
+ [=15N+1]
1489
+ [52V]
1490
+ [116Cd]
1491
+ [177W]
1492
+ [#Pr]
1493
+ [As+1]
1494
+ [GaH1-1]
1495
+ [230Pu]
1496
+ [=Sb+1]
1497
+ [IrH3]
1498
+ [218At]
1499
+ [234Np]
1500
+ [155Ho]
1501
+ [118Pd+2]
1502
+ [192Os]
1503
+ [/13CH2]
1504
+ [#14CH1]
1505
+ [/Te+1]
1506
+ [134Xe]
1507
+ [10BH2]
1508
+ [169Yb]
1509
+ [/37Cl]
1510
+ [76As]
1511
+ [=Ba]
1512
+ [=Re]
1513
+ [/C@H1]
1514
+ [SnH1-1]
1515
+ [\HgH1]
1516
+ [223Ac]
1517
+ [SnH3-1]
1518
+ [143Pr]
1519
+ [\IH1+1]
1520
+ [=BrH1]
1521
+ [103Cd]
1522
+ [Si@]
1523
+ [FeH6]
1524
+ [\PH1+1]
1525
+ [Pt-1]
1526
+ [#Tc+1]
1527
+ [96Nb]
1528
+ [103Pd]
1529
+ [Br+1]
1530
+ [19C]
1531
+ [=Os+2]
1532
+ [83BrH1]
1533
+ [#Tl]
1534
+ [#18C-1]
1535
+ [244Pu]
1536
+ [136Eu]
1537
+ [Mn+1]
1538
+ [54Cr]
1539
+ [\O+1]
1540
+ [S@+1]
1541
+ [201Tl]
1542
+ [\C@@]
1543
+ [SH3+1]
1544
+ [/125I]
1545
+ [144Pm]
1546
+ [123Sn]
1547
+ [Na]
1548
+ [161Tb+3]
1549
+ [68Zn]
1550
+ [=70Zn]
1551
+ [Nd]
1552
+ [/13C@@H1]
1553
+ [86Y]
1554
+ [Fe+6]
1555
+ [Al-2]
1556
+ [121Xe]
1557
+ [Mo+4]
1558
+ [Es]
1559
+ [19B]
1560
+ [115Sb]
1561
+ [38SH2]
1562
+ [14CH2-1]
1563
+ [=SiH2]
1564
+ [=Si+1]
1565
+ [201Au]
1566
+ [11CH1-1]
1567
+ [28SiH3]
1568
+ [Mo]
1569
+ [109Pd+2]
1570
+ [YH2]
1571
+ [#17CH1]
1572
+ [Au+1]
1573
+ [127Te]
1574
+ [#W]
1575
+ [S+1]
1576
+ [173Lu]
1577
+ [Xe]
1578
+ [104Pd]
1579
+ [/N]
1580
+ [SH0]
1581
+ [14O]
1582
+ [Ca-2]
1583
+ [=XeH1]
1584
+ [InH4-1]
1585
+ [Si-2]
1586
+ [AsH4]
1587
+ [99Ru+2]
1588
+ [Zn]
1589
+ [\S-1]
1590
+ [=Te]
1591
+ [Br+2]
1592
+ [198Tl]
1593
+ [25Mg+2]
1594
+ [/N-1]
1595
+ [10BH3]
1596
+ [195Pt+4]
1597
+ [236Pu]
1598
+ [I+1]
1599
+ [/SiH1-1]
1600
+ [InH2]
1601
+ [\B@-1]
1602
+ [60Fe]
1603
+ [14OH2]
1604
+ [233Pa]
1605
+ [199Tl+1]
1606
+ [Am]
1607
+ [Eu]
1608
+ [=GeH2]
1609
+ [158Tb]
1610
+ [=Hf]
1611
+ [=WH2]
1612
+ [AlH2+1]
1613
+ [Er]
1614
+ [189Pt]
1615
+ [172Tm]
1616
+ [Pt-4]
1617
+ [16CH2]
1618
+ [16N+1]
1619
+ [BH1-1]
1620
+ [148Pm]
1621
+ [225Ac]
1622
+ [=19C]
1623
+ [99Rh]
1624
+ [125I]
1625
+ [\79Br]
1626
+ [ReH1]
1627
+ [27Al+3]
1628
+ [Ir]
1629
+ [\AsH2]
1630
+ [23Na+1]
1631
+ [Md]
1632
+ [119In]
1633
+ [56Co]
1634
+ [104Rh]
1635
+ [\C@H1]
1636
+ [235U]
1637
+ [MoH3]
1638
+ [\In]
1639
+ [247Cm]
1640
+ [\O-1]
1641
+ [/P@@]
1642
+ [36Cl]
1643
+ [153Sm+3]
1644
+ [236Np]
1645
+ [164Dy]
1646
+ [U+2]
1647
+ [/Sn+1]
1648
+ [16C]
1649
+ [KH1]
1650
+ [Zr-3]
1651
+ [241Am]
1652
+ [131IH1]
1653
+ [ClH1+2]
1654
+ [121SnH2]
1655
+ [MoH5]
1656
+ [/AsH1]
1657
+ [#18O+1]
1658
+ [Re+1]
1659
+ [187Os]
1660
+ [=SiH1-1]
1661
+ [170Hf]
1662
+ [37Cl]
1663
+ [184Ir]
1664
+ [\TeH1]
1665
+ [\Sn-1]
1666
+ [/11CH3]
1667
+ [#Tm]
1668
+ [189Os]
1669
+ [48Cr]
1670
+ [120Te]
1671
+ [201Hg]
1672
+ [PH1-1]
1673
+ [=AsH2]
1674
+ [I+2]
1675
+ [\ClH1+1]
1676
+ [62Cu+1]
1677
+ [Si@@]
1678
+ [\I-1]
1679
+ [=PH0]
1680
+ [BrH0]
1681
+ [Li]
1682
+ [O+1]
1683
+ [117Sn]
1684
+ [199Tl]
1685
+ [148Nd]
1686
+ [NaH1]
1687
+ [62Zn+2]
1688
+ [S-2]
1689
+ [3He]
1690
+ [Ta+5]
1691
+ [In-1]
1692
+ [82Sr+2]
1693
+ [194Tl]
1694
+ [C]
1695
+ [GeH4]
1696
+ [36ClH1]
1697
+ [14N]
1698
+ [73Ga]
1699
+ [=99Tc+5]
1700
+ [TeH2+1]
1701
+ [SbH2]
1702
+ [210Tl]
1703
+ [13C]
1704
+ [=Tl]
1705
+ [\15N]
1706
+ [/SeH1]
1707
+ [181W]
1708
+ [9Li]
1709
+ [82Rb]
1710
+ [72Zn]
1711
+ [124Te]
1712
+ [Ac]
1713
+ [/P]
1714
+ [156Eu]
1715
+ [203PbH1]
1716
+ [110Ag]
1717
+ [144Sm]
1718
+ [Li+1]
1719
+ [Ni]
1720
+ [71Ga]
1721
+ [65Cu+2]
1722
+ [63Ni+2]
1723
+ [CuH2-1]
1724
+ [113Cd]
1725
+ [Cl@-1]
1726
+ [178Hf]
1727
+ [=S@]
1728
+ [45K]
1729
+ [127Cs+1]
1730
+ [RuH1-1]
1731
+ [171Yb]
1732
+ [TiH4]
1733
+ [58Fe+3]
1734
+ [231U]
1735
+ [Cr-2]
1736
+ [ClH1-1]
1737
+ [OH0]
1738
+ [37Ar]
1739
+ [94Y]
1740
+ [EuH3]
1741
+ [P@@H1+1]
1742
+ [P-1]
1743
+ [Co+1]
1744
+ [131Te]
1745
+ [18F-1]
1746
+ [=Mn]
1747
+ [67Cu+2]
1748
+ [200Po]
1749
+ [=14CH1]
1750
+ [Os+5]
1751
+ [86Rb]
1752
+ [SeH5]
1753
+ [Lu+3]
1754
+ [106Ru+3]
1755
+ [/C@@H1]
1756
+ [/124I]
1757
+ [=Ru+1]
1758
+ [91Sr]
1759
+ [#14C-1]
1760
+ [/GeH2]
1761
+ [15NH1-1]
1762
+ [201Pb]
1763
+ [240Pu]
1764
+ [192Bi]
1765
+ [Si@@H1]
1766
+ [38K+1]
1767
+ [As+5]
1768
+ [Cd-2]
1769
+ [197Hg]
1770
+ [=Sb]
1771
+ [CH1+1]
1772
+ [18O-1]
1773
+ [Np]
1774
+ [Ru-4]
1775
+ [F]
1776
+ [=Tc]
1777
+ [CH2-1]
1778
+ [Ir+1]
1779
+ [109Pd]
1780
+ [SnH2-1]
1781
+ [\P-1]
1782
+ [17OH1]
1783
+ [142Pm]
1784
+ [Ca-4]
1785
+ [116Te]
1786
+ [Hf]
1787
+ [7Li+1]
1788
+ [18F]
1789
+ [Cr-3]
1790
+ [/Si+1]
1791
+ [ScH3]
1792
+ [51Fe]
1793
+ [155Dy]
1794
+ [191Pt+4]
1795
+ [178Ta]
1796
+ [126Sn]
1797
+ [148Gd]
1798
+ [NH1+1]
1799
+ [94Ru]
1800
+ [123I-1]
1801
+ [38S]
1802
+ [64Ni]
1803
+ [/14CH3]
1804
+ [=Sr]
1805
+ [192Ir]
1806
+ [=Th+2]
1807
+ [Ni+3]
1808
+ [PH1-2]
1809
+ [85Br]
1810
+ [=Zn]
1811
+ [=B]
1812
+ [Au-1]
1813
+ [=RhH1]
1814
+ [211At]
1815
+ [65Zn+2]
1816
+ [OH1-1]
1817
+ [P@@+1]
1818
+ [/SH2+1]
1819
+ [BH2-1]
1820
+ [CaH2]
1821
+ [N+1]
1822
+ [113In]
1823
+ [33P]
1824
+ [InH1-1]
1825
+ [90Tc]
1826
+ [Ti+1]
1827
+ [\ClH1]
1828
+ [Pt-3]
1829
+ [213Bi]
1830
+ [170Tm+3]
1831
+ [=PH2+1]
1832
+ [/TeH1]
1833
+ [76BrH1]
1834
+ [200Pb]
1835
+ [82Se-2]
1836
+ [191Os]
1837
+ [PtH1]
1838
+ [75BrH1]
1839
+ [Db]
1840
+ [/NH1-1]
1841
+ [\PH1-1]
1842
+ [218Pb]
1843
+ [=Co]
1844
+ [/In]
1845
+ [=Yb]
1846
+ [100Tc+4]
1847
+ [NH4+1]
1848
+ [=Si]
1849
+ [Ga]
1850
+ [=Pd]
1851
+ [64Cu+1]
1852
+ [Ce]
1853
+ [86Tc]
1854
+ [Ru-1]
1855
+ [120I-1]
1856
+ [217At]
1857
+ [\GeH1]
1858
+ [234Pu]
1859
+ [TeH1+1]
1860
+ [/14CH2]
1861
+ [180Re]
1862
+ [62Ga]
1863
+ [=S@@]
1864
+ [15O]
1865
+ [59Fe+3]
1866
+ [168Er]
1867
+ [246Bk]
1868
+ [BH3+1]
1869
+ [81Br-1]
1870
+ [53Cr]
1871
+ [122I]
1872
+ [/Cl-1]
1873
+ [=100Tc+1]
1874
+ [#14C]
1875
+ [127IH1]
1876
+ [PtH1+1]
1877
+ [126IH1]
1878
+ [/-Ring1]
1879
+ [/GeH1]
1880
+ [TeH4]
1881
+ [16NH1]
1882
+ [108Pd]
1883
+ [35S-2]
1884
+ [127I]
1885
+ [161Er]
1886
+ [145Nd]
1887
+ [187W]
1888
+ [\NH1]
1889
+ [Mn-2]
1890
+ [10C]
1891
+ [=Lu]
1892
+ [38K]
1893
+ [Se+1]
1894
+ [28Mg+2]
1895
+ [135IH1]
1896
+ [227Pa]
1897
+ [238Np]
1898
+ [/S@@]
1899
+ [239U]
1900
+ [\Te]
1901
+ [\BH2-1]
1902
+ [#S+1]
1903
+ [XeH2]
1904
+ [154Gd]
1905
+ [Pa]
1906
+ [\N+1]
1907
+ [/BH0]
1908
+ [AlH2-1]
1909
+ [=Ga]
1910
+ [223Fr]
1911
+ [194Os]
1912
+ [161Tb]
1913
+ [#Bi]
1914
+ [K+1]
1915
+ [58Fe+2]
1916
+ [Ra]
1917
+ [OH1]
1918
+ [SiH3-2]
1919
+ [/18C]
1920
+ [AsH2+1]
1921
+ [147Sm]
1922
+ [SnH3]
1923
+ [AsH3-1]
1924
+ [RuH3]
1925
+ [181Os]
1926
+ [63Zn]
1927
+ [81Rb+1]
1928
+ [78As]
1929
+ [162Dy]
1930
+ [=Nb]
1931
+ [=Sn]
1932
+ [177Lu+3]
1933
+ [13NH1]
1934
+ [233Ra]
1935
+ [129I]
1936
+ [118Pd]
1937
+ [131Xe]
1938
+ [=Te-1]
1939
+ [142Ba]
1940
+ [10CH3]
1941
+ [32Si]
1942
+ [234Th]
1943
+ [250Bk]
1944
+ [\14C]
1945
+ [10CH2]
1946
+ [/15NH1]
1947
+ [135I-1]
1948
+ [157Dy]
1949
+ [Ba+2]
1950
+ [/B]
1951
+ [SbH1]
1952
+ [OH2+1]
1953
+ [15CH3]
1954
+ [Ring3]
1955
+ [WH1]
1956
+ [136Pr]
1957
+ [82Sr]
1958
+ [Sn@]
1959
+ [196Pb]
1960
+ [76Kr]
1961
+ [#Mo]
1962
+ [Os-2]
1963
+ [\Ga]
1964
+ [208Tl+1]
1965
+ [138Ce]
1966
+ [#NH1+1]
1967
+ [87Rb]
1968
+ [195Tl]
1969
+ [Zr+4]
1970
+ [8B]
1971
+ [112Ag]
1972
+ [/N@@+1]
1973
+ [150Pm]
1974
+ [106Ru]
1975
+ [13C@@H1]
1976
+ [3H-1]
1977
+ [37ClH1]
1978
+ [227Th+4]
1979
+ [IrH4]
1980
+ [16CH3]
1981
+ [/Bi]
1982
+ [Th+4]
1983
+ [AlH2-2]
1984
+ [/C+1]
1985
+ [/Sb]
1986
+ [242Cm]
1987
+ [39K]
1988
+ [155Gd]
1989
+ [Branch1]
1990
+ [=TaH1]
1991
+ [208Po]
1992
+ [98Nb]
1993
+ [196Au]
1994
+ [=Rh]
1995
+ [17NH1]
1996
+ [K]
1997
+ [57Fe+2]
1998
+ [218Po]
1999
+ [/SnH1]
2000
+ [=13CH2]
2001
+ [TlH2+1]
2002
+ [Sr]
2003
+ [88Rb+1]
2004
+ [68GaH3]
2005
+ [36SH1]
2006
+ [\SH1+1]
2007
+ [165Er]
2008
+ [/S+1]
2009
+ [RuH1]
2010
+ [=Tb]
2011
+ [Mn+3]
2012
+ [12CH2]
2013
+ [98Tc+5]
2014
+ [#99Tc]
2015
+ [/19F]
2016
+ [Be+2]
2017
+ [15C-1]
2018
+ [Os-1]
2019
+ [=MoH2]
2020
+ [191Pt]
2021
+ [134Cs+1]
2022
+ [120Sn]
2023
+ [6Li+1]
2024
+ [141Pr+3]
2025
+ [SeH1]
2026
+ [\GeH3]
2027
+ [AgH1]
2028
+ [168Tm]
2029
+ [26Al]
2030
+ [/S@]
2031
+ [ZrH2+2]
2032
+ [130Sb]
2033
+ [GeH2]
2034
+ [170Yb]
2035
+ [129Xe]
2036
+ [15N-1]
2037
+ [228Pa]
2038
+ [/Ru]
2039
+ [#B-1]
2040
+ [As-1]
2041
+ [41Ar]
2042
+ [103Ag]
2043
+ [Tc]
2044
+ [120Sb]
2045
+ [P-2]
2046
+ [/W]
2047
+ [22NH1]
2048
+ [=15NH1+1]
2049
+ [\At]
2050
+ [Pb+2]
2051
+ [242Am]
2052
+ [148Sm]
2053
+ [56Fe]
2054
+ [222Ra]
2055
+ [251Cf]
2056
+ [1HH1]
2057
+ [RuH1+2]
2058
+ [61Cu]
2059
+ [#As+1]
2060
+ [114In]
2061
+ [38PH3]
2062
+ [=12C]
2063
+ [88Kr]
2064
+ [/CH0]
2065
+ [HH1]
2066
+ [123Te]
2067
+ [F-1]
2068
+ [117Sb]
2069
+ [IH2]
2070
+ [152Sm]
2071
+ [42K]
2072
+ [189Re]
2073
+ [115Sn]
2074
+ [212Bi]
2075
+ [Mn]
2076
+ [31Si]
2077
+ [/18OH1]
2078
+ [Ba+1]
2079
+ [Ni-1]
2080
+ [245Am]
2081
+ [#Te]
2082
+ [104Tc]
2083
+ [Ir+3]
2084
+ [PdH2]
2085
+ [V+4]
2086
+ [Cr+2]
2087
+ [=Pd-3]
2088
+ [12C@H1]
2089
+ [94Mo]
2090
+ [RhH2]
2091
+ [89Zr]
2092
+ [\NH2+1]
2093
+ [13C@H1]
2094
+ [\35Cl]
2095
+ [12C@@H1]
2096
+ [TiH1+3]
2097
+ [\3H]
2098
+ [=BH0]
2099
+ [13O]
2100
+ [\14CH2]
2101
+ [205Tl]
2102
+ [167Yb]
2103
+ [27Al]
2104
+ [51Cr+3]
2105
+ [178Re]
2106
+ [Fe-3]
2107
+ [Eu+3]
2108
+ [84Kr]
2109
+ [166Ho]
2110
+ [244Cf]
2111
+ [PH0]
2112
+ [111Ag]
2113
+ [=IH1]
2114
+ [51V]
2115
+ [FeH4-3]
2116
+ [NH2+1]
2117
+ [\BH3-1]
2118
+ [245Bk]
2119
+ [\SiH1]
2120
+ [151Gd]
2121
+ [100Tc]
2122
+ [/14NH1]
2123
+ [98Tc+7]
2124
+ [=Eu]
2125
+ [197Pt]
2126
+ [\BH1-1]
2127
+ [80Rb]
2128
+ [216Po]
2129
+ [Mo+2]
2130
+ [88Zr]
2131
+ [/-Ring2]
2132
+ [230Pa]
2133
+ [123Xe]
2134
+ [/Si@]
2135
+ [34S-1]
2136
+ [At]
2137
+ [Hg-1]
2138
+ [126Te]
2139
+ [44Ca]
2140
+ [Yb]
2141
+ [Fe+1]
2142
+ [/Br]
2143
+ [14N+1]
2144
+ [99Y+3]
2145
+ [75As+5]
2146
+ [100Mo]
2147
+ [205Bi]
2148
+ [Si+3]
2149
+ [=Bi+1]
2150
+ [148Tb]
2151
+ [212Ra]
2152
+ [#AsH1]
2153
+ [142Nd]
2154
+ [127Sb]
2155
+ [Sb-1]
2156
+ [=77Se]
2157
+ [17OH1-1]
2158
+ [18N]
2159
+ [128I]
2160
+ [Sb+3]
2161
+ [=Re+1]
2162
+ [20Ne]
2163
+ [TlH3]
2164
+ [151Eu]
2165
+ [/Si]
2166
+ [99Ru]
2167
+ [124I-1]
2168
+ [CrH2]
2169
+ [MoH4]
2170
+ [240U]
2171
+ [162Yb]
2172
+ [22Na]
2173
+ [AsH1-1]
2174
+ [ThH4]
2175
+ [#Os-1]
2176
+ [90Sr]
2177
+ [74Ge]
2178
+ [19OH2]
2179
+ [149Tb]
2180
+ [\13CH1]
2181
+ [43Sc]
2182
+ [188Ir]
2183
+ [255Fm]
2184
+ [197Au]
2185
+ [SeH1+1]
2186
+ [Rh+2]
2187
+ [Tl+3]
2188
+ [\Br-1]
2189
+ [36Cl-1]
2190
+ [/I]
2191
+ [121Te]
2192
+ [ClH1]
2193
+ [Sn]
2194
+ [\SH0]
2195
+ [186Re]
2196
+ [188Pt]
2197
+ [\13CH3]
2198
+ [Si]
2199
+ [15NH2+1]
2200
+ [/2H]
2201
+ [=Fe+1]
2202
+ [209BiH3]
2203
+ [152Eu]
2204
+ [/CH2]
2205
+ [20CH1]
2206
+ [38Cl]
2207
+ [Bi-2]
2208
+ [94Tc+7]
2209
+ [\GeH2]
2210
+ [11B]
2211
+ [/Si@H1]
2212
+ [68Cu]
2213
+ [#Mn]
2214
+ [181Re]
2215
+ [Os]
2216
+ [Br+3]
2217
+ [230Ra]
2218
+ [156Tb]
2219
+ [152Gd]
2220
+ [/NH3+1]
2221
+ [Bk]
2222
+ [190Os]
2223
+ [ClH4+3]
2224
+ [Cl-1]
2225
+ [\C]
2226
+ [\SiH1-1]
2227
+ [#I]
2228
+ [Lu]
2229
+ [SnH1+1]
2230
+ [162Tm]
2231
+ [236U]
2232
+ [Cr+3]
2233
+ [122Sb]
2234
+ [131Sb]
2235
+ [209Po]
2236
+ [Ar]
2237
+ [166Ho+3]
2238
+ [114Sn]
2239
+ [48Ti]
2240
+ [Ti+4]
2241
+ [121Sb]
2242
+ [190Ir]
2243
+ [W]
2244
+ [Cs]
2245
+ [SnH1+3]
2246
+ [105Rh+3]
2247
+ [Mo-1]
2248
+ [C@H1]
2249
+ [MgH2]
2250
+ [AlH2]
2251
+ [20CH3]
2252
+ [Tb]
2253
+ [92Y]
2254
+ [/15NH2]
2255
+ [#C+1]
2256
+ [17O]
2257
+ [144Ce]
2258
+ [162Er]
2259
+ [175Yb]
2260
+ [80Br]
2261
+ [127Sb+3]
2262
+ [77Se]
2263
+ [177Hf]
2264
+ [64Ga]
2265
+ [144Cs]
2266
+ [Al+1]
2267
+ [139Ba]
2268
+ [=CH1]
2269
+ [\Sb]
2270
+ [89Rb]
2271
+ [142Sm]
2272
+ [89Kr]
2273
+ [=15NH1]
2274
+ [=Branch2]
2275
+ [Y+3]
2276
+ [13NH2]
2277
+ [14NH4+1]
2278
+ [=10B]
2279
+ [67Ga]
2280
+ [=P@@]
2281
+ [57Ni]
2282
+ [CH3-1]
2283
+ [223Ra]
2284
+ [62Zn]
2285
+ [SH1-1]
2286
+ [=Ir]
2287
+ [CH2+1]
2288
+ [212PbH2]
2289
+ [GeH6-2]
2290
+ [=Ho]
2291
+ [\CH2]
2292
+ [125IH1]
2293
+ [65Ni]
2294
+ [124Sb]
2295
+ [246Cm]
2296
+ [167Dy]
2297
+ [CH0]
2298
+ [224Rn]
2299
+ [Th]
2300
+ [B-1]
2301
+ [=11CH1]
2302
+ [=106Ru]
2303
+ [LiH1]
2304
+ [241Cm]
2305
+ [=99Tc]
2306
+ [\Tl]
2307
+ [RuH1+1]
2308
+ [OsH2]
2309
+ [ZrH2]
2310
+ [93Tc]
2311
+ [71Ge]
2312
+ [Te+4]
2313
+ [143Cs]
2314
+ [140La]
2315
+ [131I-1]
2316
+ [172Lu]
2317
+ [78Se]
2318
+ [6He]
2319
+ [238U]
2320
+ [#As]
2321
+ [#Ru-1]
2322
+ [=ZrH2]
2323
+ [204Pb]
2324
+ [82Se+4]
2325
+ [205Po]
2326
+ [=B+1]
2327
+ [=CH1-1]
2328
+ [=ReH1]
2329
+ [191Os+4]
2330
+ [60Co+2]
2331
+ [Pd-2]
2332
+ [/B-1]
2333
+ [/14C]
2334
+ [Ne]
2335
+ [51Cr+6]
2336
+ [SeH3]
2337
+ [183Hf]
2338
+ [\AlH2]
2339
+ [Ru]
2340
+ [B@-1]
2341
+ [186W]
2342
+ [S@]
2343
+ [SiH4-1]
2344
+ [194Pb]
2345
+ [239Th]
2346
+ [105Ru]
2347
+ [SbH1-1]
2348
+ [=BH1-1]
2349
+ [107Ag]
2350
+ [115Ag]
2351
+ [PtH4]
2352
+ [154Eu]
2353
+ [14NH1]
2354
+ [BiH4]
2355
+ [70Zn]
2356
+ [#Al]
2357
+ [\AsH1]
2358
+ [174Hf+4]
2359
+ [#15N+1]
2360
+ [CH1]
2361
+ [157Lu]
2362
+ [89Nb]
2363
+ [GeH5-1]
2364
+ [50Ti]
2365
+ [207Po]
2366
+ [31P-3]
2367
+ [\S@@]
2368
+ [47Ca]
2369
+ [Dy]
2370
+ [Ag+1]
2371
+ [147Pr]
2372
+ [=238U]
2373
+ [139Nd]
2374
+ [CrH1+2]
2375
+ [230Th]
2376
+ [216Bi]
2377
+ [OH1+1]
2378
+ [55Co]
2379
+ [#Se]
2380
+ [83Sr]
2381
+ [158Dy]
2382
+ [#Co]
2383
+ [35SH2]
2384
+ [C@]
2385
+ [185Os]
2386
+ [161Dy]
2387
+ [/F]
2388
+ [\SbH1]
2389
+ [210Po]
2390
+ [34ClH1]
2391
+ [\-Ring1]
2392
+ [125Te+4]
2393
+ [141La]
2394
+ [NH2-1]
2395
+ [30S]
2396
+ [166Dy]
2397
+ [11CH3]
2398
+ [TlH1]
2399
+ [OsH4]
2400
+ [Re-1]
2401
+ [AlH6-3]
2402
+ [202Po]
2403
+ [=C+1]
2404
+ [=Se+1]
2405
+ [N]
2406
+ [32SH2]
2407
+ [=Branch1]
2408
+ [P@@H1]
2409
+ [Pd-3]
2410
+ [17OH2]
2411
+ [Si+2]
2412
+ [#Tc]
2413
+ [188Os]
2414
+ [195Hg]
2415
+ [244Cm]
2416
+ [133Ba]
2417
+ [PH2-1]
2418
+ [15NH1+1]
2419
+ [6Li]
2420
+ [138Nd]
2421
+ [PbH3]
2422
+ [10CH4]
2423
+ [#Os+2]
2424
+ [22CH2]
2425
+ [/At]
2426
+ [214Bi]
2427
+ [228Ra]
2428
+ [Ba]
2429
+ [14C-1]
2430
+ [Cs+1]
2431
+ [239Am]
2432
+ [85Sr]
2433
+ [/OH1+1]
2434
+ [29Al]
2435
+ [NbH2]
2436
+ [70Ga]
2437
+ [59Fe]
2438
+ [RuH1+3]
2439
+ [111Sn]
2440
+ [Ta]
2441
+ [112Pd]
2442
+ [Rh+3]
2443
+ [Ru-3]
2444
+ [245Cm]
2445
+ [=N]
2446
+ [Ge+4]
2447
+ [\13CH2]
2448
+ [=SiH1+1]
2449
+ [59Fe+2]
2450
+ [202Tl]
2451
+ [117Sn+2]
2452
+ [40Ar]
2453
+ [156Dy]
2454
+ [79Rb+1]
2455
+ [/HgH1]
2456
+ [15N+1]
2457
+ [O]
2458
+ [125I-1]
2459
+ [99Tc+6]
2460
+ [186Ir]
2461
+ [SiH1]
2462
+ [/13C]
2463
+ [/SnH3]
2464
+ [131Cs]
2465
+ [111In+3]
2466
+ [Pm]
2467
+ [150Sm]
2468
+ [117In]
2469
+ [20C]
2470
+ [194Bi]
2471
+ [16O]
2472
+ [Si+4]
2473
+ [=I]
2474
+ [Mo+1]
2475
+ [Pr+3]
2476
+ [80Kr]
2477
+ [=10CH1]
2478
+ [49Cr]
2479
+ [248Cf]
2480
+ [160Gd]
2481
+ [Ca]
2482
+ [132Te]
2483
+ [/P+1]
2484
+ [48Sc]
2485
+ [=RuH1]
2486
+ [150Eu]
2487
+ [79Kr]
2488
+ [Al+3]
2489
+ [#Si]
2490
+ [Ca+2]
2491
+ [SeH2+1]
2492
+ [#Si-1]
2493
+ [Ga-1]
2494
+ [/OH2+1]
2495
+ [Se-2]
2496
+ [195Au]
2497
+ [102Ag]
2498
+ [#P+1]
2499
+ [115Cd]
2500
+ [14NH2]
2501
+ [=RuH2]
2502
+ [243Cm]
2503
+ [Se+6]
2504
+ [209Pb]
2505
+ [Ge@@]
2506
+ [ClH3+3]
2507
+ [16NH3]
2508
+ [248Am]
2509
+ [#34S+1]
2510
+ [12N+1]
2511
+ [#WH1]
2512
+ [135Ce]
2513
+ [240Am]
2514
+ [=SbH1]
2515
+ [SbH4+1]
2516
+ [32PH1]
2517
+ [80Sr]
2518
+ [=CH1+1]
2519
+ [=33S]
2520
+ [ZnH2]
2521
+ [\Se+1]
2522
+ [11BH3]
2523
+ [203Hg+2]
2524
+ [15OH1]
2525
+ [Tl]
2526
+ [Hs]
2527
+ [/PH0]
2528
+ [87Sr]
2529
+ [=N+1]
2530
+ [\Hg]
2531
+ [=15O]
2532
+ [100Pd]
2533
+ [10CH1]
2534
+ [Pd-4]
2535
+ [98Tc]
2536
+ [226Ac]
2537
+ [13CH2]
2538
+ [#Lu]
2539
+ [B@H1-1]
2540
+ [240Np]
2541
+ [110Ag+1]
2542
+ [137Cs+1]
2543
+ [=15CH1]
2544
+ [147Eu]
2545
+ [257Md]
2546
+ [#Hf+1]
2547
+ [=Mn-1]
2548
+ [\OH0]
2549
+ [=SnH2]
2550
+ [Se@@H1]
2551
+ [Zr]
2552
+ [32SH1]
2553
+ [#TaH1]
2554
+ [198Au+3]
2555
+ [38ClH1]
2556
+ [33SH1]
2557
+ [\Cl-1]
2558
+ [38PH1]
2559
+ [11C@H1]
2560
+ [9CH1]
2561
+ [134Ce]
2562
+ [Si@H1]
2563
+ [=Au]
2564
+ [AsH1+1]
2565
+ [15CH1]
2566
+ [/PH1+1]
2567
+ [Ce+3]
2568
+ [CoH1+1]
2569
+ [Os+8]
2570
+ [/125Te]
2571
+ [145Gd]
2572
+ [204Hg]
2573
+ [=Pt]
2574
+ [#13CH1]
2575
+ [W+2]
2576
+ [RuH2]
2577
+ [#Sn]
2578
+ [=Ge]
2579
+ [Tc+7]
2580
+ [37Cl-1]
2581
+ [237U]
2582
+ [16N]
2583
+ [/Si-2]
2584
+ [63Cu]
2585
+ [WH4]
2586
+ [Yb+2]
2587
+ [=SH1-1]
2588
+ [121Sn+2]
2589
+ [176Hf]
2590
+ [217Po]
2591
+ [177Lu]
2592
+ [176Lu]
2593
+ [78Ge]
2594
+ [130Cs+1]
2595
+ [211Pb]
2596
+ [Hg]
2597
+ [81Br]
2598
+ [=NiH1]
2599
+ [116In]
2600
+ [GeH3-1]
2601
+ [45Ti]
2602
+ [15C]
2603
+ [=OsH1]
2604
+ [BH3-1]
2605
+ [128Ba]
2606
+ [165Tm]
2607
+ [40K]
2608
+ [SnH2+1]
2609
+ [=Sm]
2610
+ [41K]
2611
+ [154Sm]
2612
+ [158Eu]
2613
+ [97Mo]
2614
+ [116Sb]
2615
+ [207Pb]
2616
+ [11C@@H1]
2617
+ [Ti+3]
2618
+ [Eu+2]
2619
+ [=14NH1]
2620
+ [=IH2]
2621
+ [142Ce]
2622
+ [=14O]
2623
+ [Cd-1]
2624
+ [Os+2]
2625
+ [#Os-2]
2626
+ [Sn+4]
2627
+ [Fe-2]
2628
+ [P]
2629
+ [226Th]
2630
+ [SrH2]
2631
+ [34SH2]
2632
+ [193Ir]
2633
+ [TaH3]
2634
+ [N@@+1]
2635
+ [41Ca]
2636
+ [125Cs]
2637
+ [200Au]
2638
+ [155Tb]
2639
+ [13CH4]
2640
+ [34SH1]
2641
+ [#Ring2]
2642
+ [111In]
2643
+ [=235U+2]
2644
+ [149Sm]
2645
+ [19CH2]
2646
+ [132Cs+1]
2647
+ [44K]
2648
+ [18OH1-1]
2649
+ [=Ring3]
2650
+ [/CH1]
2651
+ [64Cu+2]
2652
+ [159Gd]
2653
+ [\OH2+1]
2654
+ [#11CH1]
2655
+ [=U+2]
2656
+ [82Se]
2657
+ [RuH6]
2658
+ [249Cf]
2659
+ [Na+1]
2660
+ [O-2]
2661
+ [#Zr+1]
2662
+ [201Tl+1]
2663
+ [86Kr]
2664
+ [/11C]
2665
+ [/3H]
2666
+ [As@@]
2667
+ [124I]
2668
+ [Fe-4]
2669
+ [Fe+4]
2670
+ [75Br]
2671
+ [147Nd]
2672
+ [128Te]
2673
+ [141Ce]
2674
+ [Bi+3]
2675
+ [103Pd+2]
2676
+ [198Hg]
2677
+ [199Pb]
2678
+ [101Rh]
2679
+ [=Cr]
2680
+ [136Ba]
2681
+ [127Cs]
2682
+ [135Cs]
2683
+ [56Mn]
2684
+ [NiH1]
2685
+ [55Mn]
2686
+ [=V+2]
2687
+ [178W]
2688
+ [139Ce]
2689
+ [167Tm]
2690
+ [147Pm]
2691
+ [#11C-1]
2692
+ [188Re]
2693
+ [Fm]
2694
+ [Yb+3]
2695
+ [Gd]
2696
+ [Fe+5]
2697
+ [NH2]
2698
+ [57Co]
2699
+ [88Sr+2]
2700
+ [147Gd]
2701
+ [Cf]
2702
+ [79Br-1]
2703
+ [=Sc]
2704
+ [#CH0]
2705
+ [22CH4]
2706
+ [135Ba]
2707
+ [237Am]
2708
+ [146Gd]
2709
+ [Te@@]
2710
+ [N@@]
2711
+ [/13CH3]
2712
+ [Sm]
2713
+ [73Ge]
2714
+ [71As]
2715
+ [PbH2]
2716
+ [TaH1]
2717
+ [122Xe]
2718
+ [165Dy]
2719
+ [123Sb]
2720
+ [67GaH3]
2721
+ [/Se+1]
2722
+ [B+1]
2723
+ [83Rb]
2724
+ [Cu+2]
2725
+ [13C@]
2726
+ [AuH1]
2727
+ [\P]
2728
+ [157Eu]
2729
+ [85Rb]
2730
+ [Pt]
2731
+ [235Np]
2732
+ [80BrH1]
2733
+ [\18F]
2734
+ [P@]
2735
+ [203Po]
2736
+ [125Cs+1]
2737
+ [P+1]
2738
+ [=18CH2]
2739
+ [45K+1]
2740
+ [Co-3]
2741
+ [ZnH1+1]
2742
+ [57Co+2]
2743
+ [=PbH2]
2744
+ [=Ti+1]
2745
+ [174Ta]
2746
+ [#Ho]
2747
+ [/B+1]
2748
+ [\37Cl]
2749
+ [100Tc+5]
2750
+ [2H]
2751
+ [13B]
2752
+ [155Sm]
2753
+ [#N+1]
2754
+ [NH1-1]
2755
+ [32P]
2756
+ [58Co+3]
2757
+ [/35Cl]
2758
+ [=NH1+1]
2759
+ [=Pr+1]
2760
+ [Ir+2]
2761
+ [/Pb]
2762
+ [15NH3]
2763
+ [CuH2]
2764
+ [114In+3]
2765
+ [Ru+1]
2766
+ [Fe-1]
2767
+ [198Bi]
2768
+ [SH2]
2769
+ [RhH1+2]
2770
+ [176W]
2771
+ [200Hg]
2772
+ [Hf+4]
2773
+ [10BH1]
2774
+ [Hg-2]
2775
+ [179W]
2776
+ [252Fm]
2777
+ [PbH4]
2778
+ [/O]
2779
+ [He]
2780
+ [=Hg]
2781
+ [183W]
2782
+ [157Ho]
2783
+ [Be]
2784
+ [#Ti+1]
2785
+ [Rh-4]
2786
+ [=S-1]
2787
+ [72Se]
2788
+ [#Sm]
2789
+ [=9C]
2790
+ [Be+1]
2791
+ [180Ta]
2792
+ [/-Ring3]
2793
+ [/IH1+1]
2794
+ [Ring2]
2795
+ [/H]
2796
+ [129Sb]
2797
+ [174Yb]
2798
+ [149Gd]
2799
+ [=Br]
2800
+ [Mn+2]
2801
+ [36S]
2802
+ [14C@H1]
2803
+ [34S]
2804
+ [CoH1]
2805
+ [\TlH1]
2806
+ [170Tm]
2807
+ [68Ge+4]
2808
+ [210PoH2]
2809
+ [=Os]
2810
+ [179Lu]
2811
+ [/AlH1]
2812
+ [195Po]
2813
+ [Ru+5]
2814
+ [81BrH1]
2815
+ [17FH1]
2816
+ [#S-1]
2817
+ [136Eu+3]
2818
+ [NH3+1]
2819
+ [68GaH1]
2820
+ [28Mg]
2821
+ [=O+1]
2822
+ [#Fe]
2823
+ [60Ni+2]
2824
+ [Rh+1]
2825
+ [43Ca+2]
2826
+ [/As]
2827
+ [PdH1+1]
2828
+ [141Cs]
2829
+ [=AsH1]
2830
+ [#V]
2831
+ [229Rn]
2832
+ [17CH1]
2833
+ [95Ru]
2834
+ [67Zn]
2835
+ [153Pm]
2836
+ [#P-1]
2837
+ [Bh]
2838
+ [=Cl]
2839
+ [80Se]
2840
+ [RuH4]
2841
+ [143Pm]
2842
+ [=N-1]
2843
+ [#Os]
2844
+ [N@+1]
2845
+ [/Si@@H1]
2846
+ [Sg]
2847
+ [76Se]
2848
+ [=AsH3]
2849
+ [96Tc]
2850
+ [=P+1]
2851
+ [167Ho]
2852
+ [InH3]
2853
+ [193Po]
2854
+ [165Dy+3]
2855
+ [95Y]
2856
+ [C+1]
2857
+ [=Zr+2]
2858
+ [24Na+1]
2859
+ [89Zr+4]
2860
+ [189Ir]
2861
+ [=Bi]
2862
+ [198Pb]
2863
+ [#Gd]
2864
+ [La]
2865
+ [=Hf+2]
2866
+ [B@@-1]
2867
+ [/Cl]
2868
+ [GaH3]
2869
+ [93Zr]
2870
+ [251Es]
2871
+ [111InH2]
2872
+ [175Tm]
2873
+ [/SiH2+1]
2874
+ [H+1]
2875
+ [163Dy]
2876
+ [#Eu]
2877
+ [31S]
2878
+ [16O-1]
2879
+ [Mt]
2880
+ [110Sn]
2881
+ [Ti-2]
2882
+ [54Fe]
2883
+ [Mo+3]
2884
+ [/SH1+1]
2885
+ [72BrH1]
2886
+ [=TlH1]
2887
+ [Sn+1]
2888
+ [\H]
2889
+ [14CH3-1]
2890
+ [57Co+3]
2891
+ [14CH1-1]
2892
+ [145Sm]
2893
+ [Zr+2]
2894
+ [197Hg+1]
2895
+ [Ru+6]
2896
+ [17NH4+1]
2897
+ [60Co]
2898
+ [77Br]
2899
+ [193Pt]
2900
+ [35S]
2901
+ [133IH1]
2902
+ [147Tb]
2903
+ [95Mo]
2904
+ [52Ti]
2905
+ [129Cs]
2906
+ [133Te]
2907
+ [FH0]
2908
+ [=Ring2]
2909
+ [\B-1]
2910
+ [52Mn]
2911
+ [/PH3+1]
2912
+ [58Fe]
2913
+ [177Re]
2914
+ [49Sc]
2915
+ [52Mn+2]
2916
+ [250Es]
2917
+ [=99Tc+3]
2918
+ [53Cr+6]
2919
+ [206Po]
2920
+ [Pu]
2921
+ [/Si@@]
2922
+ [130Cs]
2923
+ [=SeH2]
2924
+ [IrH1+2]
2925
+ [180Hf]
2926
+ [83Rb+1]
2927
+ [15NH3+1]
2928
+ [Ga+3]
2929
+ [56Ni]
2930
+ [\Si-1]
2931
+ [13CH3]
2932
+ [62Ni]
2933
+ [110Te]
2934
+ [93Nb]
2935
+ [Sc+3]
2936
+ [88Sr]
2937
+ [12CH1-1]
2938
+ [CH3+1]
2939
+ [\13C]
2940
+ [151Tb]
2941
+ [77BrH1]
2942
+ [\S+1]
2943
+ [PH2]
2944
+ [\NH0]
environment.yaml ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: molecule
2
+ channels:
3
+ - defaults
4
+ dependencies:
5
+ - _libgcc_mutex=0.1=main
6
+ - _openmp_mutex=5.1=1_gnu
7
+ - bzip2=1.0.8=h5eee18b_5
8
+ - ca-certificates=2024.3.11=h06a4308_0
9
+ - ld_impl_linux-64=2.38=h1181459_1
10
+ - libffi=3.4.4=h6a678d5_0
11
+ - libgcc-ng=11.2.0=h1234567_1
12
+ - libgfortran-ng=7.5.0=ha8ba4b0_17
13
+ - libgfortran4=7.5.0=ha8ba4b0_17
14
+ - libgomp=11.2.0=h1234567_1
15
+ - libstdcxx-ng=11.2.0=h1234567_1
16
+ - libuuid=1.41.5=h5eee18b_0
17
+ - mpi=1.0=mpich
18
+ - mpi4py=3.1.4=py311hfc96bbd_0
19
+ - mpich=3.3.2=hc856adb_0
20
+ - ncurses=6.4=h6a678d5_0
21
+ - openssl=3.0.13=h7f8727e_0
22
+ - pip=23.3.1=py311h06a4308_0
23
+ - python=3.11.9=h955ad1f_0
24
+ - readline=8.2=h5eee18b_0
25
+ - setuptools=68.2.2=py311h06a4308_0
26
+ - sqlite=3.41.2=h5eee18b_0
27
+ - tk=8.6.12=h1ccaba5_0
28
+ - wheel=0.41.2=py311h06a4308_0
29
+ - xz=5.4.6=h5eee18b_0
30
+ - zlib=1.2.13=h5eee18b_0
31
+ - pip:
32
+ - aiohttp==3.9.5
33
+ - aiosignal==1.3.1
34
+ - annotated-types==0.6.0
35
+ - appdirs==1.4.4
36
+ - attrs==23.2.0
37
+ - blis==0.7.11
38
+ - blobfile==2.1.1
39
+ - catalogue==2.0.10
40
+ - certifi==2024.2.2
41
+ - charset-normalizer==3.3.2
42
+ - click==8.1.7
43
+ - cloudpathlib==0.16.0
44
+ - confection==0.1.4
45
+ - cymem==2.0.8
46
+ - datasets==2.19.0
47
+ - dill==0.3.8
48
+ - docker-pycreds==0.4.0
49
+ - fcd==1.2.2
50
+ - filelock==3.13.4
51
+ - frozenlist==1.4.1
52
+ - fsspec==2024.3.1
53
+ - gitdb==4.0.11
54
+ - gitpython==3.1.43
55
+ - huggingface-hub==0.22.2
56
+ - idna==3.7
57
+ - jinja2==3.1.3
58
+ - joblib==1.4.0
59
+ - levenshtein==0.25.1
60
+ - lxml==4.9.4
61
+ - markupsafe==2.1.5
62
+ - mpmath==1.3.0
63
+ - multidict==6.0.5
64
+ - multiprocess==0.70.16
65
+ - murmurhash==1.0.10
66
+ - networkx==3.2.1
67
+ - nltk==3.8.1
68
+ - numpy==1.26.4
69
+ - nvidia-cublas-cu12==12.1.3.1
70
+ - nvidia-cuda-cupti-cu12==12.1.105
71
+ - nvidia-cuda-nvrtc-cu12==12.1.105
72
+ - nvidia-cuda-runtime-cu12==12.1.105
73
+ - nvidia-cudnn-cu12==8.9.2.26
74
+ - nvidia-cufft-cu12==11.0.2.54
75
+ - nvidia-curand-cu12==10.3.2.106
76
+ - nvidia-cusolver-cu12==11.4.5.107
77
+ - nvidia-cusparse-cu12==12.1.0.106
78
+ - nvidia-nccl-cu12==2.20.5
79
+ - nvidia-nvjitlink-cu12==12.4.127
80
+ - nvidia-nvtx-cu12==12.1.105
81
+ - packaging==24.0
82
+ - pandas==2.2.1
83
+ - pfzy==0.3.4
84
+ - pillow==10.3.0
85
+ - preshed==3.0.9
86
+ - prompt-toolkit==3.0.43
87
+ - protobuf==4.25.3
88
+ - psutil==5.9.8
89
+ - pyarrow==15.0.2
90
+ - pyarrow-hotfix==0.6
91
+ - pycryptodomex==3.20.0
92
+ - pydantic==2.6.4
93
+ - pydantic-core==2.16.3
94
+ - python-dateutil==2.9.0.post0
95
+ - pytz==2024.1
96
+ - pyyaml==6.0.1
97
+ - rapidfuzz==3.8.1
98
+ - rdkit==2023.9.5
99
+ - regex==2023.12.25
100
+ - requests==2.31.0
101
+ - safetensors==0.4.2
102
+ - scipy==1.13.0
103
+ - selfies==2.1.1
104
+ - sentry-sdk==1.44.1
105
+ - setproctitle==1.3.3
106
+ - six==1.16.0
107
+ - smart-open==6.4.0
108
+ - smmap==5.0.1
109
+ - spacy-legacy==3.0.12
110
+ - spacy-loggers==1.0.5
111
+ - srsly==2.4.8
112
+ - sympy==1.12
113
+ - thinc==8.2.3
114
+ - tokenizers==0.15.2
115
+ - torch==2.3.0
116
+ - tqdm==4.66.2
117
+ - transformers==4.39.3
118
+ - triton==2.3.0
119
+ - typer==0.9.4
120
+ - typing-extensions==4.10.0
121
+ - tzdata==2024.1
122
+ - urllib3==2.2.1
123
+ - wandb==0.16.6
124
+ - wasabi==1.1.2
125
+ - wcwidth==0.2.13
126
+ - weasel==0.3.4
127
+ - xxhash==3.4.1
128
+ - yarl==1.9.4
129
+ prefix: /opt/conda/envs/molecule
inference.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import argparse
3
+ import selfies as sf
4
+ from tqdm import tqdm
5
+ from transformers import set_seed
6
+ from src.scripts.mytokenizers import Tokenizer
7
+ from src.improved_diffusion import gaussian_diffusion as gd
8
+ from src.improved_diffusion import dist_util, logger
9
+ from src.improved_diffusion.respace import SpacedDiffusion
10
+ from src.improved_diffusion.transformer_model import TransformerNetModel
11
+ from src.improved_diffusion.script_util import (
12
+ model_and_diffusion_defaults,
13
+ add_dict_to_argparser,
14
+ )
15
+ from src.scripts.mydatasets import Lang2molDataset_eval
16
+
17
+
18
+ def main():
19
+ set_seed(42)
20
+ args = create_argparser().parse_args()
21
+
22
+ # dist_util.setup_dist()
23
+ logger.configure()
24
+ args.sigma_small = True
25
+
26
+ # args.diffusion_steps = 200 #500 # DEBUG
27
+
28
+ if args.experiment == "random1":
29
+ args.experiment = "random"
30
+ logger.log("creating model and diffusion...")
31
+ tokenizer = Tokenizer()
32
+ model = TransformerNetModel(
33
+ in_channels=args.model_in_channels,
34
+ model_channels=args.model_model_channels,
35
+ dropout=args.model_dropout,
36
+ vocab_size=len(tokenizer),
37
+ hidden_size=args.model_hidden_size,
38
+ num_attention_heads=args.model_num_attention_heads,
39
+ num_hidden_layers=args.model_num_hidden_layers,
40
+ )
41
+ model.eval()
42
+ diffusion = SpacedDiffusion(
43
+ use_timesteps=[i for i in range(0, args.diffusion_steps, 10)],
44
+ betas=gd.get_named_beta_schedule("sqrt", args.diffusion_steps),
45
+ model_mean_type=(gd.ModelMeanType.START_X),
46
+ model_var_type=((gd.ModelVarType.FIXED_LARGE)),
47
+ loss_type=gd.LossType.E2E_MSE,
48
+ rescale_timesteps=True,
49
+ model_arch="transformer",
50
+ training_mode="e2e",
51
+ )
52
+
53
+ model.load_state_dict(
54
+ dist_util.load_state_dict(args.model_path, map_location="cpu")
55
+ )
56
+ pytorch_total_params = sum(p.numel() for p in model.parameters())
57
+ logger.log(f"the parameter count is {pytorch_total_params}")
58
+ model.to(dist_util.dev())
59
+ model.eval()
60
+
61
+ logger.log("sampling...")
62
+ print("--" * 30)
63
+ print(f"Loading {args.split} set")
64
+ print("--" * 30)
65
+
66
+ validation_dataset = Lang2molDataset_eval(
67
+ dir=args.dataset_path,
68
+ tokenizer=tokenizer,
69
+ split=args.split,
70
+ corrupt_prob=0.0,
71
+ token_max_length=args.token_max_length,
72
+ dataset_name=args.dataset_name,
73
+ )
74
+ print("-------------------- DATASET INFO --------------------")
75
+ print(f"Size: {len(validation_dataset)} samples")
76
+ print(f'Sample shape: {validation_dataset[0]["caption_state"].shape}')
77
+
78
+ print(f"Use DDIM: {args.use_ddim}")
79
+ sample_fn = (
80
+ diffusion.p_sample_loop if not args.use_ddim else diffusion.ddim_sample_loop
81
+ )
82
+
83
+ print(f"Batch size: {args.batch_size}")
84
+ next_batch_start = args.start
85
+ next_batch_end = next_batch_start + args.batch_size
86
+ all_outputs = []
87
+ all_caption = []
88
+ all_smiles = []
89
+ pbar = tqdm(
90
+ total=len(validation_dataset) // args.batch_size + 1
91
+ if len(validation_dataset) % args.batch_size != 0
92
+ else len(validation_dataset) // args.batch_size
93
+ )
94
+ while True:
95
+ sample = [
96
+ (
97
+ validation_dataset[i]["caption_state"],
98
+ validation_dataset[i]["caption_mask"],
99
+ validation_dataset[i]["caption"],
100
+ validation_dataset[i]["smiles"],
101
+ )
102
+ for i in range(next_batch_start, next_batch_end)
103
+ ]
104
+ caption_state = torch.concat([i[0] for i in sample], dim=0)
105
+ caption_mask = torch.concat([i[1] for i in sample], dim=0)
106
+ caption = [i[2] for i in sample]
107
+ smiles = [i[3] for i in sample]
108
+
109
+ outputs = sample_fn(
110
+ model,
111
+ (args.batch_size, 256, model.in_channels),
112
+ clip_denoised=args.clip_denoised,
113
+ denoised_fn=None,
114
+ model_kwargs={},
115
+ top_p=args.top_p,
116
+ progress=True,
117
+ caption=(caption_state, caption_mask),
118
+ )
119
+
120
+ logits = model.get_logits(torch.tensor(outputs).cuda())
121
+ cands = torch.topk(logits, k=1, dim=-1)
122
+ outputs = cands.indices
123
+ outputs = outputs.squeeze(-1)
124
+ outputs = tokenizer.decode(outputs)
125
+
126
+ with open(args.outputdir, "a") as f:
127
+ for i, x in enumerate(outputs):
128
+ f.write(
129
+ caption[i]
130
+ + "\t"
131
+ + smiles[i]
132
+ + "\t"
133
+ + sf.decoder(x.replace("<pad>", "").replace("</s>", ""))
134
+ + "\n"
135
+ )
136
+
137
+ all_outputs += outputs
138
+ all_caption += caption
139
+ all_smiles += smiles
140
+
141
+ next_batch_start = next_batch_end
142
+ next_batch_end = min(next_batch_end + args.batch_size, len(validation_dataset))
143
+ pbar.update(1)
144
+
145
+ if next_batch_start == len(validation_dataset):
146
+ break
147
+
148
+ with open(args.outputdir.replace(".txt", "_final.txt"), "w") as f:
149
+ for i, x in enumerate(all_outputs):
150
+ f.write(
151
+ all_caption[i]
152
+ + "\t"
153
+ + all_smiles[i]
154
+ + "\t"
155
+ + sf.decoder(x.replace("<pad>", "").replace("</s>", ""))
156
+ + "\n"
157
+ )
158
+
159
+
160
+ def create_argparser():
161
+ defaults = dict(
162
+ clip_denoised=False,
163
+ mbr_sample=1,
164
+ model_path="",
165
+ model_arch="conv-unet",
166
+ verbose="yes",
167
+ )
168
+ text_defaults = dict(
169
+ modality="text",
170
+ dataset_name="",
171
+ dataset_config_name="wikitext-2-raw-v1",
172
+ dataset_path="dataset",
173
+ experiment="gpt2_pre_compress",
174
+ model_arch="trans-unet",
175
+ model_in_channels=32,
176
+ model_model_channels=128,
177
+ model_dropout=0.1,
178
+ model_hidden_size=1024,
179
+ model_num_attention_heads=16,
180
+ model_num_hidden_layers=12,
181
+ preprocessing_num_workers=1,
182
+ emb_scale_factor=1.0,
183
+ clamp="clamp",
184
+ split="validation",
185
+ model_path="",
186
+ use_ddim=False,
187
+ batch_size=16,
188
+ top_p=1.0,
189
+ outputdir="output.txt",
190
+ diffusion_steps=2000,
191
+ token_max_length=256,
192
+ start=0,
193
+ )
194
+ defaults.update(model_and_diffusion_defaults())
195
+ defaults.update(text_defaults)
196
+ parser = argparse.ArgumentParser()
197
+ add_dict_to_argparser(parser, defaults)
198
+ return parser
199
+
200
+
201
+ if __name__ == "__main__":
202
+ main()
inference_submission.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import argparse
3
+ import selfies as sf
4
+ from tqdm import tqdm
5
+ from transformers import set_seed
6
+ from src.scripts.mytokenizers import Tokenizer
7
+ from src.improved_diffusion import gaussian_diffusion as gd
8
+ from src.improved_diffusion import dist_util, logger
9
+ from src.improved_diffusion.respace import SpacedDiffusion
10
+ from src.improved_diffusion.transformer_model import TransformerNetModel
11
+ from src.improved_diffusion.script_util import (
12
+ model_and_diffusion_defaults,
13
+ add_dict_to_argparser,
14
+ )
15
+ from src.scripts.mydatasets import Lang2molDataset_submission
16
+
17
+
18
+ def main():
19
+ set_seed(42)
20
+ args = create_argparser().parse_args()
21
+
22
+ # dist_util.setup_dist()
23
+ logger.configure()
24
+ args.sigma_small = True
25
+
26
+ # args.diffusion_steps = 200 #500 # DEBUG
27
+
28
+ if args.experiment == "random1":
29
+ args.experiment = "random"
30
+ logger.log("creating model and diffusion...")
31
+ tokenizer = Tokenizer()
32
+ model = TransformerNetModel(
33
+ in_channels=args.model_in_channels,
34
+ model_channels=args.model_model_channels,
35
+ dropout=args.model_dropout,
36
+ vocab_size=len(tokenizer),
37
+ hidden_size=args.model_hidden_size,
38
+ num_attention_heads=args.model_num_attention_heads,
39
+ num_hidden_layers=args.model_num_hidden_layers,
40
+ )
41
+ model.eval()
42
+ diffusion = SpacedDiffusion(
43
+ use_timesteps=[i for i in range(0, args.diffusion_steps, 10)],
44
+ betas=gd.get_named_beta_schedule("sqrt", args.diffusion_steps),
45
+ model_mean_type=(gd.ModelMeanType.START_X),
46
+ model_var_type=((gd.ModelVarType.FIXED_LARGE)),
47
+ loss_type=gd.LossType.E2E_MSE,
48
+ rescale_timesteps=True,
49
+ model_arch="transformer",
50
+ training_mode="e2e",
51
+ )
52
+
53
+ model.load_state_dict(
54
+ dist_util.load_state_dict(args.model_path, map_location="cpu")
55
+ )
56
+ pytorch_total_params = sum(p.numel() for p in model.parameters())
57
+ logger.log(f"the parameter count is {pytorch_total_params}")
58
+ model.to(dist_util.dev())
59
+ model.eval()
60
+
61
+ logger.log("sampling...")
62
+ print("--" * 30)
63
+ print(f"Loading {args.split} set")
64
+ print("--" * 30)
65
+
66
+ validation_dataset = Lang2molDataset_submission(
67
+ dir=args.dataset_path,
68
+ tokenizer=tokenizer,
69
+ split=args.split,
70
+ corrupt_prob=0.0,
71
+ token_max_length=args.token_max_length,
72
+ dataset_name=args.dataset_name,
73
+ )
74
+ print("-------------------- DATASET INFO --------------------")
75
+ print(f"Size: {len(validation_dataset)} samples")
76
+ print(f'Sample shape: {validation_dataset[0]["caption_state"].shape}')
77
+
78
+ print(f"Use DDIM: {args.use_ddim}")
79
+ sample_fn = (
80
+ diffusion.p_sample_loop if not args.use_ddim else diffusion.ddim_sample_loop
81
+ )
82
+
83
+ print(f"Batch size: {args.batch_size}")
84
+ next_batch_start = args.start
85
+ next_batch_end = next_batch_start + args.batch_size
86
+ all_outputs = []
87
+ all_caption = []
88
+ pbar = tqdm(
89
+ total=len(validation_dataset) // args.batch_size + 1
90
+ if len(validation_dataset) % args.batch_size != 0
91
+ else len(validation_dataset) // args.batch_size
92
+ )
93
+ while True:
94
+ sample = [
95
+ (
96
+ validation_dataset[i]["caption_state"],
97
+ validation_dataset[i]["caption_mask"],
98
+ validation_dataset[i]["caption"],
99
+ )
100
+ for i in range(next_batch_start, next_batch_end)
101
+ ]
102
+ caption_state = torch.concat([i[0] for i in sample], dim=0)
103
+ caption_mask = torch.concat([i[1] for i in sample], dim=0)
104
+ caption = [i[2] for i in sample]
105
+
106
+ outputs = sample_fn(
107
+ model,
108
+ (args.batch_size, 256, model.in_channels),
109
+ clip_denoised=args.clip_denoised,
110
+ denoised_fn=None,
111
+ model_kwargs={},
112
+ top_p=args.top_p,
113
+ progress=True,
114
+ caption=(caption_state, caption_mask),
115
+ )
116
+
117
+ logits = model.get_logits(torch.tensor(outputs).cuda())
118
+ cands = torch.topk(logits, k=1, dim=-1)
119
+ outputs = cands.indices
120
+ outputs = outputs.squeeze(-1)
121
+ outputs = tokenizer.decode(outputs)
122
+
123
+ with open(args.outputdir, "a") as f:
124
+ for i, x in enumerate(outputs):
125
+ f.write(
126
+ sf.decoder(
127
+ x.replace("<pad>", "").replace("</s>", "").replace("\t", "")
128
+ ).replace("\t", "")
129
+ + "\n"
130
+ )
131
+
132
+ all_outputs += outputs
133
+ all_caption += caption
134
+
135
+ next_batch_start = next_batch_end
136
+ next_batch_end = min(next_batch_end + args.batch_size, len(validation_dataset))
137
+ pbar.update(1)
138
+
139
+ if next_batch_start == len(validation_dataset):
140
+ break
141
+
142
+ with open(args.outputdir.replace(".txt", "_final.txt"), "w") as f:
143
+ for i, x in enumerate(all_outputs):
144
+ f.write(sf.decoder(x.replace("<pad>", "").replace("</s>", "")) + "\n")
145
+
146
+
147
+ def create_argparser():
148
+ defaults = dict(
149
+ clip_denoised=False,
150
+ mbr_sample=1,
151
+ model_path="",
152
+ model_arch="conv-unet",
153
+ verbose="yes",
154
+ )
155
+ text_defaults = dict(
156
+ modality="text",
157
+ dataset_name="language-plus-molecules/LPM-24_eval-molgen",
158
+ dataset_config_name="wikitext-2-raw-v1",
159
+ dataset_path="dataset",
160
+ experiment="gpt2_pre_compress",
161
+ model_arch="trans-unet",
162
+ model_in_channels=32,
163
+ model_model_channels=128,
164
+ model_dropout=0.1,
165
+ model_hidden_size=1024,
166
+ model_num_attention_heads=16,
167
+ model_num_hidden_layers=12,
168
+ preprocessing_num_workers=1,
169
+ emb_scale_factor=1.0,
170
+ clamp="clamp",
171
+ split="train",
172
+ model_path="",
173
+ use_ddim=False,
174
+ batch_size=7,
175
+ top_p=1.0,
176
+ outputdir="output.txt",
177
+ diffusion_steps=2000,
178
+ token_max_length=256,
179
+ start=0,
180
+ )
181
+ defaults.update(model_and_diffusion_defaults())
182
+ defaults.update(text_defaults)
183
+ parser = argparse.ArgumentParser()
184
+ add_dict_to_argparser(parser, defaults)
185
+ return parser
186
+
187
+
188
+ if __name__ == "__main__":
189
+ main()
requirements.txt ADDED
Binary file (128 Bytes). View file
 
src/__init__.py ADDED
File without changes
src/anlg_infill/anlg.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import sys, os, torch
3
+ from spacy.lang.en import English
4
+ from improved_diffusion.rounding import rounding_func, load_models, load_tokenizer
5
+ from transformers import AutoModelForCausalLM
6
+ # read files.
7
+ # with open('diffusion_lm/ROCstory/anlg/anlg/dev_cleanup.json', 'r') as f:
8
+ SPLIT = 'test'
9
+
10
+ if SPLIT == 'val':
11
+ source_file = 'diffusion_lm/ROCstory/anlg/anlg/dev_cleanup.json'
12
+ elif SPLIT == 'test':
13
+ source_file = 'diffusion_lm/ROCstory/anlg/anlg/test_cleanup_no_label.json'
14
+ else:
15
+ assert False, "invalid split"
16
+
17
+ with open(source_file, 'r') as f:
18
+ sent_lst = json.load(f)
19
+
20
+
21
+ nlp = English()
22
+ tokenizer = nlp.tokenizer
23
+ MODE = 'ar'
24
+
25
+ '''
26
+ "00b9adb2-b3b6-4737-902a-50f308bac4b5-1": {
27
+ "gold_labels": [
28
+ "I put my baby in the car and drove around.",
29
+ "I realized he needed his blanket, which I had forgotten at a faraway hotel.",
30
+ "I took a drive to get my baby to sleep.",
31
+ "I took my baby for a drive and she fell asleep in the car."
32
+ ],
33
+ "obs1": "My baby would not go to sleep last night.",
34
+ "obs2": "I wound up driving for hours."
35
+ },
36
+ '''
37
+ print(len(sent_lst))
38
+
39
+ if MODE == 'ar':
40
+ model_name = 'predictability/diff_models/roc_e=20_b=32_m=gpt2_wikitext-103-raw-v1_101_wp_pad_infill'
41
+ model_name = 'predictability/diff_models/roc_e=6_b=10_m=gpt2_wikitext-103-raw-v1_101_wp_pad_infill_v2'
42
+ model = AutoModelForCausalLM.from_pretrained(
43
+ model_name, # path to the AR model trained for LMing this task.
44
+ ).cuda()
45
+ tokenizer2 = load_tokenizer('roc', 'random',
46
+ 'predictability/diffusion_models_v7/diff_roc_pad_rand16_transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd108_xstart')
47
+ vocab = {v: k for k, v in tokenizer2.items()}
48
+ print(len(tokenizer2), len(vocab), 'loaded vocabs')
49
+
50
+ outfile='ar_sample_full_test_v2.json'
51
+ filehandle = open(outfile, 'w')
52
+
53
+ for idx, (key, val) in enumerate(sent_lst.items()):
54
+ # if idx <= 499:
55
+ # continue
56
+ # if idx >= 500:
57
+ # continue
58
+ # if idx != 684:
59
+ # continue
60
+
61
+ if MODE == 'diff':
62
+ partial_seq = f"{val['obs1']} " + "PAD "*10 + f"{val['obs2']}"
63
+ word_lst = [x.text for x in tokenizer(partial_seq)]
64
+ partial_seq = " ".join(word_lst)
65
+ print(partial_seq, idx)
66
+ # partial_seq = "Brenna and I used to be best friends . PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD We never talked again ."
67
+ COMMAND = "python ../scripts/infill.py " \
68
+ "--model_path predictability/diffusion_models_v7/diff_roc_pad_rand128_transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd108_xstart_e2e_long/ema_0.9999_800000.pt " \
69
+ " --batch_size 50 " \
70
+ f"--partial_seq \'{partial_seq}\' " \
71
+ f"--eval_task_ infill --notes {SPLIT}_{idx} " \
72
+ f"--out_dir ../anlg_results"
73
+ os.system(COMMAND)
74
+ torch.cuda.empty_cache()
75
+ elif MODE == 'ar':
76
+ partial_seq = f"{val['obs1']} " + f"{val['obs2']}"
77
+ print(partial_seq)
78
+ word_idx_lst = [vocab['START']] + [vocab.get(x.text, vocab['UNK']) for x in tokenizer(partial_seq)]
79
+ init_prompt = torch.LongTensor(word_idx_lst).cuda().unsqueeze(0)
80
+ print(init_prompt.shape)
81
+ # sample_out = model.generate(init_prompt, do_sample=True, max_length=64, top_k=len(vocab))
82
+ if 'sample' in outfile:
83
+ print('sampling 50 examples.')
84
+ init_prompt = init_prompt.expand(50, -1)
85
+ sample_out = model.generate(init_prompt, do_sample=True, max_length=64, top_k=len(vocab))
86
+ else:
87
+ sample_out = model.generate(init_prompt, do_sample=False, num_beam=4, max_length=64, top_k=len(vocab))
88
+
89
+ print(sample_out.shape)
90
+ sample_out = sample_out[:, init_prompt.size(1):]
91
+ # decode
92
+ if 'sample' in outfile:
93
+ sample_lst = []
94
+ for examp in sample_out:
95
+ sample = examp.tolist()
96
+ words_sample = [tokenizer2[s] for s in sample]
97
+ tempsent = [x for x in words_sample if x != 'PAD']
98
+ if tempsent[0] == 'START':
99
+ tempsent = tempsent[1:]
100
+ if tempsent[-1] == 'END':
101
+ tempsent = tempsent[:-1]
102
+ result_sent = " ".join(tempsent)
103
+ sample_lst.append(result_sent)
104
+ out_dict = {'idx': idx,
105
+ 'obs1': val['obs1'],
106
+ 'obs2': val['obs2'],
107
+ 'samples': sample_lst}
108
+ print(json.dumps(out_dict), file=filehandle)
109
+ else:
110
+ sample = sample_out[0].tolist()
111
+ words_sample = [tokenizer2[s] for s in sample]
112
+ tempsent = [x for x in words_sample if x != 'PAD']
113
+ if tempsent[0] == 'START':
114
+ tempsent = tempsent[1:]
115
+ if tempsent[-1] == 'END':
116
+ tempsent = tempsent[:-1]
117
+ result_sent = " ".join(tempsent)
118
+ out_dict = {'idx':idx,
119
+ 'obs1':val['obs1'],
120
+ 'obs2':val['obs2'],
121
+ 'sample':result_sent}
122
+ print(json.dumps(out_dict), file=filehandle)
123
+ filehandle.close()
124
+ print(f'written to {outfile}')
125
+
126
+
127
+
128
+
129
+
130
+
src/anlg_infill/mbr_eval.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys, json
2
+ import glob
3
+ from functools import partial
4
+ sys.path.insert(0, 'e2e-metrics')
5
+ import numpy as np
6
+ from pycocotools.coco import COCO
7
+ from pycocoevalcap.eval import COCOEvalCap
8
+ from metrics.pymteval import BLEUScore, NISTScore
9
+ from nltk.translate.meteor_score import meteor_score
10
+ from parse import *
11
+ import json
12
+ import sys, os, torch
13
+ from spacy.lang.en import English
14
+ import ast
15
+ from transformers import BertForMaskedLM, BertTokenizer
16
+
17
+ MODE = sys.argv[1] # ar or diff
18
+ SPLIT = sys.argv[2] # val or test
19
+ OUT_PATH = sys.argv[3] # output path.
20
+ INPUT_PATH = sys.argv[4] # input path. e.g. diffusion_lm/improved-diffusion/anlg_results/diff_roc_pad_rand128_transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd108_xstart_e2e_long.ema_0.9999_800000.pt.infill_infill
21
+
22
+ def load_results_simple(path):
23
+ with open(path, 'r') as f:
24
+ full_result_dict = json.load(f)
25
+ return full_result_dict
26
+
27
+ def post_process(filename, fileout, tokenizer_spacy):
28
+ # filename = 'diffusion_lm/improved-diffusion/anlg_results/diff_roc_mbr.json2'
29
+ bert_model = 'bert-base-cased'
30
+ tokenizer = BertTokenizer.from_pretrained(bert_model)
31
+ model = BertForMaskedLM.from_pretrained(bert_model).cuda()
32
+ fileout_handle = open(fileout, 'w')
33
+
34
+ full_lst = []
35
+ with open(filename, 'r') as f:
36
+ for line in f:
37
+ line = json.loads(line)
38
+ full_lst.append(line)
39
+
40
+ for example in full_lst:
41
+ sent = example['sample']
42
+ obs1 = example['obs1']
43
+ obs2 = example['obs2']
44
+ if 'UNK' in sent:
45
+ sent = obs1 + sent.replace('UNK', tokenizer.mask_token) + obs2
46
+ print(sent)
47
+ model_inputs = tokenizer(sent, return_tensors="pt")
48
+ model_inputs = {k: v.to(model.device) for k, v in model_inputs.items()}
49
+ model_out = model(**model_inputs)
50
+ mask_words = model_inputs['input_ids'] == tokenizer.mask_token_id
51
+ masked_logits = model_out.logits[mask_words].view(-1, model_out.logits.size(-1))
52
+ # take argmax from this.
53
+ max_cands = torch.max(masked_logits, dim=-1)
54
+ indices = max_cands.indices
55
+ model_inputs['input_ids'][mask_words] = indices
56
+ out = tokenizer.batch_decode(model_inputs['input_ids'].tolist(),
57
+ skip_special_tokens=True)[0]
58
+ print(out)
59
+ word_lstout = [x.text for x in tokenizer_spacy(out)]
60
+ word_lst1 = [x.text for x in tokenizer_spacy(example['obs1'])]
61
+ word_lst2 = [x.text for x in tokenizer_spacy(example['obs2'])]
62
+ example['sample'] = " ".join(word_lstout[len(word_lst1):-len(word_lst2)])
63
+ print(example['sample'])
64
+ print()
65
+
66
+
67
+ else:
68
+ print('NO NEED THIS FIX. ')
69
+
70
+
71
+ print(json.dumps(example), file=fileout_handle)
72
+
73
+ fileout_handle.close()
74
+
75
+
76
+
77
+ def load_results(sent_lst, tokenizer):
78
+ # target_file = f"{INPUT_PATH}_*.json"
79
+ # target_file = glob.glob(target_file)
80
+ # print([x for x in target_file if 'val' not in x and 'test' not in x])
81
+ # 10/0
82
+ full_result_dict = {}
83
+ failed_instances = []
84
+ found_idx = []
85
+ sent_lst_lst = list(sent_lst.items())
86
+ for idx, (key, val) in enumerate(sent_lst_lst):
87
+ # if idx < 2500: continue
88
+ if idx in full_result_dict.keys(): continue
89
+ word_lst1 = [x.text for x in tokenizer(val['obs1'])]
90
+ word_lst2 = [x.text for x in tokenizer(val['obs2'])]
91
+ # target_file = f"diffusion_lm/improved-diffusion/anlg_results/diff_roc_pad_rand128_" \
92
+ # f"transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd108_xstart_e2e_long.ema" \
93
+ # f"_0.9999_800000.pt.infill_infill_*_{SPLIT}_{idx}.json"
94
+ target_file = f"{INPUT_PATH}_*_{SPLIT}_{idx}.json"
95
+
96
+ file_lst = glob.glob(target_file)
97
+ # print(file_lst, target_file)
98
+ try:
99
+ assert len(file_lst) == 1
100
+ except:
101
+ print('the file must have existed in a batched version')
102
+ # if SPLIT == 'val': assert False
103
+ # if idx % 100 == 1: idx = idx-1
104
+ target_file = f"{INPUT_PATH}_*_{idx}.json"
105
+ file_lst = glob.glob(target_file)
106
+ print(file_lst, target_file)
107
+ print(file_lst)
108
+ target_file = file_lst[0]
109
+ if "x128" in target_file:
110
+ infill_lst = []
111
+ with open(target_file, 'r') as f:
112
+ for line in f:
113
+ example = json.loads(line)[0]
114
+ infill_ = example.split()[len(word_lst1):-len(word_lst2)]
115
+ # print(len(infill_))
116
+ # print(infill_, example)
117
+ # assert len(infill_) == 10
118
+ infill_=' '.join(infill_)
119
+ # print(infill_)
120
+ infill_lst.append(infill_)
121
+ result_dict = {
122
+ "pred_samples": infill_lst,
123
+ "sample": None,
124
+ "obs1": val['obs1'],
125
+ "obs2": val['obs2']
126
+ }
127
+ full_result_dict[idx] = result_dict
128
+ else:
129
+ with open(target_file, 'r') as f:
130
+ for line in f:
131
+ example = ast.literal_eval(line.strip())
132
+ index, template = list(example.keys())[0]
133
+ print(index, idx)
134
+ if int(index) < int(idx):
135
+ continue
136
+ assert int(index) == int(idx)
137
+ found_idx.append(idx)
138
+ example = list(example.values())[0]
139
+ kk, val = sent_lst_lst[idx]
140
+ word_lst1 = [x.text for x in tokenizer(val['obs1'])]
141
+ word_lst2 = [x.text for x in tokenizer(val['obs2'])]
142
+ infill_lst = [" ".join(xx.split()[len(word_lst1):-len(word_lst2)]) for xx in example]
143
+ result_dict = {
144
+ "pred_samples": infill_lst,
145
+ "sample": None,
146
+ "obs1": val['obs1'],
147
+ "obs2": val['obs2']
148
+ }
149
+ full_result_dict[idx] = result_dict
150
+ idx += 1
151
+
152
+ with open('full_diff_test_outputs_aug.json', 'w') as f:
153
+ json.dump(full_result_dict, f)
154
+ return full_result_dict
155
+
156
+
157
+ # read files.
158
+ def mbr(result_lst, total_len, sample_size, utility):
159
+ result = []
160
+ for i in range(total_len):
161
+ example_set = result_lst[i * sample_size:(i + 1) * sample_size]
162
+ # print(example_set)
163
+ score_dict = {}
164
+ for idx in range(len(example_set)):
165
+ y = example_set[idx]
166
+ utility_lst = []
167
+ for idx_x in range(len(example_set)):
168
+ if idx_x != idx:
169
+ utility_lst.append(utility(example_set[idx_x], y))
170
+ score_dict[idx] = np.array(utility_lst).mean()
171
+ # print(score_dict)
172
+ best_y = sorted(score_dict.items(), key=lambda item: item[1])[-1]
173
+ result.append(example_set[best_y[0]])
174
+ # print(best_y)
175
+
176
+ return result
177
+
178
+
179
+ def bleu_score(scorer, sent_sys, sents_ref):
180
+ scorer.reset()
181
+ scorer.append(sent_sys, [sents_ref])
182
+ return scorer.score()
183
+
184
+
185
+ def meteor_score2(pred, ref):
186
+ meteor = meteor_score([ref.split()], pred.split())
187
+ return meteor
188
+
189
+ def apply_mbr_func(full_result_dict, outpath, sent_lst):
190
+ assert len(sent_lst) == len(full_result_dict)
191
+ out_handle = open(outpath, 'w')
192
+ count = 0
193
+ for idx, val in full_result_dict.items():
194
+ infill_lst = val['pred_samples']
195
+ print(count, idx )
196
+ assert count == int(idx)
197
+ count += 1
198
+ sample_size = len(infill_lst)
199
+ total_len = 1
200
+ mteval_scorers = [BLEUScore(), BLEUScore(smoothing=1.0), NISTScore()]
201
+ result_lst = mbr(infill_lst, total_len, sample_size, partial(bleu_score, mteval_scorers[1]))
202
+ print(infill_lst)
203
+ print(result_lst)
204
+ result_str = result_lst[0]
205
+ result_dict = {
206
+ "pred_samples": infill_lst,
207
+ "sample": result_str,
208
+ "obs1": val['obs1'],
209
+ "obs2": val['obs2']
210
+ }
211
+ print(json.dumps(result_dict), file=out_handle)
212
+ out_handle.close()
213
+ print(f'written to {outpath}')
214
+ return
215
+
216
+ if SPLIT == 'val':
217
+ source_file = 'diffusion_lm/ROCstory/anlg/anlg/dev_cleanup.json'
218
+ elif SPLIT == 'test':
219
+ source_file = 'diffusion_lm/ROCstory/anlg/anlg/test_cleanup_no_label.json'
220
+ else:
221
+ assert False, "invalid split"
222
+
223
+ with open(source_file, 'r') as f:
224
+ sent_lst = json.load(f)
225
+
226
+
227
+
228
+ if MODE == 'diff':
229
+ nlp = English()
230
+ tokenizer = nlp.tokenizer
231
+ # load_results(sent_lst, tokenizer)
232
+ # 10/0
233
+ decoded_dict = load_results_simple(INPUT_PATH)
234
+ ############3
235
+ # small_decoded_dict = {}
236
+ # for i in range(10):
237
+ # small_decoded_dict[i] = decoded_dict[str(i)]
238
+ # decoded_dict = small_decoded_dict
239
+ # small_sent_lst = {}
240
+ # for k, v in sent_lst.items():
241
+ # if len(small_sent_lst) > 9: break
242
+ # small_sent_lst[k] = v
243
+ # sent_lst = small_sent_lst
244
+ ############3
245
+ outpath = OUT_PATH
246
+ apply_mbr_func(decoded_dict, outpath, sent_lst)
247
+ post_process(outpath, outpath+'.clean.json', tokenizer)
248
+
249
+ #
250
+ # # load_results(sent_lst, tokenizer)
251
+ # # 10/0
252
+ # print(len(sent_lst))
253
+ # for idx, (key, val) in enumerate(sent_lst.items()):
254
+ # # if idx < 518: continue
255
+ # if idx > 900:
256
+ # break
257
+ # # change the matching method.
258
+ # word_lst1 = [x.text for x in tokenizer(val['obs1'])]
259
+ # word_lst2 = [x.text for x in tokenizer(val['obs2'])]
260
+ # # partial_seq = f"{val['obs1']} " + "PAD " + f"{val['obs2']}"
261
+ # # word_lst = [x.text for x in tokenizer(partial_seq)]
262
+ # # partial_seq = " ".join(word_lst)
263
+ # # partial_seq = partial_seq.replace('PAD', '{}')
264
+ # # print(partial_seq, idx)
265
+ #
266
+ # # target_file = f"diffusion_lm/improved-diffusion/anlg_results/diff_roc_pad_rand128_" \
267
+ # # f"transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd108_xstart_e2e_long.ema" \
268
+ # # f"_0.9999_800000.pt.infill_infill_*_{SPLIT}_{idx}.json"
269
+ # target_file = f"{INPUT_PATH}_*_{SPLIT}_{idx}.json"
270
+ #
271
+ # file_lst = glob.glob(target_file)
272
+ # print(file_lst, target_file)
273
+ # assert len(file_lst) == 1
274
+ # target_file = file_lst[0]
275
+ # # print(target_file)
276
+ # infill_lst = []
277
+ # with open(target_file, 'r') as f:
278
+ # for line in f:
279
+ # example = json.loads(line)[0]
280
+ # # print(example, partial_seq)
281
+ # # infill_ = parse(partial_seq, example)
282
+ # # print(example)
283
+ # infill_ = example.split()[len(word_lst1):-len(word_lst2)]
284
+ # # print(len(infill_))
285
+ # # print(infill_, example)
286
+ # # assert len(infill_) == 10
287
+ # infill_=' '.join(infill_)
288
+ # # print(infill_)
289
+ # infill_lst.append(infill_)
290
+ # infill_lst = infill_lst
291
+ # sample_size = len(infill_lst)
292
+ # total_len = 1
293
+ # mteval_scorers = [BLEUScore(), BLEUScore(smoothing=1.0), NISTScore()]
294
+ # result_lst = mbr(infill_lst, total_len, sample_size, partial(bleu_score, mteval_scorers[1]))
295
+ # print(infill_lst)
296
+ # print(result_lst)
297
+ # result_str = result_lst[0]
298
+ # result_dict = {
299
+ # "pred_samples": infill_lst,
300
+ # "sample":result_str,
301
+ # "obs1": val['obs1'],
302
+ # "obs2": val['obs2']
303
+ # }
304
+ # print(json.dumps(result_dict), file=out_handle)
305
+ #
306
+ # out_handle.close()
307
+ # print(f'written to {outpath}')
308
+
309
+ elif MODE == 'ar':
310
+ outpath = OUT_PATH #'diffusion_lm/improved-diffusion/anlg_results/ar_full_mbr.json'
311
+ out_handle = open(outpath, 'w')
312
+ sample_file = INPUT_PATH #'diffusion_lm/improved-diffusion/anlg_results/ar_sample_500_v2.json'
313
+ nlp = English()
314
+ tokenizer = nlp.tokenizer
315
+ print(len(sent_lst))
316
+ sample_lst = []
317
+ with open(sample_file, 'r') as f:
318
+ for line in f:
319
+ sample_dict = json.loads(line)
320
+ sample_lst.append(sample_dict)
321
+
322
+ for idx, (key, val) in enumerate(sent_lst.items()):
323
+ # if idx < 109: continue
324
+ # if idx > 499:
325
+ # break
326
+ infill_lst = sample_lst[idx]['samples']
327
+ sample_size = len(infill_lst)
328
+ total_len = 1
329
+ mteval_scorers = [BLEUScore(), BLEUScore(smoothing=1.0), NISTScore()]
330
+ result_lst = mbr(infill_lst, total_len, sample_size, partial(bleu_score, mteval_scorers[1]))
331
+ print(infill_lst)
332
+ print(result_lst)
333
+ result_str = result_lst[0]
334
+ result_dict = {
335
+ "pred_samples": infill_lst,
336
+ "sample": result_str,
337
+ "obs1": val['obs1'],
338
+ "obs2": val['obs2']
339
+ }
340
+ print(json.dumps(result_dict), file=out_handle)
341
+
342
+ out_handle.close()
343
+ print(f'written to {outpath}')
344
+
345
+ post_process(outpath, outpath + '.clean.json', tokenizer)
346
+
347
+ # print(file+'.clean')
348
+ # with open(file+'.clean', 'w') as f:
349
+ # for line in result_lst:
350
+ # print(line, file=f)
351
+
src/anlg_infill/post_process.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import json
3
+ from transformers import BertForMaskedLM, BertTokenizer
4
+ filename = 'diffusion_lm/improved-diffusion/anlg_results/diff_roc_mbr.json2'
5
+ bert_model = 'bert-base-uncased'
6
+ tokenizer = BertTokenizer.from_pretrained(bert_model)
7
+ model = BertForMaskedLM.from_pretrained(bert_model).cuda()
8
+
9
+ full_lst = []
10
+ with open(filename, 'r') as f:
11
+ for line in f:
12
+ line = json.loads(line)
13
+ full_lst.append(line)
14
+
15
+ for example in full_lst:
16
+ sent = example['sample']
17
+ obs1 = example['obs1']
18
+ obs2 = example['obs2']
19
+ if 'UNK' in sent:
20
+ sent = obs1 + sent.replace('UNK', tokenizer.mask_token) + obs2
21
+ print(sent)
22
+ model_inputs = tokenizer(sent,return_tensors="pt")
23
+ model_inputs = {k:v.to(model.device) for k,v in model_inputs.items()}
24
+ model_out = model(**model_inputs)
25
+ mask_words = model_inputs['input_ids'] == tokenizer.mask_token_id
26
+ masked_logits = model_out.logits[mask_words].view(-1, model_out.logits.size(-1))
27
+ if masked_logits.size(0) > 0:
28
+ # take argmax from this.
29
+ max_cands = torch.max(masked_logits, dim=-1)
30
+ indices = max_cands.indices
31
+ model_inputs['input_ids'][mask_words] = indices
32
+ print(tokenizer.batch_decode(model_inputs['input_ids'].tolist()))
33
+ else:
34
+ print('NO NEED THIS FIX. ')
35
+
src/anlg_infill/run_evaluation.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, json, sys
2
+
3
+ SPLIT = sys.argv[1] # val or test
4
+ MBR_PATH = sys.argv[2] # output path.
5
+
6
+ # read files.
7
+ if SPLIT == 'val':
8
+ source_file = '/diffusion_lm/ROCstory/anlg/anlg/dev_cleanup.json'
9
+ elif SPLIT == 'test':
10
+ source_file = '/diffusion_lm/ROCstory/anlg/anlg/test_cleanup_no_label.json'
11
+ else:
12
+ assert False, "invalid split"
13
+
14
+ with open(source_file, 'r') as f:
15
+ sent_lst = json.load(f)
16
+
17
+ # read generation
18
+ generated_lst = []
19
+ # with open('/diffusion_lm/improved-diffusion/anlg_results/ar_beam_500.json', 'r') as f:
20
+ # with open('/diffusion_lm/improved-diffusion/anlg_results/ar_beam_500_v2.json', 'r') as f:
21
+ # with open('/diffusion_lm/improved-diffusion/anlg_results/ar_full_mbr.json', 'r') as f:
22
+ # with open('/diffusion_lm/improved-diffusion/anlg_results/diff_full.json', 'r') as f:
23
+ with open(MBR_PATH, 'r') as f:
24
+ for line in f:
25
+ generated_lst.append(json.loads(line))
26
+
27
+ print(len(generated_lst), len(sent_lst))
28
+ # eval_file_gen = "/diffusion_lm/improved-diffusion/anlg_results/ar_gen_mbr_v2.txt"
29
+ # eval_file_gold = "/diffusion_lm/improved-diffusion/anlg_results/ar_ref_mbr_v2.txt"
30
+ if SPLIT == 'val':
31
+ eval_file_gen = f"{MBR_PATH}_gen.txt"
32
+ fgen = open(eval_file_gen, 'w')
33
+ eval_file_gold = f"{MBR_PATH}_ref.txt" # "/diffusion_lm/improved-diffusion/anlg_results/diff_ref_v1.txt"
34
+ fgold = open(eval_file_gold, 'w')
35
+ for gen, gold in zip(generated_lst, sent_lst.items()):
36
+ print(gen['sample'], file=fgen)
37
+ gold = gold[1]
38
+ for x in gold['gold_labels']:
39
+ print(x, file=fgold)
40
+ print('', file=fgold)
41
+ fgold.close()
42
+ fgen.close()
43
+ elif SPLIT == 'test':
44
+ eval_file_prediction = f"{MBR_PATH}_prediction.json" # "/diffusion_lm/improved-diffusion/anlg_results/diff_ref_v1.txt"
45
+ # fpred = open(eval_file_prediction, 'w')
46
+ full_dict = {}
47
+ for gen, gold in zip(generated_lst, sent_lst.items()):
48
+ print(gold)
49
+ print(gen['sample'])
50
+ full_dict[gold[0]] = gen['sample']
51
+ # temp_dict = {gold[0]:gen['sample']}
52
+ # print(temp_dict)
53
+ # print(json.dumps(temp_dict), file=fpred)
54
+ # gold = gold[1]
55
+ # for x in gold['gold_labels']:
56
+ # print(x, file=fgold)
57
+ # print('', file=fgold)
58
+ with open(eval_file_prediction, 'w') as fpred:
59
+ json.dump(full_dict, fpred)
60
+
61
+ ###########
62
+ test_ref = '/diffusion_lm/ROCstory/anlg/anlg/test_cleanup_ref.json'
63
+ with open(test_ref, 'r') as f:
64
+ test_ref_lst = json.load(f)
65
+
66
+ eval_file_gen = f"{MBR_PATH}_gen.txt"
67
+ fgen = open(eval_file_gen, 'w')
68
+ eval_file_gold = f"{MBR_PATH}_ref.txt" # "/diffusion_lm/improved-diffusion/anlg_results/diff_ref_v1.txt"
69
+ fgold = open(eval_file_gold, 'w')
70
+ for gen, gold in zip(generated_lst, sent_lst.items()):
71
+ story_id = gold[0]
72
+ print(gen['sample'], file=fgen)
73
+ for x in test_ref_lst[story_id]:
74
+ print(x, file=fgold)
75
+ print('', file=fgold)
76
+ fgold.close()
77
+ fgen.close()
78
+
79
+
80
+ # generate prediction.json
81
+
src/control_gen/baseline_control.py ADDED
@@ -0,0 +1,500 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # syntax, semantics, etc...
2
+ import torch, json
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+
5
+ import argparse
6
+ import os
7
+
8
+ import numpy as np
9
+ import torch as th
10
+ import torch.distributed as dist
11
+ from transformers import set_seed
12
+ from improved_diffusion.rounding import rounding_func, load_models, load_tokenizer
13
+ from improved_diffusion import dist_util, logger
14
+ from improved_diffusion.script_util import (
15
+ NUM_CLASSES,
16
+ model_and_diffusion_defaults,
17
+ create_model_and_diffusion,
18
+ add_dict_to_argparser,
19
+ args_to_dict,
20
+ )
21
+ from nltk.tree import Tree
22
+
23
+ from improved_diffusion.test_util import load_results
24
+
25
+
26
+
27
+ def remove_leaves(tree_):
28
+ # simple_increm = 0
29
+ for s in tree_.subtrees(lambda t: t.height() == 2):
30
+ s[0] = '*'
31
+ s._label = ''
32
+ return tree_
33
+
34
+ def main():
35
+ args = create_argparser().parse_args()
36
+ set_seed(42)
37
+
38
+ # toy1 = 'START Alimentum is not a family - friendly place , located in city centre . \n END'.split()
39
+ # toy1 = 'START Located in riverside area , Alimentum restaurant is a place to bring the whole family . \n END'.split()
40
+ toy1 = ['START', 'The', 'Vaults', 'pub', 'near', 'Café', 'Adriatic', 'has', 'a', '5', 'star', 'rating',
41
+ '.', 'Prices', 'start', 'at', '£', '30', '.', 'END']
42
+
43
+ if args.mode == 'tree':
44
+
45
+ model = AutoModelForCausalLM.from_pretrained(
46
+ args.model_name_or_path, # path to the AR model trained for LMing this task.
47
+ ).cuda()
48
+ model.eval()
49
+
50
+ if args.finetune == 'yes':
51
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
52
+ else:
53
+
54
+ pass
55
+
56
+ control_label_lst = []
57
+ with open('diffusion_lm/improved-diffusion/control_gen/target_tree.json', 'r') as controlf:
58
+ for line in controlf:
59
+ control_label_lst.append(json.loads(line))
60
+
61
+ result_dict = {}
62
+ for label_class_dict in control_label_lst: # control_label_lst[:100]:
63
+ '''
64
+ input_strings = [" ".join(pos_) + tokenizer.bos_token + " ".join(seq) + tokenizer.eos_token
65
+ for (pos_, seq) in zip(pos_lst, examples['text'])]
66
+ '''
67
+ parse_tree = Tree.fromstring(label_class_dict['tree'])
68
+ print(parse_tree)
69
+ parse_tree = remove_leaves(parse_tree)
70
+
71
+ prompt_strings = parse_tree._pformat_flat("", "()", False) + tokenizer.bos_token
72
+ prompt_ids = tokenizer([prompt_strings], return_tensors='pt')
73
+ out_text = generate_samples(args, prompt_ids['input_ids'].cuda(), model, tokenizer)
74
+ result_dict[(label_class_dict['tree'],)] = out_text
75
+ print(len(out_text))
76
+
77
+ fout = open(args.output_text, 'w')
78
+ for k, word_lst in result_dict.items():
79
+ print({k: word_lst}, file=fout)
80
+ fout.close()
81
+
82
+ # # load trees.
83
+ # import benepar
84
+ # parser = benepar.Parser("benepar_en3")
85
+ # input_sentence1 = benepar.InputSentence(
86
+ # words=toy1[1:-1],
87
+ # )
88
+ # parse_lst = list(parser.parse_sents([input_sentence1]))[0]
89
+ # print(parse_lst)
90
+ # parse_lst = remove_leaves(parse_lst)
91
+ # prompt_strings = parse_lst._pformat_flat("", "()", False) + tokenizer.bos_token
92
+ # print(prompt_strings)
93
+ # prompt_ids = tokenizer([prompt_strings], return_tensors='pt')
94
+ # print(prompt_ids['input_ids'].shape)
95
+ #
96
+ # generate_gpt2(args, prompt_ids['input_ids'].cuda())
97
+
98
+ # eval(args)
99
+ if args.mode == 'spans':
100
+
101
+ model = AutoModelForCausalLM.from_pretrained(
102
+ args.model_name_or_path, # path to the AR model trained for LMing this task.
103
+ ).cuda()
104
+ model.eval()
105
+
106
+ if args.finetune == 'yes':
107
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
108
+ else:
109
+ import benepar
110
+ parser = benepar.Parser("benepar_en3")
111
+ tree_vocab = parser._parser.config["label_vocab"]
112
+
113
+ model_path = 'predictability/diffusion_models_v6/diff_e2e-tgt_pad_rand16_transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd102_xstart'
114
+ tokenizer2 = load_tokenizer('e2e-tgt', 'random', model_path)
115
+ tokenizer = {v: k for k, v in tokenizer2.items()}
116
+ print(len(tokenizer), len(tokenizer2), 'loaded vocabs')
117
+
118
+ print('update the vocab to include tree vocabs')
119
+ print(len(tokenizer))
120
+ for x in tree_vocab.keys():
121
+ tokenizer[x] = len(tokenizer)
122
+ print('update the vocab to include indices')
123
+ # tokenizer.add_tokens([str(xx) for xx in range(64)])
124
+ for x in range(64):
125
+ if str(x) not in tokenizer:
126
+ tokenizer[str(x)] = len(tokenizer)
127
+ vocab_dict = tokenizer
128
+ rev_tokenizer = {v: k for k, v in vocab_dict.items()}
129
+ print(len(tokenizer))
130
+
131
+
132
+ control_label_lst = []
133
+ with open('diffusion_lm/improved-diffusion/control_gen/target_spans.json', 'r') as controlf:
134
+ for line in controlf:
135
+ control_label_lst.append(json.loads(line))
136
+
137
+ result_dict = {}
138
+ for span_info in control_label_lst: # control_label_lst[:100]:
139
+ (a,b,c) = span_info['spans'][0]
140
+ if args.finetune == 'yes':
141
+ prompt_strings = f"{a}, {b}, {c}" + tokenizer.bos_token
142
+ print(prompt_strings)
143
+ prompt_ids = tokenizer([prompt_strings], return_tensors='pt')
144
+ out_text = generate_samples(args, prompt_ids['input_ids'].cuda(), model, tokenizer)
145
+ else:
146
+ prompt_ids = [vocab_dict.get(x, vocab_dict['UNK']) for x in f"{a} {b} {c}".split()] + [0]
147
+ print(prompt_ids)
148
+ prompt_ids = torch.LongTensor(prompt_ids).unsqueeze(0)
149
+ out_text = generate_samples_from_scratch(args, prompt_ids.cuda(), model, tokenizer, rev_tokenizer)
150
+ # str(label_class_dict['spans'][0]),
151
+ result_dict[str(span_info['spans'][0])] = out_text
152
+ print(len(out_text))
153
+
154
+ fout = open(args.output_text, 'w')
155
+ for k, word_lst in result_dict.items():
156
+ print({(k,): word_lst}, file=fout)
157
+ fout.close()
158
+ elif args.mode == 'pos':
159
+ import spacy_stanza
160
+ model = AutoModelForCausalLM.from_pretrained(
161
+ args.model_name_or_path, # path to the AR model trained for LMing this task.
162
+ ).cuda()
163
+ model.eval()
164
+
165
+ if args.finetune == 'yes':
166
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
167
+ else:
168
+ pass
169
+
170
+ control_label_lst = []
171
+ with open('diffusion_lm/improved-diffusion/control_gen/target_pos.json', 'r') as controlf:
172
+ for line in controlf:
173
+ control_label_lst.append(json.loads(line))
174
+ print(control_label_lst[:5])
175
+
176
+ result_dict = {}
177
+ for label_class_dict in control_label_lst: # control_label_lst[:100]:
178
+ '''
179
+ input_strings = [" ".join(pos_) + tokenizer.bos_token + " ".join(seq) + tokenizer.eos_token
180
+ for (pos_, seq) in zip(pos_lst, examples['text'])]
181
+ '''
182
+ gold_pos = label_class_dict['pos'][1:-1] # remove START, END.
183
+ words_ = label_class_dict['words_']
184
+ print(gold_pos, 'target POS tagging sequences', tokenizer.bos_token)
185
+ prompt_strings = " ".join(gold_pos) + tokenizer.bos_token
186
+ prompt_ids = tokenizer([prompt_strings], return_tensors='pt')
187
+ out_text = generate_samples(args, prompt_ids['input_ids'].cuda(), model, tokenizer )
188
+ result_dict[tuple(gold_pos)] = out_text
189
+ print(len(out_text))
190
+
191
+ fout = open(args.output_text, 'w')
192
+ for k, word_lst in result_dict.items():
193
+ print({k:word_lst}, file=fout)
194
+ fout.close()
195
+
196
+
197
+ # tagger = spacy_stanza.load_pipeline("en", processors={"tokenize": "spacy"})
198
+ # toy1 = 'START The Mill is a coffee shop with an expensive menu near The Sorrento . \n END'.split()
199
+ # toy1 = ['START', 'The', 'Vaults', 'pub', 'near', 'Café', 'Adriatic', 'has', 'a', '5', 'star', 'rating', '.',
200
+ # 'Prices', 'start', 'at', '£', '30', '.', '\n', 'END']
201
+ # sent_full = " ".join(toy1[1:-1])
202
+ # doc = tagger(sent_full)
203
+ # gold_pos = [token.pos_ for token in doc]
204
+ # print(gold_pos, 'target POS tagging sequences')
205
+ # prompt_strings = " ".join(gold_pos) + tokenizer.bos_token
206
+ # prompt_ids = tokenizer([prompt_strings], return_tensors='pt')
207
+ # generate_gpt2(args, prompt_ids['input_ids'].cuda())
208
+
209
+ elif args.mode == 'attribute':
210
+ model = AutoModelForCausalLM.from_pretrained(
211
+ args.model_name_or_path, # path to the AR model trained for LMing this task.
212
+ ).cuda()
213
+ model.eval()
214
+
215
+ if args.finetune == 'yes':
216
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
217
+ else:
218
+ pass
219
+
220
+ control_label_lst = []
221
+ with open('diffusion_lm/improved-diffusion/control_gen/target_attribute.json', 'r') as controlf:
222
+ for line in controlf:
223
+ control_label_lst.append(json.loads(line))
224
+ print(control_label_lst[:5])
225
+
226
+ result_dict = {}
227
+ for label_class in control_label_lst: # control_label_lst[:100]:
228
+ prompt_strings = " ".join(label_class) + tokenizer.bos_token
229
+ '''
230
+ input_strings = [
231
+ " ".join(attributes) + tokenizer.bos_token + " ".join(words) + tokenizer.eos_token
232
+ for (words, attributes) in examples['text']]
233
+ '''
234
+ print(label_class, 'target attribute sequences', tokenizer.bos_token)
235
+ prompt_ids = tokenizer([prompt_strings], return_tensors='pt')
236
+ out_text = generate_samples(args, prompt_ids['input_ids'].cuda(), model, tokenizer)
237
+ result_dict[tuple(label_class)] = out_text
238
+ print(len(out_text))
239
+
240
+ fout = open(args.output_text, 'w')
241
+ for k, word_lst in result_dict.items():
242
+ print({k: word_lst}, file=fout)
243
+ fout.close()
244
+
245
+ elif args.mode == 'control_len':
246
+ model = AutoModelForCausalLM.from_pretrained(
247
+ args.model_name_or_path, # path to the AR model trained for LMing this task.
248
+ ).cuda()
249
+ model.eval()
250
+
251
+ if args.finetune == 'yes':
252
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
253
+ else:
254
+ pass
255
+
256
+
257
+ result_dict = {}
258
+ for label_class in range(10, 41): # control_label_lst[:100]:
259
+ tgt_len = label_class-2
260
+ prompt_strings = f"{tgt_len}" + tokenizer.bos_token
261
+ print(label_class, 'target attribute sequences', tokenizer.bos_token)
262
+ prompt_ids = tokenizer([prompt_strings], return_tensors='pt')
263
+ out_text = generate_samples(args, prompt_ids['input_ids'].cuda(), model, tokenizer)
264
+ result_dict[tuple([label_class])] = out_text
265
+ print(len(out_text))
266
+
267
+ fout = open(args.output_text, 'w')
268
+ for k, word_lst in result_dict.items():
269
+ print({k: word_lst}, file=fout)
270
+ fout.close()
271
+
272
+ # generate_gpt2(args)
273
+
274
+
275
+ def eval(args):
276
+ text_samples = []
277
+ if args.input_text.endswith('json'):
278
+ with open(args.input_text, 'r') as f:
279
+ for line in f:
280
+ text_samples.append(json.loads(line)[0].split(' '))
281
+ else:
282
+ with open(args.input_text, 'r') as f:
283
+ for line in f:
284
+ text_samples.append(line.strip().split())
285
+
286
+ # tokenize
287
+ # load tokenizer.
288
+ tokenizer = load_tokenizer(args.modality, args.experiment, os.path.split(args.model_path)[0])
289
+ # print(args.modality, tokenizer, args.experiment)
290
+ reverse_tokenizer = {v: k for k, v in tokenizer.items()}
291
+
292
+ agg_loss = []
293
+ for x in text_samples:
294
+ # print(x)
295
+ tokenized_x = [reverse_tokenizer[s] for s in x]
296
+ # print(tokenized_x)
297
+ tokenized_x = torch.LongTensor(tokenized_x).cuda()
298
+ labels = tokenized_x.clone()
299
+ labels[labels == reverse_tokenizer['PAD']] = -100
300
+ model_output = model(tokenized_x, labels=labels)
301
+ # print(model_output.loss)
302
+ agg_loss.append(model_output.loss.item())
303
+
304
+ print(f'\nthe mean loss is {torch.tensor(agg_loss).mean()} for {args.input_text}', )
305
+ print('-' * 50)
306
+ if 'infill' in args.input_text:
307
+ json_path = os.path.join(os.path.split(args.model_path)[0], 'infill_score_decode.json')
308
+ elif 'ema' in args.model_path:
309
+ json_path = os.path.join(os.path.split(args.model_path)[0], 'ema_score_decode.json')
310
+ else:
311
+ json_path = os.path.join(os.path.split(args.model_path)[0], 'score_decode.json')
312
+ print(f'written to {json_path}')
313
+ json_dict = {
314
+ 'score_decode': torch.tensor(agg_loss).mean().item(),
315
+ 'source_decode': args.input_text,
316
+ }
317
+ load_results(json_path, json_dict)
318
+
319
+ def generate_samples(args, prompt, model, tokenizer):
320
+ if args.generation_mode == 'search':
321
+ sample_out = model.generate(prompt, do_sample=False, max_length=200, min_length=prompt.size(1) + 1, num_beams=4,
322
+ top_k=len(tokenizer), top_p=args.top_p, num_return_sequences=1,
323
+ pad_token_id=tokenizer.pad_token_id)
324
+ else:
325
+ sample_out = model.generate(prompt, do_sample=True, max_length=200, min_length=prompt.size(1)+1,
326
+ top_k=len(tokenizer), top_p=args.top_p, num_return_sequences=1,
327
+ pad_token_id=tokenizer.pad_token_id)
328
+ sample_out_lst = sample_out[:, prompt.size(1):]
329
+ # sample_out_lst.append(sample_out.cpu())
330
+ # sample_out_lst = torch.cat(sample_out_lst, dim=0)
331
+ text_out = []
332
+ for sample in sample_out_lst:
333
+ sample = sample.tolist()
334
+ words_sample = tokenizer.decode(sample, skip_special_tokens=True)
335
+ text_out.append(words_sample)
336
+ return text_out
337
+
338
+ def generate_samples_from_scratch(args, prompt, model, tokenizer, rev_tokenizer):
339
+ print('generating from scratch')
340
+ if args.generation_mode == 'search':
341
+ sample_out = model.generate(prompt, do_sample=False, max_length=200, min_length=prompt.size(1) + 1, num_beams=4,
342
+ top_k=len(tokenizer), top_p=args.top_p, num_return_sequences=1,
343
+ pad_token_id=tokenizer['PAD'], eos_token_id=tokenizer['END'])
344
+ else:
345
+ sample_out = model.generate(prompt, do_sample=True, max_length=200, min_length=prompt.size(1) + 1,
346
+ top_k=len(tokenizer), top_p=args.top_p, num_return_sequences=50,
347
+ pad_token_id=tokenizer['PAD'], eos_token_id=tokenizer['END'])
348
+ sample_out_lst = sample_out[:, prompt.size(1):]
349
+ # sample_out_lst.append(sample_out.cpu())
350
+ # sample_out_lst = torch.cat(sample_out_lst, dim=0)
351
+ text_out = []
352
+ for sample in sample_out_lst:
353
+ sample = sample.tolist()
354
+ words_sample = " ".join([rev_tokenizer[x] for x in sample])
355
+ text_out.append(words_sample)
356
+ return text_out
357
+
358
+ def generate_gpt2(args, prompt=None):
359
+
360
+ print(f'loading from {args.model_name_or_path}')
361
+ model = AutoModelForCausalLM.from_pretrained(
362
+ args.model_name_or_path, # path to the AR model trained for LMing this task.
363
+ ).cuda()
364
+
365
+ # load tokenizer.
366
+ sample_out_lst = []
367
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
368
+ sample_out = model.generate(prompt, do_sample=True, max_length=200,
369
+ top_k=len(tokenizer), top_p=args.top_p, num_return_sequences=50, pad_token_id=tokenizer.pad_token_id)
370
+ sample_out = sample_out[:, prompt.size(1):]
371
+ sample_out_lst.append(sample_out.cpu())
372
+ sample_out_lst = torch.cat(sample_out_lst, dim=0)
373
+
374
+
375
+ if args.output_text.endswith('json'):
376
+ with open(args.output_text, 'w') as f:
377
+ for sample in sample_out_lst:
378
+ sample = sample.tolist()
379
+ words_sample = tokenizer.decode(sample, skip_special_tokens=True)
380
+ print(json.dumps([words_sample]), file=f)
381
+ else:
382
+ with open(args.output_text, 'w') as f:
383
+ for sample in sample_out_lst:
384
+ sample = sample.tolist()
385
+ words_sample = tokenizer.decode(sample, skip_special_tokens=True)
386
+ print(words_sample, file=f)
387
+
388
+ agg_loss = []
389
+ for tokenized_x in sample_out:
390
+ labels = tokenized_x.clone()
391
+ labels[labels == tokenizer.eos_token_id] = -100
392
+ model_output = model(tokenized_x, labels=labels)
393
+ agg_loss.append(model_output.loss.item())
394
+
395
+ print(f'\nthe mean loss is {torch.tensor(agg_loss).mean()}',)
396
+ print('-'*50)
397
+
398
+ def generate(args):
399
+
400
+ model = AutoModelForCausalLM.from_pretrained(
401
+ args.model_name_or_path, # path to the AR model trained for LMing this task.
402
+ ).cuda()
403
+
404
+ print(model.transformer.wte)
405
+ # print(model)
406
+ # load tokenizer.
407
+ tokenizer = load_tokenizer(args.modality, args.experiment, os.path.split(args.model_path)[0])
408
+ reverse_tokenizer = {v: k for k, v in tokenizer.items()}
409
+ print(len(tokenizer))
410
+
411
+ init_prompt = torch.LongTensor([reverse_tokenizer['START']]).view(1,1).expand(50, -1).to(model.device)
412
+ sample_out = model.generate(init_prompt, do_sample=True, max_length=64,
413
+ top_k=len(tokenizer), top_p=args.top_p)
414
+ print(sample_out.shape)
415
+
416
+ if args.output_text.endswith('json'):
417
+ with open(args.output_text, 'w') as f:
418
+ for sample in sample_out:
419
+ sample = sample.tolist()
420
+ words_sample = [tokenizer[s] for s in sample]
421
+ print(json.dumps([" ".join(words_sample)]), file=f)
422
+ else:
423
+ with open(args.output_text, 'w') as f:
424
+ for sample in sample_out:
425
+ sample = sample.tolist()
426
+ words_sample = [tokenizer[s] for s in sample]
427
+ print(" ".join(words_sample), file=f)
428
+
429
+ agg_loss = []
430
+ for tokenized_x in sample_out:
431
+ model_output = model(tokenized_x, labels=tokenized_x)
432
+ agg_loss.append(model_output.loss.item())
433
+
434
+ print(f'\nthe mean loss is {torch.tensor(agg_loss).mean()}',)
435
+ print('-'*50)
436
+
437
+ ##################
438
+
439
+ text_samples = []
440
+ if args.output_text.endswith('json'):
441
+ with open(args.output_text, 'r') as f:
442
+ for line in f:
443
+ text_samples.append(json.loads(line)[0].split(' '))
444
+ else:
445
+ with open(args.output_text, 'r') as f:
446
+ for line in f:
447
+ text_samples.append(line.strip().split())
448
+
449
+
450
+ agg_loss = []
451
+ for idx, x in enumerate(text_samples):
452
+ # print(x)
453
+ tokenized_x = [reverse_tokenizer[s] for s in x]
454
+ tokenized_x = torch.LongTensor(tokenized_x).cuda()
455
+ # print(tokenized_x)
456
+ # print(sample_out[idx])
457
+ # print((tokenized_x == sample_out[idx]).all())
458
+ model_output = model(tokenized_x, labels=tokenized_x)
459
+ # print(model_output.loss)
460
+ agg_loss.append(model_output.loss.item())
461
+
462
+ print(f'\nthe mean loss is {torch.tensor(agg_loss).mean()} for {args.input_text}', )
463
+
464
+
465
+
466
+ def create_argparser():
467
+ defaults = dict(
468
+ clip_denoised=True,
469
+ num_samples=50,#10000,
470
+ batch_size=16,
471
+ use_ddim=False,
472
+ model_path="",
473
+ model_arch='conv-unet',
474
+ verbose='yes',
475
+ finetune='yes',
476
+ generation_mode='sample',
477
+ )
478
+ text_defaults = dict(modality='text',
479
+ dataset_name='wikitext',
480
+ input_text='',
481
+ mode='eval',
482
+ output_text='',
483
+ dataset_config_name='wikitext-2-raw-v1',
484
+ model_name_or_path='predictability/diff_models/compress_e=5_b=60_m=gpt2_wikitext-103-raw-v1_None',
485
+ experiment='gpt2_pre_compress', model_arch='trans-unet',
486
+ preprocessing_num_workers=1, top_p=1.0,)
487
+ defaults.update(model_and_diffusion_defaults())
488
+ defaults.update(text_defaults)
489
+ # defaults.update(model_and_diffusion_defaults())
490
+ parser = argparse.ArgumentParser()
491
+ add_dict_to_argparser(parser, defaults)
492
+ return parser
493
+
494
+
495
+
496
+
497
+
498
+ if __name__ == '__main__':
499
+ with torch.no_grad():
500
+ main()
src/control_gen/eval_control.py ADDED
@@ -0,0 +1,567 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, argparse, json
2
+ import benepar, spacy_stanza
3
+ import numpy as np
4
+ import sys, os
5
+ import csv
6
+ from nltk.tree import Tree
7
+ sys.path.insert(0, os.path.join(sys.path[0], '../scripts/'))
8
+ from tree_helper import chart_from_tree, pad_charts, padded_chart_from_spans
9
+ sys.path.insert(0, os.path.join(sys.path[0], '../../misc/self-attentive-parser/src/'))
10
+ import evaluate
11
+ from spacy.lang.en import English
12
+ from collections import defaultdict
13
+ from transformers import AutoModelForCausalLM, AutoTokenizer
14
+ from improved_diffusion.rounding import rounding_func, load_models, load_tokenizer
15
+ nlp = English()
16
+ tokenizer_spacy = nlp.tokenizer
17
+
18
+ def eval_ppl2(args, text_samples):
19
+ print(f'loading from {args.model_name_or_path}')
20
+ model = AutoModelForCausalLM.from_pretrained(
21
+ args.model_name_or_path, # path to the AR model trained for LMing this task.
22
+ ).cuda()
23
+
24
+ if 'r2l' in args.model_name_or_path:
25
+ print('Use the right-to-left encoding.')
26
+
27
+ args.model_path = 'predictability/diffusion_models_v6/diff_e2e-tgt_pad_rand16_transformer_' \
28
+ 'lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd102_xstart/ema_0.9999_200000.pt'
29
+ tokenizer = load_tokenizer('e2e-tgt', 'random', os.path.split(args.model_path)[0])
30
+ # print(args.modality, tokenizer, args.experiment)
31
+ reverse_tokenizer = {v: k for k, v in tokenizer.items()}
32
+ full_score = []
33
+ for idxx, (gold, full_word_lst) in enumerate(text_samples.items()):
34
+ # print(len(full_word_lst), full_word_lst[0])
35
+ agg_loss = []
36
+ for x in full_word_lst:
37
+ # x = " ".join(x).split()
38
+ if 'r2l' in args.model_name_or_path:
39
+ string = ["START"] + list(reversed(x)) + ["END"]
40
+ tokenized_x = [reverse_tokenizer.get(s, reverse_tokenizer['UNK']) for s in string]
41
+ else:
42
+ tokenized_x = [reverse_tokenizer['START']] + [reverse_tokenizer.get(s, reverse_tokenizer['UNK']) for s in x] \
43
+ + [reverse_tokenizer['END']]
44
+ # print(tokenized_x)
45
+ tokenized_x = torch.LongTensor(tokenized_x).cuda()
46
+ labels = tokenized_x.clone()
47
+ labels[labels == reverse_tokenizer['PAD']] = -100
48
+ model_output = model(tokenized_x, labels=labels)
49
+ # print(model_output.loss)
50
+ # if idxx == 3:
51
+ # print(tokenized_x, model_output.loss.item())
52
+ agg_loss.append(model_output.loss.item())
53
+ example_mean_score = torch.tensor(agg_loss).mean()
54
+ # print(f'\nthe mean loss is {example_mean_score} for index', idxx )
55
+ full_score.append(example_mean_score)
56
+ full_score_ = np.array(full_score).mean()
57
+ print(f'full NLL score is {full_score_} for {len(full_score)}')
58
+ print(f'full PPL score is {np.e ** full_score_} for {len(full_score)}')
59
+
60
+
61
+
62
+ def eval_ppl(args, text_samples):
63
+ '''
64
+ Evaluating using GPT2 finetuned on this task...
65
+ :param text_lst:
66
+ :return:
67
+ '''
68
+
69
+ # load model
70
+ print(f'loading from {args.model_name_or_path}')
71
+ model = AutoModelForCausalLM.from_pretrained(
72
+ args.model_name_or_path, # path to the AR model trained for LMing this task.
73
+ ).cuda()
74
+
75
+ # load tokenizer.
76
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
77
+
78
+ print('finished loading models.')
79
+
80
+ args.model_path = 'predictability/diffusion_models_v6/diff_e2e-tgt_pad_rand16_transformer_' \
81
+ 'lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd102_xstart/ema_0.9999_200000.pt'
82
+ diff_tokenizer = load_tokenizer('e2e-tgt', 'random', os.path.split(args.model_path)[0])
83
+ reverse_diff_tokenizer = {v: k for k, v in diff_tokenizer.items()}
84
+
85
+ full_score = []
86
+ for gold, full_word_lst in text_samples.items():
87
+ agg_loss = []
88
+ for x in full_word_lst:
89
+ x = [kk if kk in reverse_diff_tokenizer else 'UNK' for kk in x]
90
+ x = tokenizer.bos_token + " ".join(x) + tokenizer.eos_token
91
+ # print(x)
92
+ # should also add BOS EOS token?
93
+
94
+ tokenized_x = tokenizer(x, return_tensors='pt') #[reverse_tokenizer[s] for s in x]
95
+ input_ids = tokenized_x['input_ids'].cuda()
96
+ labels = input_ids.clone()
97
+ # print(tokenized_x)
98
+ # tokenized_x = torch.LongTensor(tokenized_x).cuda()
99
+ # labels = tokenized_x.clone()
100
+ # labels[labels == reverse_tokenizer['PAD']] = -100
101
+ model_output = model(input_ids, labels=labels)
102
+ agg_loss.append(model_output.loss.item())
103
+ example_mean_score = torch.tensor(agg_loss).mean()
104
+ # print(f'\nthe mean loss is {example_mean_score}', )
105
+ full_score.append(example_mean_score)
106
+ full_score_ = np.array(full_score).mean()
107
+ print(f'full NLL score is {full_score_} for {len(full_score)}')
108
+ print(f'full PPL score is {np.e ** full_score_} for {len(full_score)}')
109
+
110
+
111
+ def read_files(args):
112
+ '''
113
+ :param args:
114
+ :return: list of tokenized sentences.
115
+ '''
116
+ if args.input_format == 'file':
117
+ text_samples = []
118
+ if args.input_text.endswith('json'):
119
+ with open(args.input_text, 'r') as f:
120
+ for line in f:
121
+ words = [x.text for x in tokenizer_spacy(json.loads(line)[0])]
122
+ text_samples.append(words)
123
+ # text_samples.append(json.loads(line)[0].split(' '))
124
+
125
+
126
+ else:
127
+ with open(args.input_text, 'r') as f:
128
+ for line in f:
129
+ text_samples.append(line.strip().split())
130
+
131
+ # remove trailing PAD tokens.
132
+ text_samples2 = []
133
+ for sent in text_samples:
134
+ tempsent = [x for x in sent if x != 'PAD']
135
+ if tempsent[0] == 'START':
136
+ tempsent = tempsent[1:]
137
+ if tempsent[-1] == 'END':
138
+ tempsent = tempsent[:-1]
139
+ if tempsent[-1] == '\n' and args.mode in ['e2e-tgt-tree', 'e2e-tgt-tree-paired']:
140
+ tempsent[-1] = '.'
141
+ text_samples2.append(tempsent)
142
+ return text_samples2
143
+ elif args.input_format == 'paired':
144
+ import ast
145
+ # nlp = English()
146
+ # tokenizer = nlp.tokenizer
147
+ result_lst = defaultdict(list)
148
+
149
+ if args.input_text.endswith('json'):
150
+ with open(args.input_text, 'r') as f:
151
+ for line in f:
152
+ try:
153
+ line = json.loads(line)
154
+ except:
155
+ if args.mode == 'e2e-tgt-spans-paired':
156
+ line = ast.literal_eval(line)
157
+ line = {tuple(ast.literal_eval(k[0])) : v for k, v in line.items()}
158
+ result_lst.update(line)
159
+ else:
160
+ line = ast.literal_eval(line)
161
+ result_lst.update(line)
162
+
163
+ elif args.input_text.endswith('log'):
164
+ with open(args.input_text, 'r') as csvfile:
165
+ roc_reader = csv.reader(csvfile) #delimiter=' ', quotechar='|')
166
+ for idx, row in enumerate(roc_reader):
167
+ if idx == 0: continue
168
+ if args.mode == 'e2e-tgt-spans-paired' or args.mode == 'e2e-tgt-length-paired':
169
+ pos = tuple(ast.literal_eval(row[0]))
170
+
171
+ if args.mode == 'e2e-tgt-length-paired':
172
+ pos = list(pos)
173
+ pos[0] = int(pos[0]) + 2 # because this count didn't accounted for START and END
174
+ pos = tuple(pos)
175
+ else:
176
+ pos = tuple(row[0].split())
177
+ result_lst[pos].append(row[2])
178
+
179
+ clean_result_lst = {}
180
+ for k, text_samples in result_lst.items():
181
+ text_samples2 = []
182
+ for sent in text_samples:
183
+ sent = sent.split(' ')
184
+ # KEY DEBUG.
185
+ # sent = [x.text for x in tokenizer_spacy(sent)]
186
+ # print(sent, sent2)
187
+ # 10/0
188
+ tempsent = [x for x in sent if x != 'PAD']
189
+ if tempsent[0] == 'START':
190
+ tempsent = tempsent[1:]
191
+ if tempsent[-1] == 'END':
192
+ tempsent = tempsent[:-1]
193
+ if tempsent[-1] == '\n' and args.mode == 'e2e-tgt-tree':
194
+ tempsent[-1] = '.'
195
+
196
+ # KEY DEBUG.
197
+ tempsent = " ".join(tempsent)
198
+ tempsent = [x.text for x in tokenizer_spacy(tempsent)]
199
+ text_samples2.append(tempsent)
200
+ if k[0] == 'START' and k[-1] == 'END':
201
+ kk_ = k[1:-1]
202
+ else:
203
+ kk_ = k
204
+ clean_result_lst[kk_] = text_samples2 # remove start and end from the training data.
205
+ return clean_result_lst
206
+
207
+ def eval_parse(parser, generated, tree_vocab):
208
+ sent_lst = []
209
+ for sent in generated:
210
+ # print(sent)
211
+ input_sentence1 = benepar.InputSentence(
212
+ words=sent,
213
+ )
214
+ sent_lst.append(input_sentence1)
215
+ parse_lst = list(parser.parse_sents(sent_lst))
216
+ # print(examples['text'][:10])
217
+ assert len(parse_lst) == len(generated)
218
+ # print(parse_lst[:2])
219
+ spans_lst = []
220
+ for parse in parse_lst:
221
+ chart, spans = chart_from_tree(tree_vocab, parse, verbose=True)
222
+ spans_lst.append(spans)
223
+ return parse_lst, spans_lst
224
+
225
+ def levenshteinDistance(s1, s2):
226
+ if len(s1) > len(s2):
227
+ s1, s2 = s2, s1
228
+ distances = range(len(s1) + 1)
229
+ for i2, c2 in enumerate(s2):
230
+ distances_ = [i2+1]
231
+ for i1, c1 in enumerate(s1):
232
+ if c1 == c2:
233
+ distances_.append(distances[i1])
234
+ else:
235
+ distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1])))
236
+ distances = distances_
237
+ return distances[-1]
238
+
239
+ def score_spans(gold_spans, generated_span):
240
+ print(gold_spans)
241
+ print(generated_span)
242
+ gold_spans = set([gold_spans])
243
+ generated_span = set(generated_span)
244
+ intersection = gold_spans.intersection(generated_span)
245
+ print(intersection, len(intersection) / len(gold_spans))
246
+ # union = gold_spans.union(generated_span)
247
+ # print(len(union), len(intersection))
248
+
249
+ # if unlabeled:
250
+ # print(generated_span)
251
+ # unlabeled_gold_spans = set([(a,b) for (a, b, v) in gold_spans])
252
+ # unlabeled_generated_span =set([(a,b) for (a, b, v) in generated_span])
253
+ # intersection = gold_spans.intersection(generated_span)
254
+ # union = gold_spans.union(generated_span)
255
+ return len(intersection) / len(gold_spans)
256
+
257
+ def score_tree(gold_tree, pred_trees):
258
+ # print([x.leaves() for x in pred_trees])
259
+
260
+ def reset_leaves(tree_):
261
+ simple_increm = 0
262
+ for s in tree_.subtrees(lambda t: t.height() == 2):
263
+ s[0] = simple_increm
264
+ s._label = 'NN'
265
+ simple_increm += 1
266
+ return simple_increm
267
+
268
+ # reset.
269
+ increm_gold = reset_leaves(gold_tree)
270
+ # print(increm_gold)
271
+ for i, pred in enumerate(pred_trees):
272
+ increm_pred = reset_leaves(pred)
273
+ # print(increm_pred, 'pred', i)
274
+
275
+ use_evalb = True
276
+ if use_evalb:
277
+ # print(len(gold_tree), len(pred_trees), gold_tree)
278
+ gold_trees = [gold_tree] * len(pred_trees)
279
+ print(len(gold_tree.leaves()), [len(x.leaves()) for x in pred_trees])
280
+ # print(pred_trees[0])
281
+ dev_fscore = evaluate.evalb('diffusion_lm/misc/self-attentive-parser/EVALB',
282
+ gold_trees, pred_trees)
283
+ print(dev_fscore)
284
+
285
+ return dev_fscore
286
+
287
+ def score_pos(gold_pos, generated_pos):
288
+ ed = levenshteinDistance(gold_pos, generated_pos)
289
+ return 1 - (ed / len(gold_pos))
290
+
291
+ def score_pos_em(gold_pos, generated_pos):
292
+ # print(len(gold_pos), len(generated_pos), gold_pos, generated_pos)
293
+ if len(generated_pos) > len(gold_pos):
294
+ generated_pos = generated_pos[:len(gold_pos)]
295
+ elif len(generated_pos) < len(gold_pos):
296
+ generated_pos = generated_pos + ['PAD'] * (len(gold_pos) - len(generated_pos))
297
+ assert len(gold_pos) == len(generated_pos)
298
+ correct = 0
299
+ all = 0
300
+ for x1, x2 in zip(gold_pos, generated_pos):
301
+ if x1 == x2:
302
+ correct += 1
303
+ all += 1
304
+ return correct/all
305
+
306
+ def score_attributes(gold_att, generated):
307
+ if gold_att in generated:
308
+ return 1.
309
+ else:
310
+ return 0.
311
+
312
+ def eval_pos(tagger, generated_text):
313
+ generated_pos = []
314
+ for sent in generated_text:
315
+ sent_full = " ".join(sent)
316
+ doc = tagger(sent_full)
317
+ generated_pos.append([token.pos_ for token in doc])
318
+ return generated_pos
319
+
320
+ def eval_(args, text_samples):
321
+ if args.mode == 'e2e-tgt-tree':
322
+
323
+ parser = benepar.Parser("benepar_en3")
324
+ tree_vocab = parser._parser.config["label_vocab"]
325
+ if args.gold_ref == 'full':
326
+ # toy1 = 'START Located in riverside area , Alimentum restaurant is a place to bring the whole family . \n END'.split()
327
+ # toy1 = 'START Alimentum is not a family - friendly place , located in city centre . \n END'.split()
328
+ toy1 = ['START', 'The', 'Vaults', 'pub', 'near', 'Café', 'Adriatic', 'has', 'a', '5', 'star', 'rating',
329
+ '.', 'Prices', 'start', 'at', '£', '30', '.', 'END']
330
+ input_sentence1 = benepar.InputSentence(
331
+ words=toy1[1:-1],
332
+ )
333
+ gold_parse = list(parser.parse_sents([input_sentence1]))[0]
334
+ chart, gold_spans = chart_from_tree(tree_vocab, gold_parse, verbose=True)
335
+ print(len(toy1[1:-1]), len(list(gold_parse.leaves())))
336
+ elif args.gold_ref == 'span':
337
+ # spans = [(10, 14, 'ADJP')]
338
+ gold_spans = [(0, 4, 'S::VP')]
339
+ gold_spans = [(0, 0, 'NP')]
340
+ gold_spans = [(9, 13, 'ADJP')]
341
+ # gold_spans = [(9, 13, 'PP')]
342
+
343
+ print(text_samples[:1])
344
+ # correct for length:
345
+ target_len = len(gold_parse.leaves())
346
+ print(gold_parse.leaves(), 'target')
347
+ for i, x in enumerate(text_samples):
348
+ if len(x) == target_len:
349
+ continue
350
+ elif len(x) > target_len:
351
+ text_samples[i] = x[:target_len]
352
+ else:
353
+ print('padded to same length', (target_len-len(x)))
354
+ text_samples[i] = x + ['.'] * (target_len-len(x))
355
+ # print(text_samples[i])
356
+ # print('SAD, our model is shorter??')
357
+ generated_parse, generated_span = eval_parse(parser, text_samples, tree_vocab)
358
+ # print(gold_spans)
359
+ # print(generated_span[:2])
360
+ evalb_score = score_tree(gold_parse, generated_parse)
361
+ print([len(x) for x in text_samples])
362
+ score_lst = []
363
+ for x in generated_span:
364
+ score_lst.append(score_spans(gold_spans, x))
365
+
366
+ print(np.array(score_lst).mean())
367
+ elif args.mode == 'e2e-tgt-pos':
368
+ tagger = spacy_stanza.load_pipeline("en", processors='tokenize,mwt,pos', ) #processors={"tokenize": "spacy",}
369
+ if args.gold_ref == 'full':
370
+ toy1 = 'START The Mill is a coffee shop with an expensive menu near The Sorrento . \n END'.split()
371
+ toy1 = ['START', 'The', 'Vaults', 'pub', 'near', 'Café', 'Adriatic', 'has', 'a', '5', 'star', 'rating', '.',
372
+ 'Prices', 'start', 'at', '£', '30', '.', '\n', 'END']
373
+ sent_full = " ".join(toy1[1:-1])
374
+ doc = tagger(sent_full)
375
+ gold_pos = [token.pos_ for token in doc]
376
+ elif args.gold_ref == 'span':
377
+ gold_pos = [(9, 'PROPN')]
378
+
379
+ generated_pos = eval_pos(tagger, text_samples)
380
+ score_lst = []
381
+ score_lst2 = []
382
+ for x in generated_pos:
383
+ print(gold_pos)
384
+ print(x)
385
+ print()
386
+ score_lst.append(score_pos(gold_pos, x))
387
+ score_lst2.append(score_pos_em(gold_pos, x))
388
+
389
+ print(np.array(score_lst).mean())
390
+ print(np.array(score_lst2).mean())
391
+ elif args.mode == 'e2e-tgt-pos-paired':
392
+ import stanza
393
+ nlp = spacy_stanza.load_pipeline("en", processors={"tokenize": "spacy"})
394
+ print(nlp)
395
+ # nlp = stanza.Pipeline("en", processors={"tokenize": "spacy", 'pos': 'combined'}, package=None)
396
+
397
+ full_score = []
398
+ for gold, full_word_lst in text_samples.items():
399
+ print(gold, len(full_word_lst), full_word_lst[:2])
400
+ # full_word_lst = full_word_lst[:2]
401
+ sent_lst = [" ".join(seq) for seq in full_word_lst]
402
+ sent_full = " ".join(sent_lst)
403
+ # print(sent_lst)
404
+ try:
405
+ doc = nlp(sent_full)
406
+ doc_token_pos = [(token.text, token.pos_,) for token in doc]
407
+ len_lst = [len(seq) for seq in full_word_lst]
408
+ print(sum(len_lst), len(doc_token_pos), 'should be equal!!! ')
409
+ assert sum(len_lst) == len(doc_token_pos)
410
+ pos_lst = []
411
+ init_idx = 0
412
+ for len_temp in len_lst:
413
+ pos_lst.append([x[1] for x in doc_token_pos[init_idx:init_idx + len_temp]])
414
+ init_idx = init_idx + len_temp
415
+
416
+ except:
417
+ print(f'stanza pipeline failed... for this {gold}')
418
+
419
+ # parse each sentence separately...
420
+ pos_lst = []
421
+ for single_sent in sent_lst:
422
+ doc = nlp(single_sent)
423
+ # doc_token_pos = [(token.text, token.pos_,) for token in doc]
424
+ pos_lst.append([ token.pos_ for token in doc])
425
+
426
+
427
+ score_lst = []
428
+ score_lst2 = []
429
+ for x in pos_lst:
430
+ score_lst.append(score_pos(gold, x))
431
+ score_lst2.append(score_pos_em(gold, x))
432
+ score_ed = np.array(score_lst).mean()
433
+ score_em = np.array(score_lst2).mean()
434
+ print(len(score_lst), score_ed, score_em)
435
+ full_score.append(score_em)
436
+ full_score_em = np.array(full_score).mean()
437
+ print(full_score_em, f"\pm {np.array(full_score).std()}", len(full_score))
438
+
439
+ if args.mode == 'e2e-tgt-tree-paired':
440
+
441
+ parser = benepar.Parser("benepar_en3")
442
+ tree_vocab = parser._parser.config["label_vocab"]
443
+
444
+ full_score = []
445
+ for idx, (gold_parse, full_word_lst) in enumerate(text_samples.items()):
446
+ # to avoid evalb complain --> change \n to .
447
+ gold_parse_str = gold_parse[0]
448
+ gold_parse_str = gold_parse_str.replace('\n', '.')
449
+ # print([gold_parse_str], 'gold tree string ')
450
+ gold_parse = Tree.fromstring(gold_parse_str)
451
+ target_len = len(gold_parse.leaves())
452
+ # print(gold_parse.leaves(), 'target')
453
+ # print(full_word_lst)
454
+ for i, x in enumerate(full_word_lst):
455
+ if len(x) == target_len:
456
+ continue
457
+ elif len(x) > target_len:
458
+ print('generated seq is longer than gold seq')
459
+ full_word_lst[i] = x[:target_len]
460
+ else:
461
+ print('padded to same length', (target_len - len(x)))
462
+ full_word_lst[i] = x + ['.'] * (target_len - len(x))
463
+ # print(text_samples[i])
464
+ # print('SAD, our model is shorter??')
465
+ generated_parse, generated_span = eval_parse(parser, full_word_lst, tree_vocab)
466
+ evalb_score = score_tree(gold_parse, generated_parse) # inputs are nltk.Tree
467
+ # print(type(evalb_score))
468
+ print(evalb_score.fscore)
469
+ full_score.append(evalb_score.fscore)
470
+ full_score_f1 = np.array(full_score).mean()
471
+ # print(full_score_f1, len(full_score))
472
+ print(full_score_f1, f"\pm {np.array(full_score).std()}", len(full_score))
473
+
474
+ elif args.mode == 'e2e-tgt-spans-paired':
475
+
476
+ parser = benepar.Parser("benepar_en3")
477
+ tree_vocab = parser._parser.config["label_vocab"]
478
+
479
+ full_score = []
480
+ for idx, (gold_spans, full_word_lst) in enumerate(text_samples.items()):
481
+ # to avoid evalb complain --> change \n to .
482
+ print(gold_spans, '11 gold')
483
+ generated_parse, generated_span = eval_parse(parser, full_word_lst, tree_vocab)
484
+ score_lst = []
485
+ for x in generated_span:
486
+ score_lst.append(score_spans(gold_spans, x))
487
+ print(score_lst)
488
+ score_lst_mean = np.array(score_lst).mean()
489
+ full_score.append(score_lst_mean)
490
+ full_score_span = np.array(full_score).mean()
491
+ print(full_score_span, f"\pm {np.array(full_score).std()}", len(full_score))
492
+
493
+ if args.mode == 'e2e-tgt-attribute-paired':
494
+
495
+ full_score = []
496
+ for idx, (attribute, full_word_lst) in enumerate(text_samples.items()):
497
+ # print(attribute)
498
+ attribute = " ".join(attribute).split(':')[1].strip()
499
+ gold_attribute = attribute
500
+ score_lst = []
501
+ for i, x in enumerate(full_word_lst):
502
+ # print(gold_attribute, x)
503
+ score_lst.append(score_attributes(gold_attribute, " ".join(x)))
504
+ score_lst_mean = np.array(score_lst).mean()
505
+ full_score.append(score_lst_mean)
506
+ full_score_mean = np.array(full_score).mean()
507
+ # print(full_score_mean, len(full_score))
508
+ print(full_score_mean, f"\pm {np.array(full_score).std()}", len(full_score))
509
+
510
+ if args.mode == 'e2e-tgt-length-paired':
511
+
512
+ full_score = []
513
+ for idx, (attribute, full_word_lst) in enumerate(text_samples.items()):
514
+ tgt_len = int(attribute[0]) - 2 # remove START and END.
515
+ score_lst = []
516
+ for i, x in enumerate(full_word_lst):
517
+ if tgt_len == len(x):
518
+ # if np.abs(tgt_len - len(x)) <= 2:
519
+ score_lst.append(1.)
520
+ else:
521
+ score_lst.append(0.)
522
+ score_lst_mean = np.array(score_lst).mean()
523
+ full_score.append(score_lst_mean)
524
+ full_score_mean = np.array(full_score).mean()
525
+ # print(full_score_mean, len(full_score))
526
+ print(full_score_mean, f"\pm {np.array(full_score).std()}", len(full_score))
527
+
528
+ elif args.mode == 'e2e-tgt-attribute':
529
+ gold_attribute = ""
530
+ score_lst = []
531
+ for x in text_samples:
532
+ score_lst.append(score_attributes(gold_attribute, x))
533
+ print(np.array(score_lst).mean())
534
+
535
+ if __name__ == '__main__':
536
+
537
+ # 'diffusion_lm/improved_diffusion/out_gen/diff_e2e-tgt_pad_rand16_transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd102_xstart.ema_0.9999_200000.pt.infill_control_tree_50x64x16_tree_partial-cat-lgv0.1.json'
538
+ parser = argparse.ArgumentParser(description='training args.')
539
+ parser.add_argument('--input_text', type=str, default='diffusion_lm/improved_diffusion/out_gen/diff_e2e-tgt_pad_rand16_transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd102_xstart.ema_0.9999_200000.pt.'
540
+ 'infill_control_tree_50x64x16_tree_partial-cat-lgv0.1.json',)
541
+ parser.add_argument('--input_format', type=str, default='batch', help='wp, wikitext')
542
+
543
+ parser.add_argument('--mode', type=str, default='e2e-tgt-tree', help='')
544
+ parser.add_argument('--gold_ref', type=str, default='full', help='')
545
+ parser.add_argument('--model_name_or_path', type=str, default='predictability/diff_models/e2e-tgt_e=20_b=64_m=gpt2_wikitext-103-raw-v1_101_wp_finetune_UNK', help='')
546
+ # default='predictability/diff_models/e2e-tgt_e=6_b=10_m=gpt2_wikitext-103-raw-v1_101_wp_pad', help='')
547
+
548
+
549
+
550
+ args = parser.parse_args()
551
+ text_samples = read_files(args)
552
+ eval_(args, text_samples)
553
+ eval_ppl(args, text_samples)
554
+ # eval_ppl2(args, text_samples)
555
+
556
+
557
+
558
+
559
+
560
+
561
+
562
+
563
+
564
+
565
+
566
+
567
+
src/ev.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os.path as osp
3
+ from nltk.translate.bleu_score import corpus_bleu
4
+ from rdkit import RDLogger
5
+ from Levenshtein import distance as lev
6
+ from rdkit import Chem
7
+ from rdkit.Chem import MACCSkeys
8
+ from rdkit import DataStructs
9
+ from rdkit.Chem import AllChem
10
+ from rdkit import DataStructs
11
+ RDLogger.DisableLog('rdApp.*')
12
+ from fcd import get_fcd, load_ref_model, canonical_smiles
13
+ import warnings
14
+ import os
15
+ warnings.filterwarnings('ignore')
16
+
17
+ def get_smis(filepath):
18
+ print(filepath)
19
+ with open(filepath) as f:
20
+ lines = f.readlines()
21
+ gt_smis= []
22
+ op_smis = []
23
+ for s in lines:
24
+ if len(s)<3:
25
+ continue
26
+ s0,s1 = s.split(' || ')
27
+ s0,s1 = s0.strip().replace('[EOS]','').replace('[SOS]','').replace('[X]','').replace('[XPara]','').replace('[XRing]',''),s1.strip()
28
+ gt_smis.append(s1)
29
+ op_smis.append(s0)
30
+ return gt_smis,op_smis
31
+
32
+ def evaluate(gt_smis,op_smis):
33
+ references = []
34
+ hypotheses = []
35
+ for i, (gt, out) in enumerate(zip(gt_smis,op_smis)):
36
+ gt_tokens = [c for c in gt]
37
+ out_tokens = [c for c in out]
38
+ references.append([gt_tokens])
39
+ hypotheses.append(out_tokens)
40
+ # BLEU score
41
+ bleu_score = corpus_bleu(references, hypotheses)
42
+ references = []
43
+ hypotheses = []
44
+ levs = []
45
+ num_exact = 0
46
+ bad_mols = 0
47
+ for i, (gt, out) in enumerate(zip(gt_smis,op_smis)):
48
+ hypotheses.append(out)
49
+ references.append(gt)
50
+ try:
51
+ m_out = Chem.MolFromSmiles(out)
52
+ m_gt = Chem.MolFromSmiles(gt)
53
+ if Chem.MolToInchi(m_out) == Chem.MolToInchi(m_gt): num_exact += 1
54
+ except:
55
+ bad_mols += 1
56
+ levs.append(lev(out, gt))
57
+ # Exact matching score
58
+ exact_match_score = num_exact/(i+1)
59
+ # Levenshtein score
60
+ levenshtein_score = np.mean(levs)
61
+ validity_score = 1 - bad_mols/len(gt_smis)
62
+ return bleu_score, exact_match_score, levenshtein_score, validity_score
63
+
64
+
65
+ def fevaluate(gt_smis,op_smis, morgan_r=2):
66
+ outputs = []
67
+ bad_mols = 0
68
+ for n, (gt_smi,ot_smi) in enumerate(zip(gt_smis,op_smis)):
69
+ try:
70
+ gt_m = Chem.MolFromSmiles(gt_smi)
71
+ ot_m = Chem.MolFromSmiles(ot_smi)
72
+ if ot_m == None: raise ValueError('Bad SMILES')
73
+ outputs.append((gt_m, ot_m))
74
+ except:
75
+ bad_mols += 1
76
+ validity_score = len(outputs)/(len(outputs)+bad_mols)
77
+
78
+ MACCS_sims = []
79
+ morgan_sims = []
80
+ RDK_sims = []
81
+ enum_list = outputs
82
+ for i, (gt_m, ot_m) in enumerate(enum_list):
83
+ MACCS_sims.append(DataStructs.FingerprintSimilarity(MACCSkeys.GenMACCSKeys(gt_m), MACCSkeys.GenMACCSKeys(ot_m), metric=DataStructs.TanimotoSimilarity))
84
+ RDK_sims.append(DataStructs.FingerprintSimilarity(Chem.RDKFingerprint(gt_m), Chem.RDKFingerprint(ot_m), metric=DataStructs.TanimotoSimilarity))
85
+ morgan_sims.append(DataStructs.TanimotoSimilarity(AllChem.GetMorganFingerprint(gt_m,morgan_r), AllChem.GetMorganFingerprint(ot_m, morgan_r)))
86
+
87
+ maccs_sims_score = np.mean(MACCS_sims)
88
+ rdk_sims_score = np.mean(RDK_sims)
89
+ morgan_sims_score = np.mean(morgan_sims)
90
+ return validity_score, maccs_sims_score, rdk_sims_score, morgan_sims_score
91
+
92
+ def fcdevaluate(qgt_smis,qop_smis):
93
+ gt_smis = []
94
+ ot_smis = []
95
+ for n, (gt_smi,ot_smi) in enumerate(zip(qgt_smis,qop_smis)):
96
+ if len(ot_smi) == 0: ot_smi = '[]'
97
+ gt_smis.append(gt_smi)
98
+ ot_smis.append(ot_smi)
99
+ model = load_ref_model()
100
+ canon_gt_smis = [w for w in canonical_smiles(gt_smis) if w is not None]
101
+ canon_ot_smis = [w for w in canonical_smiles(ot_smis) if w is not None]
102
+ fcd_sim_score = get_fcd(canon_gt_smis, canon_ot_smis, model)
103
+ return fcd_sim_score
104
+
105
+ os.environ["CUDA_VISIBLE_DEVICES"] = "1"
106
+ gt,op = get_smis('output.txt')
107
+ bleu_score, exact_match_score, levenshtein_score,_ = evaluate(gt,op)
108
+ validity_score, maccs_sims_score, rdk_sims_score, morgan_sims_score = fevaluate(gt,op)
109
+ fcd_metric_score = fcdevaluate(gt,op)
110
+ print(f'BLEU: {round(bleu_score, 3)}')
111
+ print(f'Exact: {round(exact_match_score, 3)}')
112
+ print(f'Levenshtein: {round(levenshtein_score, 3)}')
113
+ print(f'MACCS FTS: {round(maccs_sims_score, 3)}')
114
+ print(f'RDK FTS: {round(rdk_sims_score, 3)}')
115
+ print(f'Morgan FTS: {round(morgan_sims_score, 3)}')
116
+ print(f'FCD Metric: {round(fcd_metric_score, 3)}')
117
+ print(f'Validity: {round(validity_score, 3)}')
src/evaluation/fcd_metric.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Code from https://github.com/blender-nlp/MolT5
3
+
4
+ ```bibtex
5
+ @article{edwards2022translation,
6
+ title={Translation between Molecules and Natural Language},
7
+ author={Edwards, Carl and Lai, Tuan and Ros, Kevin and Honke, Garrett and Ji, Heng},
8
+ journal={arXiv preprint arXiv:2204.11817},
9
+ year={2022}
10
+ }
11
+ ```
12
+ '''
13
+
14
+ import argparse
15
+ import csv
16
+
17
+ import os.path as osp
18
+
19
+ from rdkit import RDLogger
20
+ RDLogger.DisableLog('rdApp.*')
21
+
22
+ from fcd import get_fcd, load_ref_model, canonical_smiles
23
+
24
+ def evaluate(input_file, verbose=False):
25
+ gt_smis = []
26
+ ot_smis = []
27
+
28
+ with open(osp.join(input_file)) as f:
29
+ reader = csv.DictReader(f, delimiter="\t", quoting=csv.QUOTE_NONE)
30
+ for n, line in enumerate(reader):
31
+ gt_smi = line['ground truth']
32
+ ot_smi = line['output']
33
+ if len(ot_smi) == 0: ot_smi = '[]'
34
+
35
+ gt_smis.append(gt_smi)
36
+ ot_smis.append(ot_smi)
37
+
38
+
39
+ model = load_ref_model()
40
+
41
+ canon_gt_smis = [w for w in canonical_smiles(gt_smis) if w is not None]
42
+ canon_ot_smis = [w for w in canonical_smiles(ot_smis) if w is not None]
43
+
44
+ fcd_sim_score = get_fcd(canon_gt_smis, canon_ot_smis, model)
45
+ if verbose:
46
+ print('FCD Similarity:', fcd_sim_score)
47
+
48
+ return fcd_sim_score
49
+
50
+ if __name__ == "__main__":
51
+ parser = argparse.ArgumentParser()
52
+ parser.add_argument('--input_file', type=str, default='caption2smiles_example.txt', help='path where test generations are saved')
53
+ args = parser.parse_args()
54
+ evaluate(args.input_file, True)
src/evaluation/fingerprint_metrics.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Code from https://github.com/blender-nlp/MolT5
3
+
4
+ ```bibtex
5
+ @article{edwards2022translation,
6
+ title={Translation between Molecules and Natural Language},
7
+ author={Edwards, Carl and Lai, Tuan and Ros, Kevin and Honke, Garrett and Ji, Heng},
8
+ journal={arXiv preprint arXiv:2204.11817},
9
+ year={2022}
10
+ }
11
+ ```
12
+ '''
13
+
14
+ import argparse
15
+ import csv
16
+
17
+ import os.path as osp
18
+
19
+ import numpy as np
20
+
21
+ from rdkit import Chem
22
+ from rdkit.Chem import MACCSkeys
23
+ from rdkit import DataStructs
24
+ from rdkit.Chem import AllChem
25
+
26
+ from rdkit import RDLogger
27
+ RDLogger.DisableLog('rdApp.*')
28
+
29
+ def evaluate(input_file, morgan_r, verbose=False):
30
+ outputs = []
31
+ bad_mols = 0
32
+
33
+ with open(osp.join(input_file)) as f:
34
+ reader = csv.DictReader(f, delimiter="\t", quoting=csv.QUOTE_NONE)
35
+ for n, line in enumerate(reader):
36
+ try:
37
+ gt_smi = line['ground truth']
38
+ ot_smi = line['output']
39
+ gt_m = Chem.MolFromSmiles(gt_smi)
40
+ ot_m = Chem.MolFromSmiles(ot_smi)
41
+
42
+ if ot_m == None: raise ValueError('Bad SMILES')
43
+ outputs.append((line['description'], gt_m, ot_m))
44
+ except:
45
+ bad_mols += 1
46
+ validity_score = len(outputs)/(len(outputs)+bad_mols)
47
+ if verbose:
48
+ print('validity:', validity_score)
49
+
50
+
51
+ MACCS_sims = []
52
+ morgan_sims = []
53
+ RDK_sims = []
54
+
55
+ enum_list = outputs
56
+
57
+ for i, (desc, gt_m, ot_m) in enumerate(enum_list):
58
+
59
+ if i % 100 == 0:
60
+ if verbose: print(i, 'processed.')
61
+
62
+ MACCS_sims.append(DataStructs.FingerprintSimilarity(MACCSkeys.GenMACCSKeys(gt_m), MACCSkeys.GenMACCSKeys(ot_m), metric=DataStructs.TanimotoSimilarity))
63
+ RDK_sims.append(DataStructs.FingerprintSimilarity(Chem.RDKFingerprint(gt_m), Chem.RDKFingerprint(ot_m), metric=DataStructs.TanimotoSimilarity))
64
+ morgan_sims.append(DataStructs.TanimotoSimilarity(AllChem.GetMorganFingerprint(gt_m,morgan_r), AllChem.GetMorganFingerprint(ot_m, morgan_r)))
65
+
66
+ maccs_sims_score = np.mean(MACCS_sims)
67
+ rdk_sims_score = np.mean(RDK_sims)
68
+ morgan_sims_score = np.mean(morgan_sims)
69
+ if verbose:
70
+ print('Average MACCS Similarity:', maccs_sims_score)
71
+ print('Average RDK Similarity:', rdk_sims_score)
72
+ print('Average Morgan Similarity:', morgan_sims_score)
73
+ return validity_score, maccs_sims_score, rdk_sims_score, morgan_sims_score
74
+
75
+ if __name__ == "__main__":
76
+ parser = argparse.ArgumentParser()
77
+ parser.add_argument('--input_file', type=str, default='caption2smiles_example.txt', help='path where test generations are saved')
78
+ parser.add_argument('--morgan_r', type=int, default=2, help='morgan fingerprint radius')
79
+ args = parser.parse_args()
80
+
81
+ evaluate(args.input_file, args.morgan_r, True)
src/evaluation/mol_translation_metrics.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Code from https://github.com/blender-nlp/MolT5
3
+
4
+ ```bibtex
5
+ @article{edwards2022translation,
6
+ title={Translation between Molecules and Natural Language},
7
+ author={Edwards, Carl and Lai, Tuan and Ros, Kevin and Honke, Garrett and Ji, Heng},
8
+ journal={arXiv preprint arXiv:2204.11817},
9
+ year={2022}
10
+ }
11
+ ```
12
+ '''
13
+
14
+
15
+ import pickle
16
+ import argparse
17
+ import csv
18
+
19
+ import os.path as osp
20
+
21
+ import numpy as np
22
+
23
+ #load metric stuff
24
+
25
+ from nltk.translate.bleu_score import corpus_bleu
26
+ #from nltk.translate.meteor_score import meteor_score
27
+
28
+ from Levenshtein import distance as lev
29
+
30
+ from rdkit import Chem
31
+
32
+ from rdkit import RDLogger
33
+ RDLogger.DisableLog('rdApp.*')
34
+
35
+ def evaluate(input_fp, verbose=False):
36
+ outputs = []
37
+
38
+ with open(osp.join(input_fp)) as f:
39
+ reader = csv.DictReader(f, delimiter="\t", quoting=csv.QUOTE_NONE)
40
+ for n, line in enumerate(reader):
41
+ gt_smi = line['ground truth']
42
+ ot_smi = line['output']
43
+ outputs.append((line['description'], gt_smi, ot_smi))
44
+
45
+
46
+ bleu_scores = []
47
+ #meteor_scores = []
48
+
49
+ references = []
50
+ hypotheses = []
51
+
52
+ for i, (smi, gt, out) in enumerate(outputs):
53
+
54
+ if i % 100 == 0:
55
+ if verbose:
56
+ print(i, 'processed.')
57
+
58
+
59
+ gt_tokens = [c for c in gt]
60
+
61
+ out_tokens = [c for c in out]
62
+
63
+ references.append([gt_tokens])
64
+ hypotheses.append(out_tokens)
65
+
66
+ # mscore = meteor_score([gt], out)
67
+ # meteor_scores.append(mscore)
68
+
69
+ # BLEU score
70
+ bleu_score = corpus_bleu(references, hypotheses)
71
+ if verbose: print('BLEU score:', bleu_score)
72
+
73
+ # Meteor score
74
+ # _meteor_score = np.mean(meteor_scores)
75
+ # print('Average Meteor score:', _meteor_score)
76
+
77
+ rouge_scores = []
78
+
79
+ references = []
80
+ hypotheses = []
81
+
82
+ levs = []
83
+
84
+ num_exact = 0
85
+
86
+ bad_mols = 0
87
+
88
+ for i, (smi, gt, out) in enumerate(outputs):
89
+
90
+ hypotheses.append(out)
91
+ references.append(gt)
92
+
93
+ try:
94
+ m_out = Chem.MolFromSmiles(out)
95
+ m_gt = Chem.MolFromSmiles(gt)
96
+
97
+ if Chem.MolToInchi(m_out) == Chem.MolToInchi(m_gt): num_exact += 1
98
+ #if gt == out: num_exact += 1 #old version that didn't standardize strings
99
+ except:
100
+ bad_mols += 1
101
+
102
+
103
+
104
+ levs.append(lev(out, gt))
105
+
106
+
107
+ # Exact matching score
108
+ exact_match_score = num_exact/(i+1)
109
+ if verbose:
110
+ print('Exact Match:')
111
+ print(exact_match_score)
112
+
113
+ # Levenshtein score
114
+ levenshtein_score = np.mean(levs)
115
+ if verbose:
116
+ print('Levenshtein:')
117
+ print(levenshtein_score)
118
+
119
+ validity_score = 1 - bad_mols/len(outputs)
120
+ if verbose:
121
+ print('validity:', validity_score)
122
+
123
+ return bleu_score, exact_match_score, levenshtein_score, validity_score
124
+
125
+ if __name__ == "__main__":
126
+ parser = argparse.ArgumentParser()
127
+ parser.add_argument('--input_file', type=str, default='caption2smiles_example.txt', help='path where test generations are saved')
128
+ args = parser.parse_args()
129
+ evaluate(args.input_file, verbose=True)
src/improved_diffusion/__init__.py ADDED
File without changes
src/improved_diffusion/dist_util.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Helpers for distributed training.
3
+ """
4
+
5
+ import io
6
+ import os
7
+ import socket
8
+
9
+ import blobfile as bf
10
+
11
+ from mpi4py import MPI
12
+ import torch as th
13
+ import torch.distributed as dist
14
+
15
+ # Change this to reflect your cluster layout.
16
+ # The GPU for a given rank is (rank % GPUS_PER_NODE).
17
+ GPUS_PER_NODE = 1 # 8
18
+
19
+ SETUP_RETRY_COUNT = 3
20
+
21
+
22
+ def setup_dist(rank, world_size, port="12145"):
23
+ """
24
+ Setup a distributed process group.
25
+ """
26
+ if dist.is_initialized():
27
+ return
28
+
29
+ # comm = MPI.COMM_WORLD
30
+ # backend = "gloo" if not th.cuda.is_available() else "nccl"
31
+
32
+ # if backend == "gloo":
33
+ # hostname = "localhost"
34
+ # else:
35
+ # hostname = socket.gethostbyname(socket.getfqdn())
36
+ # os.environ["MASTER_ADDR"] = comm.bcast(hostname, root=0)
37
+ # os.environ["RANK"] = str(comm.rank)
38
+ # os.environ["WORLD_SIZE"] = str(comm.size)
39
+
40
+ # port = comm.bcast(_find_free_port(), root=0)
41
+ # os.environ["MASTER_PORT"] = str(port)
42
+
43
+ # dist.init_process_group(backend=backend, init_method="env://")
44
+ os.environ["MASTER_ADDR"] = "localhost"
45
+ os.environ["MASTER_PORT"] = port
46
+ dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
47
+
48
+
49
+ def dev():
50
+ """
51
+ Get the device to use for torch.distributed.
52
+ """
53
+ if th.cuda.is_available():
54
+ return th.device(f"cuda:{MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE}")
55
+ return th.device("cpu")
56
+
57
+
58
+ def load_state_dict(path, **kwargs):
59
+ """
60
+ Load a PyTorch file without redundant fetches across MPI ranks.
61
+ """
62
+ if MPI.COMM_WORLD.Get_rank() == 0:
63
+ with bf.BlobFile(path, "rb") as f:
64
+ data = f.read()
65
+ else:
66
+ data = None
67
+ data = MPI.COMM_WORLD.bcast(data)
68
+ return th.load(io.BytesIO(data), **kwargs)
69
+
70
+
71
+ def sync_params(params):
72
+ """
73
+ Synchronize a sequence of Tensors across ranks from rank 0.
74
+ """
75
+ for p in params:
76
+ with th.no_grad():
77
+ dist.broadcast(p, 0)
78
+
79
+
80
+ def _find_free_port():
81
+ try:
82
+ s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
83
+ s.bind(("", 0))
84
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
85
+ return s.getsockname()[1]
86
+ finally:
87
+ s.close()
src/improved_diffusion/fp16_util.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Helpers to train with 16-bit precision.
3
+ """
4
+
5
+ import torch.nn as nn
6
+ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
7
+
8
+
9
+ def convert_module_to_f16(l):
10
+ """
11
+ Convert primitive modules to float16.
12
+ """
13
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
14
+ l.weight.data = l.weight.data.half()
15
+ l.bias.data = l.bias.data.half()
16
+
17
+
18
+ def convert_module_to_f32(l):
19
+ """
20
+ Convert primitive modules to float32, undoing convert_module_to_f16().
21
+ """
22
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
23
+ l.weight.data = l.weight.data.float()
24
+ l.bias.data = l.bias.data.float()
25
+
26
+
27
+ def make_master_params(model_params):
28
+ """
29
+ Copy model parameters into a (differently-shaped) list of full-precision
30
+ parameters.
31
+ """
32
+ master_params = _flatten_dense_tensors(
33
+ [param.detach().float() for param in model_params]
34
+ )
35
+ master_params = nn.Parameter(master_params)
36
+ master_params.requires_grad = True
37
+ return [master_params]
38
+
39
+
40
+ def model_grads_to_master_grads(model_params, master_params):
41
+ """
42
+ Copy the gradients from the model parameters into the master parameters
43
+ from make_master_params().
44
+ """
45
+ master_params[0].grad = _flatten_dense_tensors(
46
+ [param.grad.data.detach().float() for param in model_params]
47
+ )
48
+
49
+
50
+ def master_params_to_model_params(model_params, master_params):
51
+ """
52
+ Copy the master parameter data back into the model parameters.
53
+ """
54
+ # Without copying to a list, if a generator is passed, this will
55
+ # silently not copy any parameters.
56
+ model_params = list(model_params)
57
+
58
+ for param, master_param in zip(
59
+ model_params, unflatten_master_params(model_params, master_params)
60
+ ):
61
+ param.detach().copy_(master_param)
62
+
63
+
64
+ def unflatten_master_params(model_params, master_params):
65
+ """
66
+ Unflatten the master parameters to look like model_params.
67
+ """
68
+ return _unflatten_dense_tensors(master_params[0].detach(), model_params)
69
+
70
+
71
+ def zero_grad(model_params):
72
+ for param in model_params:
73
+ # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group
74
+ if param.grad is not None:
75
+ param.grad.detach_()
76
+ param.grad.zero_()
src/improved_diffusion/gaussian_diffusion.py ADDED
@@ -0,0 +1,1606 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This code started out as a PyTorch port of Ho et al's diffusion models:
3
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py
4
+
5
+ Docstrings have been added, as well as DDIM sampling and a new collection of beta schedules.
6
+ """
7
+
8
+ import enum
9
+ import math
10
+ import torch
11
+ import numpy as np
12
+
13
+ from .nn import mean_flat
14
+ from .losses import normal_kl, discretized_gaussian_log_likelihood
15
+
16
+
17
+ def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
18
+ """
19
+ Get a pre-defined beta schedule for the given name.
20
+
21
+ The beta schedule library consists of beta schedules which remain similar
22
+ in the limit of num_diffusion_timesteps.
23
+ Beta schedules may be added, but should not be removed or changed once
24
+ they are committed to maintain backwards compatibility.
25
+ """
26
+ if schedule_name == "linear":
27
+ # Linear schedule from Ho et al, extended to work for any number of
28
+ # diffusion steps.
29
+ scale = 1000 / num_diffusion_timesteps
30
+ beta_start = scale * 0.0001
31
+ beta_end = scale * 0.02
32
+ return np.linspace(
33
+ beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
34
+ )
35
+ elif schedule_name == "cosine":
36
+ return betas_for_alpha_bar(
37
+ num_diffusion_timesteps,
38
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
39
+ )
40
+ elif schedule_name == "sqrt":
41
+ return betas_for_alpha_bar(
42
+ num_diffusion_timesteps,
43
+ lambda t: 1 - np.sqrt(t + 0.0001),
44
+ )
45
+ elif schedule_name == "trunc_cos":
46
+ return betas_for_alpha_bar2(
47
+ num_diffusion_timesteps,
48
+ lambda t: np.cos((t + 0.1) / 1.1 * np.pi / 2) ** 2,
49
+ )
50
+ elif schedule_name == "trunc_lin":
51
+ scale = 1000 / num_diffusion_timesteps
52
+ beta_start = scale * 0.0001 + 0.01
53
+ beta_end = scale * 0.02 + 0.01
54
+ return np.linspace(
55
+ beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
56
+ )
57
+ elif schedule_name == "pw_lin":
58
+ scale = 1000 / num_diffusion_timesteps
59
+ beta_start = scale * 0.0001 + 0.01
60
+ beta_mid = scale * 0.0001 # scale * 0.02
61
+ beta_end = scale * 0.02
62
+ first_part = np.linspace(beta_start, beta_mid, 10, dtype=np.float64)
63
+ second_part = np.linspace(
64
+ beta_mid, beta_end, num_diffusion_timesteps - 10, dtype=np.float64
65
+ )
66
+ return np.concatenate([first_part, second_part])
67
+ else:
68
+ raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
69
+
70
+
71
+ def betas_for_alpha_bar2(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
72
+ """
73
+ Create a beta schedule that discretizes the given alpha_t_bar function,
74
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
75
+
76
+ :param num_diffusion_timesteps: the number of betas to produce.
77
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
78
+ produces the cumulative product of (1-beta) up to that
79
+ part of the diffusion process.
80
+ :param max_beta: the maximum beta to use; use values lower than 1 to
81
+ prevent singularities.
82
+ """
83
+ betas = []
84
+ betas.append(min(1 - alpha_bar(0), max_beta))
85
+ for i in range(num_diffusion_timesteps - 1):
86
+ t1 = i / num_diffusion_timesteps
87
+ t2 = (i + 1) / num_diffusion_timesteps
88
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
89
+ return np.array(betas)
90
+
91
+
92
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
93
+ """
94
+ Create a beta schedule that discretizes the given alpha_t_bar function,
95
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
96
+
97
+ :param num_diffusion_timesteps: the number of betas to produce.
98
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
99
+ produces the cumulative product of (1-beta) up to that
100
+ part of the diffusion process.
101
+ :param max_beta: the maximum beta to use; use values lower than 1 to
102
+ prevent singularities.
103
+ """
104
+ betas = []
105
+ for i in range(num_diffusion_timesteps):
106
+ t1 = i / num_diffusion_timesteps
107
+ t2 = (i + 1) / num_diffusion_timesteps
108
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
109
+ return np.array(betas)
110
+
111
+
112
+ class ModelMeanType(enum.Enum):
113
+ """
114
+ Which type of output the model predicts.
115
+ """
116
+
117
+ PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
118
+ START_X = enum.auto() # the model predicts x_0
119
+ EPSILON = enum.auto() # the model predicts epsilon
120
+
121
+
122
+ class ModelVarType(enum.Enum):
123
+ """
124
+ What is used as the model's output variance.
125
+
126
+ The LEARNED_RANGE option has been added to allow the model to predict
127
+ values between FIXED_SMALL and FIXED_LARGE, making its job easier.
128
+ """
129
+
130
+ LEARNED = enum.auto()
131
+ FIXED_SMALL = enum.auto()
132
+ FIXED_LARGE = enum.auto()
133
+ LEARNED_RANGE = enum.auto()
134
+
135
+
136
+ class LossType(enum.Enum):
137
+ MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
138
+ RESCALED_MSE = (
139
+ enum.auto()
140
+ ) # use raw MSE loss (with RESCALED_KL when learning variances)
141
+ KL = enum.auto() # use the variational lower-bound
142
+ RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
143
+ E2E_KL = enum.auto()
144
+ E2E_MSE = enum.auto()
145
+ E2E_Simple_MSE = enum.auto()
146
+ E2E_Simple_KL = enum.auto()
147
+
148
+ def is_vb(self):
149
+ return self == LossType.KL or self == LossType.RESCALED_KL
150
+
151
+
152
+ class GaussianDiffusion:
153
+ """
154
+ Utilities for training and sampling diffusion models.
155
+
156
+ Ported directly from here, and then adapted over time to further experimentation.
157
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
158
+
159
+ :param betas: a 1-D numpy array of betas for each diffusion timestep,
160
+ starting at T and going to 1.
161
+ :param model_mean_type: a ModelMeanType determining what the model outputs.
162
+ :param model_var_type: a ModelVarType determining how variance is output.
163
+ :param loss_type: a LossType determining the loss function to use.
164
+ :param rescale_timesteps: if True, pass floating point timesteps into the
165
+ model so that they are always scaled like in the
166
+ original paper (0 to 1000).
167
+ """
168
+
169
+ def __init__(
170
+ self,
171
+ *,
172
+ betas,
173
+ model_mean_type,
174
+ model_var_type,
175
+ loss_type,
176
+ rescale_timesteps=False,
177
+ model_arch=None,
178
+ training_mode="emb",
179
+ ):
180
+ self.model_mean_type = model_mean_type
181
+ self.model_var_type = model_var_type
182
+ self.loss_type = loss_type
183
+ self.rescale_timesteps = rescale_timesteps
184
+ self.model_arch = model_arch
185
+
186
+ # Use float64 for accuracy.
187
+ betas = np.array(betas, dtype=np.float64)
188
+ self.betas = betas
189
+ assert len(betas.shape) == 1, "betas must be 1-D"
190
+ assert (betas > 0).all() and (betas <= 1).all()
191
+
192
+ self.num_timesteps = int(betas.shape[0])
193
+
194
+ alphas = 1.0 - betas
195
+ self.alphas_cumprod = np.cumprod(alphas, axis=0)
196
+ self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
197
+ self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
198
+ assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
199
+
200
+ # calculations for diffusion q(x_t | x_{t-1}) and others
201
+ self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
202
+ self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
203
+ self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
204
+ self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
205
+ self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
206
+
207
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
208
+ self.posterior_variance = (
209
+ betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
210
+ )
211
+ # log calculation clipped because the posterior variance is 0 at the
212
+ # beginning of the diffusion chain.
213
+ self.posterior_log_variance_clipped = np.log(
214
+ np.append(self.posterior_variance[1], self.posterior_variance[1:])
215
+ )
216
+ self.posterior_mean_coef1 = (
217
+ betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
218
+ )
219
+ self.posterior_mean_coef2 = (
220
+ (1.0 - self.alphas_cumprod_prev)
221
+ * np.sqrt(alphas)
222
+ / (1.0 - self.alphas_cumprod)
223
+ )
224
+
225
+ self.training_mode = training_mode
226
+ self.mapping_func = None
227
+ #
228
+ # if training_mode == 'e2e':
229
+ # self.training_losses = self.training_losses_e2e
230
+ # else:
231
+ # self.training_losses = self.training_losses_emb
232
+ self.maxt = -1
233
+
234
+ def training_losses(self, model, *args, **kwargs):
235
+ return self.training_losses_e2e(model, *args, **kwargs)
236
+ # if self.training_mode == "e2e":
237
+ # return self.training_losses_e2e(model, *args, **kwargs)
238
+ # elif self.training_mode == "e2e-simple":
239
+ # return self.training_losses_e2e_simple(model, *args, **kwargs)
240
+ # else:
241
+ # return self.training_losses_emb(model, *args, **kwargs)
242
+
243
+ def calc_bpd_loop(self, model, *args, **kwargs):
244
+ if self.training_mode == "e2e":
245
+ return self.calc_bpd_loop_e2e(model, *args, **kwargs)
246
+ else:
247
+ return self.calc_bpd_loop_emb(model, *args, **kwargs)
248
+
249
+ def q_mean_variance(self, x_start, t):
250
+ """
251
+ Get the distribution q(x_t | x_0).
252
+
253
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
254
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
255
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
256
+ """
257
+ mean = (
258
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
259
+ )
260
+ variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
261
+ log_variance = _extract_into_tensor(
262
+ self.log_one_minus_alphas_cumprod, t, x_start.shape
263
+ )
264
+ return mean, variance, log_variance
265
+
266
+ def q_sample(self, x_start, t, noise=None):
267
+ """
268
+ Diffuse the data for a given number of diffusion steps.
269
+
270
+ In other words, sample from q(x_t | x_0).
271
+
272
+ :param x_start: the initial data batch.
273
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
274
+ :param noise: if specified, the split-out normal noise.
275
+ :return: A noisy version of x_start.
276
+ """
277
+ if noise is None:
278
+ noise = torch.randn_like(x_start)
279
+ assert noise.shape == x_start.shape
280
+ return (
281
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
282
+ + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
283
+ * noise
284
+ )
285
+
286
+ def q_posterior_mean_variance(self, x_start, x_t, t):
287
+ """
288
+ Compute the mean and variance of the diffusion posterior:
289
+
290
+ q(x_{t-1} | x_t, x_0)
291
+
292
+ """
293
+ assert x_start.shape == x_t.shape
294
+ posterior_mean = (
295
+ _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
296
+ + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
297
+ )
298
+ posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
299
+ posterior_log_variance_clipped = _extract_into_tensor(
300
+ self.posterior_log_variance_clipped, t, x_t.shape
301
+ )
302
+ assert (
303
+ posterior_mean.shape[0]
304
+ == posterior_variance.shape[0]
305
+ == posterior_log_variance_clipped.shape[0]
306
+ == x_start.shape[0]
307
+ )
308
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
309
+
310
+ def p_mean_variance(
311
+ self,
312
+ model,
313
+ x,
314
+ t,
315
+ clip_denoised=True,
316
+ denoised_fn=None,
317
+ model_kwargs=None,
318
+ caption=None,
319
+ ):
320
+ """
321
+ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
322
+ the initial x, x_0.
323
+
324
+ :param model: the model, which takes a signal and a batch of timesteps
325
+ as input.
326
+ :param x: the [N x C x ...] tensor at time t.
327
+ :param t: a 1-D Tensor of timesteps.
328
+ :param clip_denoised: if True, clip the denoised signal into [-1, 1].
329
+ :param denoised_fn: if not None, a function which applies to the
330
+ x_start prediction before it is used to sample. Applies before
331
+ clip_denoised.
332
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
333
+ pass to the model. This can be used for conditioning.
334
+ :return: a dict with the following keys:
335
+ - 'mean': the model mean output.
336
+ - 'variance': the model variance output.
337
+ - 'log_variance': the log of 'variance'.
338
+ - 'pred_xstart': the prediction for x_0.
339
+ """
340
+ caption_state, caption_mask = caption[0], caption[1]
341
+ if model_kwargs is None:
342
+ model_kwargs = {}
343
+ if self.model_arch == "conv-unet" or self.model_arch == "1d-unet":
344
+ B, C = x.shape[:2]
345
+ else:
346
+ B, C = x.size(0), x.size(-1)
347
+ assert t.shape == (B,)
348
+ # print(x.shape)
349
+ model_output = model(
350
+ x, self._scale_timesteps(t), caption_state, caption_mask, **model_kwargs
351
+ )
352
+
353
+ if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
354
+ if self.model_arch == "conv-unet":
355
+ assert model_output.shape == (B, C * 2, *x.shape[2:])
356
+ model_output, model_var_values = torch.split(model_output, C, dim=1)
357
+ # print('conv-unet')
358
+ elif self.model_arch == "1d-unet":
359
+ assert model_output.shape == (B, C * 2, *x.shape[2:])
360
+ model_output, model_var_values = torch.split(model_output, C, dim=1)
361
+ else:
362
+ assert model_output.shape == (B, x.size(1), C * 2)
363
+ model_output, model_var_values = torch.split(model_output, C, dim=-1)
364
+
365
+ if self.model_var_type == ModelVarType.LEARNED:
366
+ model_log_variance = model_var_values
367
+ model_variance = torch.exp(model_log_variance)
368
+ else:
369
+ min_log = _extract_into_tensor(
370
+ self.posterior_log_variance_clipped, t, x.shape
371
+ )
372
+ max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
373
+ # The model_var_values is [-1, 1] for [min_var, max_var].
374
+ frac = (model_var_values + 1) / 2
375
+ model_log_variance = frac * max_log + (1 - frac) * min_log
376
+ model_variance = torch.exp(model_log_variance)
377
+ else:
378
+ model_variance, model_log_variance = {
379
+ # for fixedlarge, we set the initial (log-)variance like so
380
+ # to get a better decoder log likelihood.
381
+ ModelVarType.FIXED_LARGE: (
382
+ np.append(self.posterior_variance[1], self.betas[1:]),
383
+ np.log(np.append(self.posterior_variance[1], self.betas[1:])),
384
+ ),
385
+ ModelVarType.FIXED_SMALL: (
386
+ self.posterior_variance,
387
+ self.posterior_log_variance_clipped,
388
+ ),
389
+ }[self.model_var_type]
390
+ model_variance = _extract_into_tensor(model_variance, t, x.shape)
391
+ model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
392
+
393
+ def process_xstart(x):
394
+ if denoised_fn is not None:
395
+ # print(denoised_fn)
396
+ x = denoised_fn(x, t)
397
+ if clip_denoised:
398
+ return x.clamp(-1, 1)
399
+ return x
400
+
401
+ if self.model_mean_type == ModelMeanType.PREVIOUS_X:
402
+ pred_xstart = process_xstart(
403
+ self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output)
404
+ )
405
+ model_mean = model_output
406
+ elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]:
407
+ if self.model_mean_type == ModelMeanType.START_X:
408
+ pred_xstart = process_xstart(model_output)
409
+ else:
410
+ pred_xstart = process_xstart(
411
+ self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
412
+ )
413
+ model_mean, _, _ = self.q_posterior_mean_variance(
414
+ x_start=pred_xstart, x_t=x, t=t
415
+ )
416
+ else:
417
+ raise NotImplementedError(self.model_mean_type)
418
+
419
+ assert (
420
+ model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
421
+ )
422
+ return {
423
+ "mean": model_mean,
424
+ "variance": model_variance,
425
+ "log_variance": model_log_variance,
426
+ "pred_xstart": pred_xstart,
427
+ }
428
+
429
+ def _predict_xstart_from_eps(self, x_t, t, eps):
430
+ assert x_t.shape == eps.shape
431
+ return (
432
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
433
+ - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
434
+ )
435
+
436
+ def _predict_xstart_from_xprev(self, x_t, t, xprev):
437
+ assert x_t.shape == xprev.shape
438
+ return ( # (xprev - coef2*x_t) / coef1
439
+ _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev
440
+ - _extract_into_tensor(
441
+ self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape
442
+ )
443
+ * x_t
444
+ )
445
+
446
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
447
+ return (
448
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
449
+ - pred_xstart
450
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
451
+
452
+ def _scale_timesteps(self, t):
453
+ if self.rescale_timesteps:
454
+ return t.float() * (1000.0 / self.num_timesteps)
455
+ return t
456
+
457
+ def p_sample(
458
+ self,
459
+ model,
460
+ x,
461
+ t,
462
+ clip_denoised=True,
463
+ denoised_fn=None,
464
+ model_kwargs=None,
465
+ top_p=None,
466
+ caption=None,
467
+ ):
468
+ """
469
+ Sample x_{t-1} from the model at the given timestep.
470
+
471
+ :param model: the model to sample from.
472
+ :param x: the current tensor at x_{t-1}.
473
+ :param t: the value of t, starting at 0 for the first diffusion step.
474
+ :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
475
+ :param denoised_fn: if not None, a function which applies to the
476
+ x_start prediction before it is used to sample.
477
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
478
+ pass to the model. This can be used for conditioning.
479
+ :return: a dict containing the following keys:
480
+ - 'sample': a random sample from the model.
481
+ - 'pred_xstart': a prediction of x_0.
482
+ """
483
+ out = self.p_mean_variance(
484
+ model,
485
+ x,
486
+ t,
487
+ clip_denoised=clip_denoised,
488
+ denoised_fn=denoised_fn,
489
+ model_kwargs=model_kwargs,
490
+ caption=caption,
491
+ )
492
+ if top_p is not None and top_p > 0:
493
+ # print('top_p sampling')
494
+ noise = torch.randn_like(x)
495
+ replace_mask = torch.abs(noise) > top_p
496
+ while replace_mask.any():
497
+ noise[replace_mask] = torch.randn_like(noise[replace_mask])
498
+ replace_mask = torch.abs(noise) > top_p
499
+ assert (torch.abs(noise) <= top_p).all()
500
+
501
+ else:
502
+ noise = torch.randn_like(x)
503
+ nonzero_mask = (
504
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
505
+ ) # no noise when t == 0
506
+ sample = (
507
+ out["mean"] + nonzero_mask * torch.exp(0.5 * out["log_variance"]) * noise
508
+ )
509
+ return {
510
+ "sample": sample,
511
+ "pred_xstart": out["pred_xstart"],
512
+ "greedy_mean": out["mean"],
513
+ "out": out,
514
+ }
515
+
516
+ def p_debug_loop(
517
+ self,
518
+ model,
519
+ shape,
520
+ noise=None,
521
+ clip_denoised=True,
522
+ denoised_fn=None,
523
+ model_kwargs=None,
524
+ device=None,
525
+ progress=False,
526
+ ):
527
+ final = None
528
+ for sample in self.p_debug_loop_progressive(
529
+ model,
530
+ shape,
531
+ noise=noise,
532
+ clip_denoised=clip_denoised,
533
+ denoised_fn=denoised_fn,
534
+ model_kwargs=model_kwargs,
535
+ device=device,
536
+ progress=progress,
537
+ ):
538
+ final = sample
539
+ return final["sample"]
540
+
541
+ def p_debug_loop_progressive(
542
+ self,
543
+ model,
544
+ shape,
545
+ noise=None,
546
+ clip_denoised=True,
547
+ denoised_fn=None,
548
+ model_kwargs=None,
549
+ device=None,
550
+ progress=False,
551
+ custom_t_start=100,
552
+ ):
553
+ """
554
+ Generate samples from the model and yield intermediate samples from
555
+ each timestep of diffusion.
556
+
557
+ Arguments are the same as p_sample_loop().
558
+ Returns a generator over dicts, where each dict is the return value of
559
+ p_sample().
560
+ """
561
+ if device is None:
562
+ device = next(model.parameters()).device
563
+ assert isinstance(shape, (tuple, list))
564
+ if noise is not None:
565
+ img = noise
566
+ else:
567
+ img = torch.randn(*shape, device=device)
568
+ indices = list(range(custom_t_start))[::-1]
569
+
570
+ if progress:
571
+ # Lazy import so that we don't depend on tqdm.
572
+ from tqdm.auto import tqdm
573
+
574
+ indices = tqdm(indices)
575
+
576
+ for i in indices:
577
+ t = torch.tensor([i] * shape[0], device=device)
578
+ with torch.no_grad():
579
+ out = self.p_sample(
580
+ model,
581
+ img,
582
+ t,
583
+ clip_denoised=clip_denoised,
584
+ denoised_fn=denoised_fn,
585
+ model_kwargs=model_kwargs,
586
+ )
587
+ yield out
588
+ img = out["sample"]
589
+
590
+ def p_sample_loop(
591
+ self,
592
+ model,
593
+ shape,
594
+ noise=None,
595
+ clip_denoised=True,
596
+ denoised_fn=None,
597
+ model_kwargs=None,
598
+ device=None,
599
+ progress=False,
600
+ top_p=None,
601
+ caption=None,
602
+ ):
603
+ """
604
+ Generate samples from the model.
605
+
606
+ :param model: the model module.
607
+ :param shape: the shape of the samples, (N, C, H, W).
608
+ :param noise: if specified, the noise from the encoder to sample.
609
+ Should be of the same shape as `shape`.
610
+ :param clip_denoised: if True, clip x_start predictions to [-1, 1].
611
+ :param denoised_fn: if not None, a function which applies to the
612
+ x_start prediction before it is used to sample.
613
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
614
+ pass to the model. This can be used for conditioning.
615
+ :param device: if specified, the device to create the samples on.
616
+ If not specified, use a model parameter's device.
617
+ :param progress: if True, show a tqdm progress bar.
618
+ :return: a non-differentiable batch of samples.
619
+ """
620
+ final = None
621
+ for sample in self.p_sample_loop_progressive(
622
+ model,
623
+ shape,
624
+ noise=noise,
625
+ clip_denoised=clip_denoised,
626
+ denoised_fn=denoised_fn,
627
+ model_kwargs=model_kwargs,
628
+ device=device,
629
+ progress=progress,
630
+ top_p=top_p,
631
+ caption=caption,
632
+ ):
633
+ final = sample
634
+ return final["sample"]
635
+
636
+ def p_sample_loop_progressive(
637
+ self,
638
+ model,
639
+ shape,
640
+ noise=None,
641
+ clip_denoised=True,
642
+ denoised_fn=None,
643
+ model_kwargs=None,
644
+ device=None,
645
+ progress=False,
646
+ top_p=None,
647
+ caption=None,
648
+ ):
649
+ """
650
+ Generate samples from the model and yield intermediate samples from
651
+ each timestep of diffusion.
652
+
653
+ Arguments are the same as p_sample_loop().
654
+ Returns a generator over dicts, where each dict is the return value of
655
+ p_sample().
656
+ """
657
+ if device is None:
658
+ device = next(model.parameters()).device
659
+ assert isinstance(shape, (tuple, list))
660
+ if noise is not None:
661
+ img = noise.to(device)
662
+ else:
663
+ img = torch.randn(*shape, device=device)
664
+ indices = list(range(self.num_timesteps))[::-1]
665
+ # print(indices[-10:])
666
+ # indices = indices[:-1]+[1,1,1,1,1,1,1]*60+[0]
667
+ # print(indices[-10:])
668
+ if progress:
669
+ # Lazy import so that we don't depend on tqdm.
670
+ from tqdm.auto import tqdm
671
+
672
+ indices = tqdm(indices)
673
+ if caption is not None:
674
+ print("Text Guiding Generation ......")
675
+ caption = (
676
+ caption[0].to(img.device),
677
+ caption[1].to(img.device),
678
+ ) # (caption_state, caption_mask)
679
+ for i in indices:
680
+ t = torch.tensor([i] * shape[0], device=device)
681
+ with torch.no_grad():
682
+ out = self.p_sample(
683
+ model,
684
+ img,
685
+ t,
686
+ clip_denoised=clip_denoised,
687
+ denoised_fn=denoised_fn,
688
+ model_kwargs=model_kwargs,
689
+ top_p=top_p,
690
+ caption=caption,
691
+ )
692
+ yield out
693
+ img = out["sample"]
694
+
695
+ def p_sample_loop_langevin_progressive(
696
+ self,
697
+ model,
698
+ shape,
699
+ noise=None,
700
+ clip_denoised=True,
701
+ denoised_fn=None,
702
+ model_kwargs=None,
703
+ device=None,
704
+ progress=False,
705
+ langevin_func=None,
706
+ top_p=None,
707
+ ):
708
+ """
709
+ Generate samples from the model and yield intermediate samples from
710
+ each timestep of diffusion.
711
+
712
+ Arguments are the same as p_sample_loop().
713
+ Returns a generator over dicts, where each dict is the return value of
714
+ p_sample().
715
+ """
716
+ if device is None:
717
+ device = next(model.parameters()).device
718
+ assert isinstance(shape, (tuple, list))
719
+ if noise is not None:
720
+ img = noise
721
+ else:
722
+ img = torch.randn(*shape, device=device)
723
+ indices = list(range(self.num_timesteps))[::-1]
724
+
725
+ if progress:
726
+ # Lazy import so that we don't depend on tqdm.
727
+ from tqdm.auto import tqdm
728
+
729
+ indices = tqdm(indices)
730
+
731
+ for i in indices:
732
+ t = torch.tensor([i] * shape[0], device=device)
733
+ with torch.no_grad():
734
+ out = self.p_sample(
735
+ model,
736
+ img,
737
+ t,
738
+ clip_denoised=clip_denoised,
739
+ denoised_fn=denoised_fn,
740
+ model_kwargs=model_kwargs,
741
+ top_p=top_p,
742
+ )
743
+ if langevin_func is not None:
744
+ out["t"] = t
745
+ out["img"] = img
746
+ out = langevin_func(out)
747
+ yield out
748
+ img = out["sample"]
749
+
750
+ def p_sample_loop_progressive_infill(
751
+ self,
752
+ model,
753
+ shape,
754
+ partial_enc,
755
+ partial_mask,
756
+ noise=None,
757
+ clip_denoised=True,
758
+ denoised_fn=None,
759
+ model_kwargs=None,
760
+ device=None,
761
+ progress=False,
762
+ greedy=False,
763
+ ):
764
+ """
765
+ Generate samples from the model and yield intermediate samples from
766
+ each timestep of diffusion.
767
+
768
+ Arguments are the same as p_sample_loop().
769
+ Returns a generator over dicts, where each dict is the return value of
770
+ p_sample().
771
+ """
772
+ if device is None:
773
+ device = next(model.parameters()).device
774
+ assert isinstance(shape, (tuple, list))
775
+ if noise is not None:
776
+ img = noise
777
+ # img = img[partial_mask] + partial_enc_with_noise[~partial_mask]
778
+ else:
779
+ t_batch = torch.tensor([self.num_timesteps - 1] * shape[0], device=device)
780
+ partial_enc_with_noise = self.q_sample(partial_enc, t_batch)
781
+ img = torch.randn(*shape, device=device)
782
+ # print(img.shape, partial_enc_with_noise.shape, partial_mask.shape)
783
+ # img = img[partial_mask] + partial_enc_with_noise[~partial_mask]
784
+ img[~partial_mask] = partial_enc_with_noise[~partial_mask]
785
+ indices = list(range(self.num_timesteps))[::-1]
786
+
787
+ if progress:
788
+ # Lazy import so that we don't depend on tqdm.
789
+ from tqdm.auto import tqdm
790
+
791
+ indices = tqdm(indices)
792
+
793
+ for i in indices:
794
+ t = torch.tensor([i] * shape[0], device=device)
795
+ with torch.no_grad():
796
+ out = self.p_sample(
797
+ model,
798
+ img,
799
+ t,
800
+ clip_denoised=clip_denoised,
801
+ denoised_fn=denoised_fn,
802
+ model_kwargs=model_kwargs,
803
+ )
804
+ if i > 0:
805
+ partial_enc_with_noise = self.q_sample(partial_enc, t - 1)
806
+ else:
807
+ partial_enc_with_noise = partial_enc
808
+ if greedy:
809
+ img = out["greedy_mean"]
810
+ img[~partial_mask] = partial_enc[~partial_mask]
811
+ out["sample"] = img
812
+ else:
813
+ img = out["sample"]
814
+ img[~partial_mask] = partial_enc[~partial_mask]
815
+ # img[~partial_mask] = partial_enc_with_noise[~partial_mask]
816
+ out["sample"] = img
817
+ yield out
818
+
819
+ def p_sample_loop_progressive_merge(
820
+ self,
821
+ model,
822
+ shape,
823
+ partial_enc,
824
+ partial_mask,
825
+ noise=None,
826
+ clip_denoised=True,
827
+ denoised_fn=None,
828
+ model_kwargs=None,
829
+ device=None,
830
+ progress=False,
831
+ greedy=False,
832
+ ):
833
+ """
834
+ Generate samples from the model and yield intermediate samples from
835
+ each timestep of diffusion.
836
+
837
+ Arguments are the same as p_sample_loop().
838
+ Returns a generator over dicts, where each dict is the return value of
839
+ p_sample().
840
+ """
841
+ if device is None:
842
+ device = next(model.parameters()).device
843
+ assert isinstance(shape, (tuple, list))
844
+ if noise is not None:
845
+ img = noise
846
+ # img = img[partial_mask] + partial_enc_with_noise[~partial_mask]
847
+ else:
848
+ t_batch = torch.tensor([self.num_timesteps - 1] * shape[0], device=device)
849
+ partial_enc_with_noise = self.q_sample(partial_enc, t_batch)
850
+ img = torch.randn(*shape, device=device)
851
+ # print(img.shape, partial_enc_with_noise.shape, partial_mask.shape)
852
+ # img = img[partial_mask] + partial_enc_with_noise[~partial_mask]
853
+ img[~partial_mask] = partial_enc_with_noise[~partial_mask]
854
+ indices = list(range(self.num_timesteps))[::-1]
855
+
856
+ if progress:
857
+ # Lazy import so that we don't depend on tqdm.
858
+ from tqdm.auto import tqdm
859
+
860
+ indices = tqdm(indices)
861
+
862
+ for i in indices:
863
+ t = torch.tensor([i] * shape[0], device=device)
864
+ with torch.no_grad():
865
+ out = self.p_sample(
866
+ model,
867
+ img,
868
+ t,
869
+ clip_denoised=clip_denoised,
870
+ denoised_fn=denoised_fn,
871
+ model_kwargs=model_kwargs,
872
+ )
873
+ if i > 0:
874
+ partial_enc_with_noise = self.q_sample(partial_enc, t - 1)
875
+ else:
876
+ partial_enc_with_noise = partial_enc
877
+ if greedy:
878
+ img = out["greedy_mean"]
879
+ img[~partial_mask] = partial_enc[~partial_mask]
880
+ out["sample"] = img
881
+ else:
882
+ img = out["sample"]
883
+ img[~partial_mask] = partial_enc[~partial_mask]
884
+ # img[~partial_mask] = partial_enc_with_noise[~partial_mask]
885
+ out["sample"] = img
886
+ yield out
887
+
888
+ def ddim_sample(
889
+ self,
890
+ model,
891
+ x,
892
+ t,
893
+ clip_denoised=True,
894
+ denoised_fn=None,
895
+ model_kwargs=None,
896
+ eta=0.0,
897
+ langevin_fn=None,
898
+ caption=None,
899
+ ):
900
+ """
901
+ Sample x_{t-1} from the model using DDIM.
902
+
903
+ Same usage as p_sample().
904
+ """
905
+ out = self.p_mean_variance(
906
+ model,
907
+ x,
908
+ t,
909
+ clip_denoised=clip_denoised,
910
+ denoised_fn=denoised_fn,
911
+ model_kwargs=model_kwargs,
912
+ caption=caption,
913
+ )
914
+ # Usually our model outputs epsilon, but we re-derive it
915
+ # in case we used x_start or x_prev prediction.
916
+ eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
917
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
918
+ alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
919
+ sigma = (
920
+ eta
921
+ * torch.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
922
+ * torch.sqrt(1 - alpha_bar / alpha_bar_prev)
923
+ )
924
+ # Equation 12.
925
+ noise = torch.randn_like(x)
926
+ mean_pred = (
927
+ out["pred_xstart"] * torch.sqrt(alpha_bar_prev)
928
+ + torch.sqrt(1 - alpha_bar_prev - sigma**2) * eps
929
+ )
930
+ nonzero_mask = (
931
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
932
+ ) # no noise when t == 0
933
+ # print(sigma.mean())
934
+ sample = mean_pred + nonzero_mask * sigma * noise
935
+ if langevin_fn:
936
+ print(t.shape)
937
+ sample = langevin_fn(
938
+ sample, mean_pred, sigma, self.alphas_cumprod_prev[t[0]], t, x
939
+ )
940
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
941
+
942
+ def ddim_reverse_sample(
943
+ self,
944
+ model,
945
+ x,
946
+ t,
947
+ clip_denoised=True,
948
+ denoised_fn=None,
949
+ model_kwargs=None,
950
+ eta=0.0,
951
+ ):
952
+ """
953
+ Sample x_{t+1} from the model using DDIM reverse ODE.
954
+ """
955
+ assert eta == 0.0, "Reverse ODE only for deterministic path"
956
+ out = self.p_mean_variance(
957
+ model,
958
+ x,
959
+ t,
960
+ clip_denoised=clip_denoised,
961
+ denoised_fn=denoised_fn,
962
+ model_kwargs=model_kwargs,
963
+ )
964
+ # Usually our model outputs epsilon, but we re-derive it
965
+ # in case we used x_start or x_prev prediction.
966
+ eps = (
967
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
968
+ - out["pred_xstart"]
969
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
970
+ alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
971
+
972
+ # Equation 12. reversed
973
+ mean_pred = (
974
+ out["pred_xstart"] * torch.sqrt(alpha_bar_next)
975
+ + torch.sqrt(1 - alpha_bar_next) * eps
976
+ )
977
+
978
+ return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
979
+
980
+ def ddim_sample_loop(
981
+ self,
982
+ model,
983
+ shape,
984
+ noise=None,
985
+ clip_denoised=True,
986
+ denoised_fn=None,
987
+ model_kwargs=None,
988
+ device=None,
989
+ progress=False,
990
+ eta=0.0,
991
+ top_p=-1.0,
992
+ langevin_fn=None,
993
+ caption=None,
994
+ ):
995
+ """
996
+ Generate samples from the model using DDIM.
997
+
998
+ Same usage as p_sample_loop().
999
+ """
1000
+ final = None
1001
+ for sample in self.ddim_sample_loop_progressive(
1002
+ model,
1003
+ shape,
1004
+ noise=noise,
1005
+ clip_denoised=clip_denoised,
1006
+ denoised_fn=denoised_fn,
1007
+ model_kwargs=model_kwargs,
1008
+ device=device,
1009
+ progress=progress,
1010
+ eta=eta,
1011
+ langevin_fn=langevin_fn,
1012
+ caption=caption,
1013
+ ):
1014
+ final = sample
1015
+ return final["sample"]
1016
+
1017
+ def ddim_sample_loop_progressive(
1018
+ self,
1019
+ model,
1020
+ shape,
1021
+ noise=None,
1022
+ clip_denoised=True,
1023
+ denoised_fn=None,
1024
+ model_kwargs=None,
1025
+ device=None,
1026
+ progress=False,
1027
+ eta=0.0,
1028
+ langevin_fn=None,
1029
+ caption=None,
1030
+ ):
1031
+ """
1032
+ Use DDIM to sample from the model and yield intermediate samples from
1033
+ each timestep of DDIM.
1034
+
1035
+ Same usage as p_sample_loop_progressive().
1036
+ """
1037
+ if device is None:
1038
+ device = next(model.parameters()).device
1039
+ assert isinstance(shape, (tuple, list))
1040
+ if noise is not None:
1041
+ img = noise
1042
+ else:
1043
+ img = torch.randn(*shape, device=device)
1044
+ indices = list(range(self.num_timesteps))[::-1]
1045
+ if caption is not None:
1046
+ print("Text Guiding Generation ......")
1047
+ caption = (
1048
+ caption[0].to(img.device),
1049
+ caption[1].to(img.device),
1050
+ ) # (caption_state, caption_mask)
1051
+ if progress:
1052
+ # Lazy import so that we don't depend on tqdm.
1053
+ from tqdm.auto import tqdm
1054
+
1055
+ indices = tqdm(indices)
1056
+
1057
+ for i in indices:
1058
+ t = torch.tensor([i] * shape[0], device=device)
1059
+ with torch.no_grad():
1060
+ out = self.ddim_sample(
1061
+ model,
1062
+ img,
1063
+ t,
1064
+ clip_denoised=clip_denoised,
1065
+ denoised_fn=denoised_fn,
1066
+ model_kwargs=model_kwargs,
1067
+ eta=eta,
1068
+ langevin_fn=langevin_fn,
1069
+ caption=caption,
1070
+ )
1071
+ yield out
1072
+ img = out["sample"]
1073
+
1074
+ def _vb_terms_bpd(
1075
+ self,
1076
+ model,
1077
+ x_start,
1078
+ x_t,
1079
+ t,
1080
+ clip_denoised=True,
1081
+ model_kwargs=None,
1082
+ noise=None,
1083
+ denoised_fn=None,
1084
+ ):
1085
+ """
1086
+ Get a term for the variational lower-bound.
1087
+
1088
+ The resulting units are bits (rather than nats, as one might expect).
1089
+ This allows for comparison to other papers.
1090
+
1091
+ :return: a dict with the following keys:
1092
+ - 'output': a shape [N] tensor of NLLs or KLs.
1093
+ - 'pred_xstart': the x_0 predictions.
1094
+ """
1095
+ # lambda *args, r=frozen_out: r,
1096
+ true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
1097
+ x_start=x_start, x_t=x_t, t=t
1098
+ )
1099
+ if model_kwargs is not None and "input_ids" in model_kwargs:
1100
+ input_ids = model_kwargs.pop("input_ids")
1101
+ mapping_func = model_kwargs.pop("mapping_func", self.mapping_func)
1102
+ else:
1103
+ input_ids = None
1104
+ # noise=None
1105
+ out = self.p_mean_variance(
1106
+ model,
1107
+ x_t,
1108
+ t,
1109
+ clip_denoised=clip_denoised,
1110
+ model_kwargs=model_kwargs,
1111
+ denoised_fn=denoised_fn,
1112
+ )
1113
+ kl = normal_kl(
1114
+ true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
1115
+ )
1116
+ kl = mean_flat(kl) / np.log(2.0)
1117
+
1118
+ if input_ids is not None:
1119
+ # print('input_ids is not None')
1120
+ # from torch.distributions import Normal
1121
+ # normal_dist = Normal(out["mean"], (0.5 * out["log_variance"]).exp())
1122
+ # decoder_nll = -normal_dist.log_prob(x_start)
1123
+ assert mapping_func is not None
1124
+ if mapping_func is not None and torch.any(t == 0):
1125
+
1126
+ decoder_nll = mapping_func(out["mean"], input_ids) / out["mean"].size(
1127
+ -1
1128
+ )
1129
+ else:
1130
+ decoder_nll = torch.zeros_like(x_start)
1131
+ model_kwargs["input_ids"] = input_ids
1132
+ model_kwargs["mapping_func"] = mapping_func
1133
+
1134
+ # target = {
1135
+ # ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
1136
+ # x_start=x_start, x_t=x_t, t=t
1137
+ # )[0],
1138
+ # ModelMeanType.START_X: x_start,
1139
+ # ModelMeanType.EPSILON: noise,
1140
+ # }[self.model_mean_type]
1141
+ # # print(out['mean'].shape, x_start.shape, self.model_mean_type, noise)
1142
+ # assert out["mean"].shape == target.shape == x_start.shape
1143
+ # decoder_nll = (target - out["mean"]) ** 2
1144
+ else:
1145
+ decoder_nll = -discretized_gaussian_log_likelihood(
1146
+ x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
1147
+ )
1148
+ assert decoder_nll.shape == x_start.shape
1149
+ decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
1150
+
1151
+ # At the first timestep return the decoder NLL,
1152
+ # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
1153
+ output = torch.where((t == 0), decoder_nll, kl)
1154
+ return {"output": output, "pred_xstart": out["pred_xstart"]}
1155
+
1156
+ def _vb_terms_bpd_e2e(
1157
+ self,
1158
+ model,
1159
+ x_start,
1160
+ x_t,
1161
+ t,
1162
+ input_ids,
1163
+ get_logits,
1164
+ x_start_mean,
1165
+ x_start_log_var,
1166
+ clip_denoised=True,
1167
+ model_kwargs=None,
1168
+ noise=None,
1169
+ denoised_fn=None,
1170
+ ):
1171
+ """
1172
+ Get a term for the variational lower-bound.
1173
+
1174
+ The resulting units are bits (rather than nats, as one might expect).
1175
+ This allows for comparison to other papers.
1176
+
1177
+ :return: a dict with the following keys:
1178
+ - 'output': a shape [N] tensor of NLLs or KLs.
1179
+ - 'pred_xstart': the x_0 predictions.
1180
+ """
1181
+ # lambda *args, r=frozen_out: r,
1182
+ true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
1183
+ x_start=x_start, x_t=x_t, t=t
1184
+ )
1185
+ assert input_ids is not None
1186
+ mapping_func = model_kwargs.pop("mapping_func", self.mapping_func)
1187
+ # assert 'input_ids' in model_kwargs
1188
+ # input_ids = model_kwargs.pop('input_ids')
1189
+
1190
+ out = self.p_mean_variance(
1191
+ model,
1192
+ x_t,
1193
+ t,
1194
+ clip_denoised=clip_denoised,
1195
+ model_kwargs=model_kwargs,
1196
+ denoised_fn=denoised_fn,
1197
+ )
1198
+ # print(true_log_variance_clipped[0], out["log_variance"][0], 'line1259')
1199
+ kl = normal_kl(
1200
+ true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
1201
+ )
1202
+ kl = mean_flat(kl) / np.log(2.0)
1203
+
1204
+ decoder_nll = self.token_discrete_loss(x_start, get_logits, input_ids) # t=-1
1205
+
1206
+ decoder_nll = decoder_nll / out["mean"].size(-1)
1207
+ decoder_nll = decoder_nll / np.log(2.0)
1208
+
1209
+ mask_1 = t == 0
1210
+ if mask_1.any():
1211
+ kl_T = normal_kl(
1212
+ x_start_mean, x_start_log_var, out["mean"], out["log_variance"]
1213
+ )
1214
+ kl_T = mean_flat(kl_T) / np.log(2.0)
1215
+ kl = torch.where(mask_1, kl_T, kl)
1216
+
1217
+ out_mean, out_variance, out_log_variance_clipped = self.q_mean_variance(
1218
+ x_start, torch.LongTensor([self.num_timesteps - 1]).to(x_start.device)
1219
+ )
1220
+ kl_T = normal_kl(out_mean, out_log_variance_clipped, 0, 0)
1221
+ kl_T = mean_flat(kl_T) / np.log(2.0)
1222
+
1223
+ # print(decoder_nll, )
1224
+ # print()
1225
+ # At the first timestep return the decoder NLL,
1226
+ # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
1227
+ # output =torch.where((t == 0), decoder_nll, kl)
1228
+ output = kl + decoder_nll + kl_T
1229
+ return {
1230
+ "output": output,
1231
+ "pred_xstart": out["pred_xstart"],
1232
+ "kl": kl,
1233
+ "decoder_nll": decoder_nll,
1234
+ "kl_T": kl_T,
1235
+ }
1236
+
1237
+ def get_x_start(self, x_start_mean, std):
1238
+ """
1239
+ Using the interpolating policy OR using the convolution policy...
1240
+ :param x_start_mean:
1241
+ :return:
1242
+ """
1243
+ noise = torch.randn_like(x_start_mean)
1244
+ # print(std.shape, noise.shape, x_start_mean.shape)
1245
+ assert noise.shape == x_start_mean.shape
1246
+ # print(x_start_mean.device, noise.device)
1247
+ return x_start_mean + std * noise
1248
+
1249
+ def token_discrete_loss(self, x_t, get_logits, input_ids):
1250
+ if self.model_arch == "conv-unet" or self.model_arch == "1d-unet":
1251
+ reshaped_x_t = x_t.view(x_t.size(0), x_t.size(1), -1).permute(0, 2, 1)
1252
+ else:
1253
+ # print(x_t.shape)
1254
+ reshaped_x_t = x_t
1255
+ # logits = get_logits(reshaped_x_t) # bsz, seqlen, vocab
1256
+
1257
+ logits = get_logits(reshaped_x_t)
1258
+
1259
+ loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
1260
+ decoder_nll = loss_fct(
1261
+ logits.view(-1, logits.size(-1)), input_ids.view(-1)
1262
+ ).view(input_ids.shape)
1263
+ decoder_nll = decoder_nll.mean(dim=-1)
1264
+ return decoder_nll
1265
+
1266
+ def x0_helper(self, model_output, x, t):
1267
+ if self.model_mean_type == ModelMeanType.PREVIOUS_X:
1268
+ pred_xstart = self._predict_xstart_from_xprev(
1269
+ x_t=x, t=t, xprev=model_output
1270
+ )
1271
+ pred_prev = model_output
1272
+
1273
+ elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]:
1274
+ if self.model_mean_type == ModelMeanType.START_X:
1275
+ pred_xstart = model_output
1276
+ else:
1277
+ pred_xstart = self._predict_xstart_from_eps(
1278
+ x_t=x, t=t, eps=model_output
1279
+ )
1280
+ pred_prev, _, _ = self.q_posterior_mean_variance(
1281
+ x_start=pred_xstart, x_t=x, t=t
1282
+ )
1283
+
1284
+ else:
1285
+ raise NotImplementedError(self.model_mean_type)
1286
+ return {"pred_xprev": pred_prev, "pred_xstart": pred_xstart}
1287
+
1288
+ def training_losses_e2e(self, model, micro, t, noise=None):
1289
+ """
1290
+ The function `training_losses_e2e` calculates various loss terms for an end-to-end training
1291
+ process in a machine learning model.
1292
+
1293
+ :param model: The `model` parameter in the `training_losses_e2e` function seems to be an
1294
+ instance of a model used for training. It is likely a neural network model that is being trained
1295
+ for a specific task, such as sequence generation or prediction. The model is used within the
1296
+ function to make predictions
1297
+ :param micro: The `micro` parameter in the `training_losses_e2e` function seems to be a tuple
1298
+ containing the following elements:
1299
+ :param t: The `t` parameter in the `training_losses_e2e` function seems to represent the time
1300
+ step or timestep index. It is used to determine certain conditions within the function, such as
1301
+ comparing it to a threshold value of 400 and scaling timesteps. The function performs various
1302
+ calculations and computations based
1303
+ :param noise: The `noise` parameter in the `training_losses_e2e` function is used to pass a
1304
+ tensor representing random noise. If the `noise` parameter is not provided when calling the
1305
+ function, it generates random noise using `torch.randn_like(mix_start)`. This noise is then used
1306
+ in the
1307
+ :return: The function `training_losses_e2e` returns a dictionary `terms` containing different
1308
+ loss terms based on the specified loss type. The specific terms included in the dictionary
1309
+ depend on the conditions and calculations performed within the function for the given loss type.
1310
+ The function calculates and populates the `terms` dictionary with relevant loss values such as
1311
+ mean squared error (mse), variational bound (vb), decoder negative
1312
+ """
1313
+ selfies_ids = micro[0]
1314
+ caption_state = micro[1]
1315
+ caption_mask = micro[2]
1316
+ corrupted_selfies_ids = micro[3]
1317
+ assert corrupted_selfies_ids.shape == selfies_ids.shape
1318
+
1319
+ #########################################
1320
+ mix_ids = torch.where(
1321
+ t.reshape(-1, 1) < 400, corrupted_selfies_ids, selfies_ids
1322
+ )
1323
+ if t.max() > self.maxt:
1324
+ self.maxt = t.max()
1325
+ # print("Recieving max t:{}".format(self.maxt))
1326
+ ##########################################
1327
+ # print(f"Model dir: {dir(model)}")
1328
+ try:
1329
+ x_start_mean = model.model.get_embeds(selfies_ids)
1330
+ mix_start_mean = model.model.get_embeds(mix_ids)
1331
+ except:
1332
+ x_start_mean = model.model.module.get_embeds(selfies_ids)
1333
+ mix_start_mean = model.model.module.get_embeds(mix_ids)
1334
+
1335
+ std = _extract_into_tensor(
1336
+ self.sqrt_one_minus_alphas_cumprod,
1337
+ torch.tensor([0]).to(x_start_mean.device),
1338
+ x_start_mean.shape,
1339
+ )
1340
+
1341
+ x_start = self.get_x_start(x_start_mean, std)
1342
+ mix_start = self.get_x_start(mix_start_mean, std)
1343
+
1344
+ if noise is None:
1345
+ noise = torch.randn_like(mix_start)
1346
+ x_t = self.q_sample(mix_start, t, noise=noise) # reparametrization trick.
1347
+ try:
1348
+ get_logits = model.model.get_logits
1349
+ except:
1350
+ get_logits = model.model.module.get_logits
1351
+
1352
+ terms = {}
1353
+
1354
+ if self.loss_type == LossType.E2E_KL:
1355
+ pass
1356
+
1357
+ elif (
1358
+ self.loss_type == LossType.E2E_MSE
1359
+ or self.loss_type == LossType.E2E_RESCALED_MSE
1360
+ ):
1361
+ model_output = model(
1362
+ x_t, self._scale_timesteps(t), caption_state, caption_mask
1363
+ )
1364
+
1365
+ if self.model_var_type in [
1366
+ ModelVarType.LEARNED,
1367
+ ModelVarType.LEARNED_RANGE,
1368
+ ]:
1369
+ pass
1370
+
1371
+ target = {
1372
+ # ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
1373
+ # x_start=x_start, x_t=x_t, t=t
1374
+ # )[0],
1375
+ ModelMeanType.START_X: x_start,
1376
+ ModelMeanType.EPSILON: noise,
1377
+ }[
1378
+ self.model_mean_type
1379
+ ] # this is exactly x_start
1380
+ # print(model_output.shape ,target.shape , x_start.shape)
1381
+
1382
+ assert model_output.shape == target.shape == x_start.shape
1383
+ terms["mse"] = mean_flat((target - model_output) ** 2)
1384
+ # print( terms["mse"])
1385
+ model_out_x_start = self.x0_helper(model_output, x_t, t)[
1386
+ "pred_xstart"
1387
+ ] # this is exactly model_output
1388
+ t0_mask = t == 0
1389
+ t0_loss = mean_flat((x_start_mean - model_out_x_start) ** 2)
1390
+ # print(terms["mse"].shape, )
1391
+ terms["mse"] = torch.where(t0_mask, t0_loss, terms["mse"])
1392
+
1393
+ # tT_mask = (t == self.num_timesteps - 1)
1394
+ out_mean, _, _ = self.q_mean_variance(
1395
+ x_start, torch.LongTensor([self.num_timesteps - 1]).to(x_start.device)
1396
+ )
1397
+ tT_loss = mean_flat(out_mean**2)
1398
+
1399
+ decoder_nll = self.token_discrete_loss(x_start, get_logits, selfies_ids)
1400
+
1401
+ if "vb" in terms:
1402
+ terms["loss"] = terms["mse"] + terms["vb"]
1403
+ else:
1404
+ terms["loss"] = terms["mse"] + (decoder_nll + tT_loss)
1405
+ else:
1406
+ raise NotImplementedError(self.loss_type)
1407
+
1408
+ return terms
1409
+
1410
+ def _prior_bpd(self, x_start):
1411
+ """
1412
+ Get the prior KL term for the variational lower-bound, measured in
1413
+ bits-per-dim.
1414
+
1415
+ This term can't be optimized, as it only depends on the encoder.
1416
+
1417
+ :param x_start: the [N x C x ...] tensor of inputs.
1418
+ :return: a batch of [N] KL values (in bits), one per batch element.
1419
+ """
1420
+ batch_size = x_start.shape[0]
1421
+ t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
1422
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
1423
+ kl_prior = normal_kl(
1424
+ mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
1425
+ )
1426
+ return mean_flat(kl_prior) / np.log(2.0)
1427
+
1428
+ def calc_bpd_loop_e2e(
1429
+ self, model, x_start, clip_denoised=True, model_kwargs=None, denoised_fn=None
1430
+ ):
1431
+ device = x_start.device
1432
+ batch_size = x_start.shape[0]
1433
+
1434
+ input_ids = model_kwargs.pop("input_ids").to(device)
1435
+ x_start_mean = model.get_embeds(input_ids)
1436
+ if self.model_arch == "conv-unet":
1437
+ seqlen = int(np.sqrt(input_ids.size(1)))
1438
+ x_start_mean = x_start_mean.view(
1439
+ x_start_mean.size(0), seqlen, seqlen, x_start_mean.size(-1)
1440
+ ).permute(0, 3, 1, 2)
1441
+ elif self.model_arch == "1d-unet":
1442
+ x_start_mean = x_start_mean.permute(0, 2, 1)
1443
+ std = _extract_into_tensor(
1444
+ self.sqrt_one_minus_alphas_cumprod,
1445
+ torch.tensor([0]).to(x_start_mean.device),
1446
+ x_start_mean.shape,
1447
+ )
1448
+ x_start_log_var = 2 * torch.log(std)
1449
+ x_start = self.get_x_start(x_start_mean, std)
1450
+ get_logits = model.get_logits
1451
+
1452
+ vb = []
1453
+ xstart_mse = []
1454
+ mse = []
1455
+ for t in list(range(self.num_timesteps))[::-1]:
1456
+ t_batch = torch.tensor([t] * batch_size, device=device)
1457
+ noise = torch.randn_like(x_start)
1458
+ x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
1459
+ with torch.no_grad():
1460
+ out = self._vb_terms_bpd_e2e(
1461
+ model,
1462
+ x_start=x_start,
1463
+ x_t=x_t,
1464
+ t=t_batch,
1465
+ input_ids=input_ids,
1466
+ get_logits=get_logits,
1467
+ x_start_mean=x_start_mean,
1468
+ x_start_log_var=x_start_log_var,
1469
+ clip_denoised=clip_denoised,
1470
+ model_kwargs=model_kwargs,
1471
+ noise=noise,
1472
+ denoised_fn=denoised_fn,
1473
+ )
1474
+ if t == self.num_timesteps - 1:
1475
+ assert len(vb) == 0
1476
+ vb.append(out["kl_T"])
1477
+ vb.append(out["kl"])
1478
+ xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
1479
+ eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
1480
+ mse.append(mean_flat((eps - noise) ** 2))
1481
+ vb.append(out["decoder_nll"])
1482
+
1483
+ vb = torch.stack(vb, dim=1)
1484
+ xstart_mse = torch.stack(xstart_mse, dim=1)
1485
+ mse = torch.stack(mse, dim=1)
1486
+
1487
+ # prior_bpd = self._prior_bpd(x_start)
1488
+ prior_bpd = out["kl_T"]
1489
+ total_bpd = vb.sum(dim=1)
1490
+ return {
1491
+ "total_bpd": total_bpd,
1492
+ "prior_bpd": prior_bpd,
1493
+ "vb": vb,
1494
+ "xstart_mse": xstart_mse,
1495
+ "mse": mse,
1496
+ }
1497
+
1498
+ def calc_bpd_loop_emb(
1499
+ self, model, x_start, clip_denoised=True, model_kwargs=None, denoised_fn=None
1500
+ ):
1501
+ """
1502
+ Compute the entire variational lower-bound, measured in bits-per-dim,
1503
+ as well as other related quantities.
1504
+
1505
+ :param model: the model to evaluate loss on.
1506
+ :param x_start: the [N x C x ...] tensor of inputs.
1507
+ :param clip_denoised: if True, clip denoised samples.
1508
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
1509
+ pass to the model. This can be used for conditioning.
1510
+
1511
+ :return: a dict containing the following keys:
1512
+ - total_bpd: the total variational lower-bound, per batch element.
1513
+ - prior_bpd: the prior term in the lower-bound.
1514
+ - vb: an [N x T] tensor of terms in the lower-bound.
1515
+ - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
1516
+ - mse: an [N x T] tensor of epsilon MSEs for each timestep.
1517
+ """
1518
+ device = x_start.device
1519
+ batch_size = x_start.shape[0]
1520
+
1521
+ vb = []
1522
+ xstart_mse = []
1523
+ mse = []
1524
+ for t in list(range(self.num_timesteps))[::-1]:
1525
+ t_batch = torch.tensor([t] * batch_size, device=device)
1526
+ noise = torch.randn_like(x_start)
1527
+ # print(t)
1528
+ x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
1529
+ # Calculate VLB term at the current timestep
1530
+ with torch.no_grad():
1531
+ out = self._vb_terms_bpd(
1532
+ model,
1533
+ x_start=x_start,
1534
+ x_t=x_t,
1535
+ t=t_batch,
1536
+ clip_denoised=clip_denoised,
1537
+ model_kwargs=model_kwargs,
1538
+ noise=noise,
1539
+ denoised_fn=denoised_fn,
1540
+ )
1541
+ vb.append(out["output"])
1542
+ xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
1543
+ eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
1544
+
1545
+ #
1546
+ # ## DEBUG
1547
+ # def is_very_close(a, b):
1548
+ # return (((a - b) ** 2).mean())
1549
+ # x_start_cycle = self._predict_xstart_from_eps(x_t=x_t, t=t_batch, eps=noise)
1550
+ # gold_eps_cycle = self._predict_eps_from_xstart(x_t, t_batch, x_start_cycle)
1551
+ # print(((gold_eps_cycle-noise)**2).mean())
1552
+
1553
+ # print(is_very_close(out2['pred_xstart'],out["pred_xstart"]), 'first isclose --> check p_mean')
1554
+ # model.eval()
1555
+ # with torch.no_grad():
1556
+ # direct_pred_eps = model(x_t, self._scale_timesteps(t_batch), **model_kwargs)
1557
+ # print(((direct_pred_eps - noise) ** 2).mean(), 'ans1', self.rescale_timesteps)
1558
+
1559
+ # x_start_cycle_pred = self._predict_xstart_from_eps(x_t=x_t, t=t_batch, eps=direct_pred_eps)
1560
+ # model_kwargs['debug_x_t'] = x_t
1561
+ # model_kwargs['debug_t_batch'] = t_batch
1562
+ # model_kwargs['debug_direct_pred_eps'] = direct_pred_eps
1563
+ # model_kwargs['debug_x_start_cycle_pred'] = x_start_cycle_pred
1564
+
1565
+ # out2 = self.p_mean_variance(
1566
+ # model, x_t, t_batch, clip_denoised=clip_denoised, model_kwargs=model_kwargs
1567
+ # )
1568
+ # # print(((out["pred_xstart"] - x_start_cycle_pred) ** 2).mean(), 'if not align issue with vb_terms')
1569
+ # print(is_very_close(out2['pred_xstart'], x_start_cycle_pred), '2nd isclose --> check our flattened')
1570
+ # gold_eps_cycle_pred = self._predict_eps_from_xstart(x_t, t_batch, x_start_cycle_pred)
1571
+
1572
+ # print(((eps - noise) ** 2).mean(), 'ans2', self._scale_timesteps)
1573
+ # print()
1574
+ # print(((gold_eps_cycle_pred - direct_pred_eps) ** 2).mean(), 'should be same, exactly same computation..')
1575
+ ## DEBUG
1576
+ mse.append(mean_flat((eps - noise) ** 2))
1577
+
1578
+ vb = torch.stack(vb, dim=1)
1579
+ xstart_mse = torch.stack(xstart_mse, dim=1)
1580
+ mse = torch.stack(mse, dim=1)
1581
+
1582
+ prior_bpd = self._prior_bpd(x_start)
1583
+ total_bpd = vb.sum(dim=1) + prior_bpd
1584
+ return {
1585
+ "total_bpd": total_bpd,
1586
+ "prior_bpd": prior_bpd,
1587
+ "vb": vb,
1588
+ "xstart_mse": xstart_mse,
1589
+ "mse": mse,
1590
+ }
1591
+
1592
+
1593
+ def _extract_into_tensor(arr, timesteps, broadcast_shape):
1594
+ """
1595
+ Extract values from a 1-D numpy array for a batch of indices.
1596
+
1597
+ :param arr: the 1-D numpy array.
1598
+ :param timesteps: a tensor of indices into the array to extract.
1599
+ :param broadcast_shape: a larger shape of K dimensions with the batch
1600
+ dimension equal to the length of timesteps.
1601
+ :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
1602
+ """
1603
+ res = torch.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
1604
+ while len(res.shape) < len(broadcast_shape):
1605
+ res = res[..., None]
1606
+ return res.expand(broadcast_shape)
src/improved_diffusion/image_datasets.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import blobfile as bf
3
+ from mpi4py import MPI
4
+ import numpy as np
5
+ from torch.utils.data import DataLoader, Dataset
6
+
7
+
8
+ def load_data(
9
+ *, data_dir, batch_size, image_size, class_cond=False, deterministic=False, permutation=None
10
+ ):
11
+ """
12
+ For a dataset, create a generator over (images, kwargs) pairs.
13
+
14
+ Each images is an NCHW float tensor, and the kwargs dict contains zero or
15
+ more keys, each of which map to a batched Tensor of their own.
16
+ The kwargs dict can be used for class labels, in which case the key is "y"
17
+ and the values are integer tensors of class labels.
18
+
19
+ :param data_dir: a dataset directory.
20
+ :param batch_size: the batch size of each returned pair.
21
+ :param image_size: the size to which images are resized.
22
+ :param class_cond: if True, include a "y" key in returned dicts for class
23
+ label. If classes are not available and this is true, an
24
+ exception will be raised.
25
+ :param deterministic: if True, yield results in a deterministic order.
26
+ """
27
+ if not data_dir:
28
+ raise ValueError("unspecified data directory")
29
+ all_files = _list_image_files_recursively(data_dir)
30
+ classes = None
31
+ if class_cond:
32
+ # Assume classes are the first part of the filename,
33
+ # before an underscore.
34
+ class_names = [bf.basename(path).split("_")[0] for path in all_files]
35
+ sorted_classes = {x: i for i, x in enumerate(sorted(set(class_names)))}
36
+ classes = [sorted_classes[x] for x in class_names]
37
+ dataset = ImageDataset(
38
+ image_size,
39
+ all_files,
40
+ classes=classes,
41
+ shard=MPI.COMM_WORLD.Get_rank(),
42
+ num_shards=MPI.COMM_WORLD.Get_size(),
43
+ permutation=permutation,
44
+ )
45
+ if deterministic:
46
+ loader = DataLoader(
47
+ dataset, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=True
48
+ )
49
+ else:
50
+ loader = DataLoader(
51
+ dataset, batch_size=batch_size, shuffle=True, num_workers=1, drop_last=True
52
+ )
53
+ while True:
54
+ yield from loader
55
+
56
+
57
+ def _list_image_files_recursively(data_dir):
58
+ results = []
59
+ for entry in sorted(bf.listdir(data_dir)):
60
+ full_path = bf.join(data_dir, entry)
61
+ ext = entry.split(".")[-1]
62
+ if "." in entry and ext.lower() in ["jpg", "jpeg", "png", "gif"]:
63
+ results.append(full_path)
64
+ elif bf.isdir(full_path):
65
+ results.extend(_list_image_files_recursively(full_path))
66
+ return results
67
+
68
+
69
+ class ImageDataset(Dataset):
70
+ def __init__(self, resolution, image_paths, classes=None, shard=0, num_shards=1, permutation=None):
71
+ super().__init__()
72
+ self.resolution = resolution
73
+ self.local_images = image_paths[shard:][::num_shards]
74
+ self.local_classes = None if classes is None else classes[shard:][::num_shards]
75
+ self.permutation = permutation
76
+
77
+ def __len__(self):
78
+ return len(self.local_images)
79
+
80
+ def __getitem__(self, idx):
81
+ path = self.local_images[idx]
82
+ with bf.BlobFile(path, "rb") as f:
83
+ pil_image = Image.open(f)
84
+ pil_image.load()
85
+
86
+ # We are not on a new enough PIL to support the `reducing_gap`
87
+ # argument, which uses BOX downsampling at powers of two first.
88
+ # Thus, we do it by hand to improve downsample quality.
89
+ while min(*pil_image.size) >= 2 * self.resolution:
90
+ pil_image = pil_image.resize(
91
+ tuple(x // 2 for x in pil_image.size), resample=Image.BOX
92
+ )
93
+
94
+ scale = self.resolution / min(*pil_image.size)
95
+ pil_image = pil_image.resize(
96
+ tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
97
+ )
98
+
99
+ arr = np.array(pil_image.convert("RGB"))
100
+ crop_y = (arr.shape[0] - self.resolution) // 2
101
+ crop_x = (arr.shape[1] - self.resolution) // 2
102
+ arr = arr[crop_y : crop_y + self.resolution, crop_x : crop_x + self.resolution]
103
+ if self.permutation is not None: # pixel value permutation.
104
+ # print('running permutation.')
105
+ # print(arr)
106
+ arr = self.permutation[arr]
107
+ # print(arr)
108
+
109
+ arr = arr.astype(np.float32) / 127.5 - 1
110
+
111
+ # if self.permutation is not None: # pixel location permutation.
112
+ # # print('running permutation.')
113
+ # arr_reshaped = arr.reshape(arr.shape[0] * arr.shape[1], -1)
114
+ # arr_permuted = arr_reshaped[self.permutation,:]
115
+ # arr = arr_permuted.reshape(arr.shape[0], arr.shape[1], -1)
116
+
117
+ out_dict = {}
118
+ if self.local_classes is not None:
119
+ out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
120
+ return np.transpose(arr, [2, 0, 1]), out_dict
src/improved_diffusion/logger.py ADDED
@@ -0,0 +1,498 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Logger copied from OpenAI baselines to avoid extra RL-based dependencies:
3
+ https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/logger.py
4
+ """
5
+
6
+ import os
7
+ import sys
8
+ import shutil
9
+ import os.path as osp
10
+ import json
11
+ import time
12
+ import datetime
13
+ import tempfile
14
+ import warnings
15
+ from collections import defaultdict
16
+ from contextlib import contextmanager
17
+ import wandb
18
+
19
+ DEBUG = 10
20
+ INFO = 20
21
+ WARN = 30
22
+ ERROR = 40
23
+
24
+ DISABLED = 50
25
+
26
+
27
+ class KVWriter(object):
28
+ def writekvs(self, kvs):
29
+ raise NotImplementedError
30
+
31
+
32
+ class SeqWriter(object):
33
+ def writeseq(self, seq):
34
+ raise NotImplementedError
35
+
36
+
37
+ class HumanOutputFormat(KVWriter, SeqWriter):
38
+ def __init__(self, filename_or_file):
39
+ if isinstance(filename_or_file, str):
40
+ self.file = open(filename_or_file, "wt")
41
+ self.own_file = True
42
+ else:
43
+ assert hasattr(filename_or_file, "read"), (
44
+ "expected file or str, got %s" % filename_or_file
45
+ )
46
+ self.file = filename_or_file
47
+ self.own_file = False
48
+
49
+ def writekvs(self, kvs):
50
+ # Create strings for printing
51
+ key2str = {}
52
+ for (key, val) in sorted(kvs.items()):
53
+ if hasattr(val, "__float__"):
54
+ valstr = "%-8.3g" % val
55
+ else:
56
+ valstr = str(val)
57
+ key2str[self._truncate(key)] = self._truncate(valstr)
58
+
59
+ # Find max widths
60
+ if len(key2str) == 0:
61
+ print("WARNING: tried to write empty key-value dict")
62
+ return
63
+ else:
64
+ keywidth = max(map(len, key2str.keys()))
65
+ valwidth = max(map(len, key2str.values()))
66
+
67
+ # Write out the data
68
+ dashes = "-" * (keywidth + valwidth + 7)
69
+ lines = [dashes]
70
+ for (key, val) in sorted(key2str.items(), key=lambda kv: kv[0].lower()):
71
+ lines.append(
72
+ "| %s%s | %s%s |"
73
+ % (key, " " * (keywidth - len(key)), val, " " * (valwidth - len(val)))
74
+ )
75
+ lines.append(dashes)
76
+ self.file.write("\n".join(lines) + "\n")
77
+
78
+ # Flush the output to the file
79
+ self.file.flush()
80
+
81
+ def _truncate(self, s):
82
+ maxlen = 30
83
+ return s[: maxlen - 3] + "..." if len(s) > maxlen else s
84
+
85
+ def writeseq(self, seq):
86
+ seq = list(seq)
87
+ for (i, elem) in enumerate(seq):
88
+ self.file.write(elem)
89
+ if i < len(seq) - 1: # add space unless this is the last one
90
+ self.file.write(" ")
91
+ self.file.write("\n")
92
+ self.file.flush()
93
+
94
+ def close(self):
95
+ if self.own_file:
96
+ self.file.close()
97
+
98
+
99
+ class JSONOutputFormat(KVWriter):
100
+ def __init__(self, filename):
101
+ self.file = open(filename, "wt")
102
+
103
+ def writekvs(self, kvs):
104
+ for k, v in sorted(kvs.items()):
105
+ if hasattr(v, "dtype"):
106
+ kvs[k] = float(v)
107
+ self.file.write(json.dumps(kvs) + "\n")
108
+ self.file.flush()
109
+
110
+ def close(self):
111
+ self.file.close()
112
+
113
+
114
+ class CSVOutputFormat(KVWriter):
115
+ def __init__(self, filename):
116
+ self.file = open(filename, "w+t")
117
+ self.keys = []
118
+ self.sep = ","
119
+
120
+ def writekvs(self, kvs):
121
+ # Add our current row to the history
122
+ extra_keys = list(kvs.keys() - self.keys)
123
+ extra_keys.sort()
124
+ if extra_keys:
125
+ self.keys.extend(extra_keys)
126
+ self.file.seek(0)
127
+ lines = self.file.readlines()
128
+ self.file.seek(0)
129
+ for (i, k) in enumerate(self.keys):
130
+ if i > 0:
131
+ self.file.write(",")
132
+ self.file.write(k)
133
+ self.file.write("\n")
134
+ for line in lines[1:]:
135
+ self.file.write(line[:-1])
136
+ self.file.write(self.sep * len(extra_keys))
137
+ self.file.write("\n")
138
+ for (i, k) in enumerate(self.keys):
139
+ if i > 0:
140
+ self.file.write(",")
141
+ v = kvs.get(k)
142
+ if v is not None:
143
+ self.file.write(str(v))
144
+ self.file.write("\n")
145
+ self.file.flush()
146
+
147
+ def close(self):
148
+ self.file.close()
149
+
150
+
151
+ class TensorBoardOutputFormat(KVWriter):
152
+ """
153
+ Dumps key/value pairs into TensorBoard's numeric format.
154
+ """
155
+
156
+ def __init__(self, dir):
157
+ os.makedirs(dir, exist_ok=True)
158
+ self.dir = dir
159
+ self.step = 1
160
+ prefix = "events"
161
+ path = osp.join(osp.abspath(dir), prefix)
162
+ import tensorflow as tf
163
+ from tensorflow.python import pywrap_tensorflow
164
+ from tensorflow.core.util import event_pb2
165
+ from tensorflow.python.util import compat
166
+
167
+ self.tf = tf
168
+ self.event_pb2 = event_pb2
169
+ self.pywrap_tensorflow = pywrap_tensorflow
170
+ self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path))
171
+
172
+ def writekvs(self, kvs):
173
+ def summary_val(k, v):
174
+ kwargs = {"tag": k, "simple_value": float(v)}
175
+ return self.tf.Summary.Value(**kwargs)
176
+
177
+ summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()])
178
+ event = self.event_pb2.Event(wall_time=time.time(), summary=summary)
179
+ event.step = (
180
+ self.step
181
+ ) # is there any reason why you'd want to specify the step?
182
+ self.writer.WriteEvent(event)
183
+ self.writer.Flush()
184
+ self.step += 1
185
+
186
+ def close(self):
187
+ if self.writer:
188
+ self.writer.Close()
189
+ self.writer = None
190
+
191
+
192
+ def make_output_format(format, ev_dir, log_suffix=""):
193
+ os.makedirs(ev_dir, exist_ok=True)
194
+ if format == "stdout":
195
+ return HumanOutputFormat(sys.stdout)
196
+ elif format == "log":
197
+ return HumanOutputFormat(osp.join(ev_dir, "log%s.txt" % log_suffix))
198
+ elif format == "json":
199
+ return JSONOutputFormat(osp.join(ev_dir, "progress%s.json" % log_suffix))
200
+ elif format == "csv":
201
+ return CSVOutputFormat(osp.join(ev_dir, "progress%s.csv" % log_suffix))
202
+ elif format == "tensorboard":
203
+ return TensorBoardOutputFormat(osp.join(ev_dir, "tb%s" % log_suffix))
204
+ else:
205
+ raise ValueError("Unknown format specified: %s" % (format,))
206
+
207
+
208
+ # ================================================================
209
+ # API
210
+ # ================================================================
211
+
212
+
213
+ def logkv(key, val):
214
+ """
215
+ Log a value of some diagnostic
216
+ Call this once for each diagnostic quantity, each iteration
217
+ If called many times, last value will be used.
218
+ """
219
+ get_current().logkv(key, val)
220
+
221
+
222
+ def logkv_mean(key, val):
223
+ """
224
+ The same as logkv(), but if called many times, values averaged.
225
+ """
226
+ get_current().logkv_mean(key, val)
227
+
228
+
229
+ def logkvs(d):
230
+ """
231
+ Log a dictionary of key-value pairs
232
+ """
233
+ for (k, v) in d.items():
234
+ logkv(k, v)
235
+
236
+
237
+ def dumpkvs():
238
+ """
239
+ Write all of the diagnostics from the current iteration
240
+ """
241
+ return get_current().dumpkvs()
242
+
243
+
244
+ def getkvs():
245
+ return get_current().name2val
246
+
247
+
248
+ def log(*args, level=INFO):
249
+ """
250
+ Write the sequence of args, with no separators, to the console and output files (if you've configured an output file).
251
+ """
252
+ get_current().log(*args, level=level)
253
+
254
+
255
+ def debug(*args):
256
+ log(*args, level=DEBUG)
257
+
258
+
259
+ def info(*args):
260
+ log(*args, level=INFO)
261
+
262
+
263
+ def warn(*args):
264
+ log(*args, level=WARN)
265
+
266
+
267
+ def error(*args):
268
+ log(*args, level=ERROR)
269
+
270
+
271
+ def set_level(level):
272
+ """
273
+ Set logging threshold on current logger.
274
+ """
275
+ get_current().set_level(level)
276
+
277
+
278
+ def set_comm(comm):
279
+ get_current().set_comm(comm)
280
+
281
+
282
+ def get_dir():
283
+ """
284
+ Get directory that log files are being written to.
285
+ will be None if there is no output directory (i.e., if you didn't call start)
286
+ """
287
+ return get_current().get_dir()
288
+
289
+
290
+ record_tabular = logkv
291
+ dump_tabular = dumpkvs
292
+
293
+
294
+ @contextmanager
295
+ def profile_kv(scopename):
296
+ logkey = "wait_" + scopename
297
+ tstart = time.time()
298
+ try:
299
+ yield
300
+ finally:
301
+ get_current().name2val[logkey] += time.time() - tstart
302
+
303
+
304
+ def profile(n):
305
+ """
306
+ Usage:
307
+ @profile("my_func")
308
+ def my_func(): code
309
+ """
310
+
311
+ def decorator_with_name(func):
312
+ def func_wrapper(*args, **kwargs):
313
+ with profile_kv(n):
314
+ return func(*args, **kwargs)
315
+
316
+ return func_wrapper
317
+
318
+ return decorator_with_name
319
+
320
+
321
+ # ================================================================
322
+ # Backend
323
+ # ================================================================
324
+
325
+
326
+ def get_current():
327
+ if Logger.CURRENT is None:
328
+ _configure_default_logger()
329
+
330
+ return Logger.CURRENT
331
+
332
+
333
+ class Logger(object):
334
+ DEFAULT = None # A logger with no output files. (See right below class definition)
335
+ # So that you can still log to the terminal without setting up any output files
336
+ CURRENT = None # Current logger being used by the free functions above
337
+
338
+ def __init__(self, dir, output_formats, comm=None):
339
+ self.name2val = defaultdict(float) # values this iteration
340
+ self.name2cnt = defaultdict(int)
341
+ self.level = INFO
342
+ self.dir = dir
343
+ self.output_formats = output_formats
344
+ self.comm = comm
345
+
346
+ # Logging API, forwarded
347
+ # ----------------------------------------
348
+ def logkv(self, key, val):
349
+ self.name2val[key] = val
350
+
351
+ def logkv_mean(self, key, val):
352
+ oldval, cnt = self.name2val[key], self.name2cnt[key]
353
+ self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1)
354
+ self.name2cnt[key] = cnt + 1
355
+
356
+ def dumpkvs(self, prefix=None):
357
+ if self.comm is None:
358
+ d = self.name2val
359
+ else:
360
+ d = mpi_weighted_mean(
361
+ self.comm,
362
+ {
363
+ name: (val, self.name2cnt.get(name, 1))
364
+ for (name, val) in self.name2val.items()
365
+ },
366
+ )
367
+ if self.comm.rank != 0:
368
+ d["dummy"] = 1 # so we don't get a warning about empty dict
369
+ # LISA
370
+ wandb.log({**d})
371
+ out = d.copy() # Return the dict for unit testing purposes
372
+ for fmt in self.output_formats:
373
+ if isinstance(fmt, KVWriter):
374
+ fmt.writekvs(d)
375
+ self.name2val.clear()
376
+ self.name2cnt.clear()
377
+ return out
378
+
379
+ def log(self, *args, level=INFO):
380
+ if self.level <= level:
381
+ self._do_log(args)
382
+
383
+ # Configuration
384
+ # ----------------------------------------
385
+ def set_level(self, level):
386
+ self.level = level
387
+
388
+ def set_comm(self, comm):
389
+ self.comm = comm
390
+
391
+ def get_dir(self):
392
+ return self.dir
393
+
394
+ def close(self):
395
+ for fmt in self.output_formats:
396
+ fmt.close()
397
+
398
+ # Misc
399
+ # ----------------------------------------
400
+ def _do_log(self, args):
401
+ for fmt in self.output_formats:
402
+ if isinstance(fmt, SeqWriter):
403
+ fmt.writeseq(map(str, args))
404
+
405
+
406
+ def get_rank_without_mpi_import():
407
+ # check environment variables here instead of importing mpi4py
408
+ # to avoid calling MPI_Init() when this module is imported
409
+ for varname in ["PMI_RANK", "OMPI_COMM_WORLD_RANK"]:
410
+ if varname in os.environ:
411
+ return int(os.environ[varname])
412
+ return 0
413
+
414
+
415
+ def mpi_weighted_mean(comm, local_name2valcount):
416
+ """
417
+ Copied from: https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/mpi_util.py#L110
418
+ Perform a weighted average over dicts that are each on a different node
419
+ Input: local_name2valcount: dict mapping key -> (value, count)
420
+ Returns: key -> mean
421
+ """
422
+ all_name2valcount = comm.gather(local_name2valcount)
423
+ if comm.rank == 0:
424
+ name2sum = defaultdict(float)
425
+ name2count = defaultdict(float)
426
+ for n2vc in all_name2valcount:
427
+ for (name, (val, count)) in n2vc.items():
428
+ try:
429
+ val = float(val)
430
+ except ValueError:
431
+ if comm.rank == 0:
432
+ warnings.warn(
433
+ "WARNING: tried to compute mean on non-float {}={}".format(
434
+ name, val
435
+ )
436
+ )
437
+ else:
438
+ name2sum[name] += val * count
439
+ name2count[name] += count
440
+ return {name: name2sum[name] / name2count[name] for name in name2sum}
441
+ else:
442
+ return {}
443
+
444
+
445
+ def configure(dir=None, format_strs=None, comm=None, log_suffix=""):
446
+ """
447
+ If comm is provided, average all numerical stats across that comm
448
+ """
449
+ if dir is None:
450
+ dir = os.getenv("OPENAI_LOGDIR")
451
+ if dir is None:
452
+ dir = osp.join(
453
+ tempfile.gettempdir(),
454
+ datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"),
455
+ )
456
+ assert isinstance(dir, str)
457
+ dir = os.path.expanduser(dir)
458
+ os.makedirs(os.path.expanduser(dir), exist_ok=True)
459
+
460
+ rank = get_rank_without_mpi_import()
461
+ if rank > 0:
462
+ log_suffix = log_suffix + "-rank%03i" % rank
463
+
464
+ if format_strs is None:
465
+ if rank == 0:
466
+ format_strs = os.getenv("OPENAI_LOG_FORMAT", "stdout,log,csv").split(",")
467
+ else:
468
+ format_strs = os.getenv("OPENAI_LOG_FORMAT_MPI", "log").split(",")
469
+ format_strs = filter(None, format_strs)
470
+ output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs]
471
+
472
+ Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm)
473
+ if output_formats:
474
+ log("Logging to %s" % dir)
475
+
476
+
477
+ def _configure_default_logger():
478
+ configure()
479
+ Logger.DEFAULT = Logger.CURRENT
480
+
481
+
482
+ def reset():
483
+ if Logger.CURRENT is not Logger.DEFAULT:
484
+ Logger.CURRENT.close()
485
+ Logger.CURRENT = Logger.DEFAULT
486
+ log("Reset logger")
487
+
488
+
489
+ @contextmanager
490
+ def scoped_configure(dir=None, format_strs=None, comm=None):
491
+ prevlogger = Logger.CURRENT
492
+ configure(dir=dir, format_strs=format_strs, comm=comm)
493
+ try:
494
+ yield
495
+ finally:
496
+ Logger.CURRENT.close()
497
+ Logger.CURRENT = prevlogger
498
+
src/improved_diffusion/losses.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Helpers for various likelihood-based losses. These are ported from the original
3
+ Ho et al. diffusion models codebase:
4
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py
5
+ """
6
+
7
+ import numpy as np
8
+
9
+ import torch as th
10
+
11
+
12
+ def normal_kl(mean1, logvar1, mean2, logvar2):
13
+ """
14
+ Compute the KL divergence between two gaussians.
15
+
16
+ Shapes are automatically broadcasted, so batches can be compared to
17
+ scalars, among other use cases.
18
+ """
19
+ tensor = None
20
+ for obj in (mean1, logvar1, mean2, logvar2):
21
+ if isinstance(obj, th.Tensor):
22
+ tensor = obj
23
+ break
24
+ assert tensor is not None, "at least one argument must be a Tensor"
25
+
26
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
27
+ # Tensors, but it does not work for th.exp().
28
+ logvar1, logvar2 = [
29
+ x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
30
+ for x in (logvar1, logvar2)
31
+ ]
32
+
33
+ # print(logvar2.shape)
34
+ # temp1 = 0.5 * (-1.0 + logvar2 - logvar1 + th.exp(logvar1 - logvar2))
35
+ # print(f'const = {temp1.mean()}, coef={(th.exp(-logvar2) * 0.5).mean()}, mse={((mean1 - mean2) ** 2).mean().item()}')
36
+
37
+ return 0.5 * (
38
+ -1.0
39
+ + logvar2
40
+ - logvar1
41
+ + th.exp(logvar1 - logvar2)
42
+ + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
43
+ )
44
+
45
+
46
+ def approx_standard_normal_cdf(x):
47
+ """
48
+ A fast approximation of the cumulative distribution function of the
49
+ standard normal.
50
+ """
51
+ return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
52
+
53
+
54
+ def discretized_gaussian_log_likelihood(x, *, means, log_scales):
55
+ """
56
+ Compute the log-likelihood of a Gaussian distribution discretizing to a
57
+ given image.
58
+
59
+ :param x: the target images. It is assumed that this was uint8 values,
60
+ rescaled to the range [-1, 1].
61
+ :param means: the Gaussian mean Tensor.
62
+ :param log_scales: the Gaussian log stddev Tensor.
63
+ :return: a tensor like x of log probabilities (in nats).
64
+ """
65
+ assert x.shape == means.shape == log_scales.shape
66
+ centered_x = x - means
67
+ inv_stdv = th.exp(-log_scales)
68
+ plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
69
+ cdf_plus = approx_standard_normal_cdf(plus_in)
70
+ min_in = inv_stdv * (centered_x - 1.0 / 255.0)
71
+ cdf_min = approx_standard_normal_cdf(min_in)
72
+ log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
73
+ log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
74
+ cdf_delta = cdf_plus - cdf_min
75
+ log_probs = th.where(
76
+ x < -0.999,
77
+ log_cdf_plus,
78
+ th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
79
+ )
80
+ assert log_probs.shape == x.shape
81
+ return log_probs
82
+
83
+ def gaussian_density(x, *, means, log_scales):
84
+ from torch.distributions import Normal
85
+ normal_dist = Normal(means, log_scales.exp())
86
+ logp = normal_dist.log_prob(x)
87
+ return logp
88
+
89
+
90
+ def discretized_text_log_likelihood(x, *, means, log_scales):
91
+ """
92
+ Compute the log-likelihood of a Gaussian distribution discretizing to a
93
+ given image.
94
+
95
+ :param x: the target images. It is assumed that this was uint8 values,
96
+ rescaled to the range [-1, 1].
97
+ :param means: the Gaussian mean Tensor.
98
+ :param log_scales: the Gaussian log stddev Tensor.
99
+ :return: a tensor like x of log probabilities (in nats).
100
+ """
101
+ print(x.shape, means.shape)
102
+ # assert x.shape == means.shape == log_scales.shape
103
+ print(x, means)
104
+ centered_x = x - means
105
+ inv_stdv = th.exp(-log_scales)
106
+ plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
107
+ cdf_plus = approx_standard_normal_cdf(plus_in)
108
+ min_in = inv_stdv * (centered_x - 1.0 / 255.0)
109
+ cdf_min = approx_standard_normal_cdf(min_in)
110
+ log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
111
+ log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
112
+ cdf_delta = cdf_plus - cdf_min
113
+ log_probs = th.where(
114
+ x < -0.999,
115
+ log_cdf_plus,
116
+ th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
117
+ )
118
+ assert log_probs.shape == x.shape
119
+ return log_probs
src/improved_diffusion/nn.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Various utilities for neural networks.
3
+ """
4
+
5
+ import math
6
+
7
+ import torch as th
8
+ import torch.nn as nn
9
+
10
+
11
+ # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
12
+ class SiLU(nn.Module):
13
+ def forward(self, x):
14
+ return x * th.sigmoid(x)
15
+
16
+
17
+ class GroupNorm32(nn.GroupNorm):
18
+ def forward(self, x):
19
+ return super().forward(x.float()).type(x.dtype)
20
+
21
+
22
+ def conv_nd(dims, *args, **kwargs):
23
+ """
24
+ Create a 1D, 2D, or 3D convolution module.
25
+ """
26
+ if dims == 1:
27
+ return nn.Conv1d(*args, **kwargs)
28
+ elif dims == 2:
29
+ return nn.Conv2d(*args, **kwargs)
30
+ elif dims == 3:
31
+ return nn.Conv3d(*args, **kwargs)
32
+ raise ValueError(f"unsupported dimensions: {dims}")
33
+
34
+
35
+ def linear(*args, **kwargs):
36
+ """
37
+ Create a linear module.
38
+ """
39
+ return nn.Linear(*args, **kwargs)
40
+
41
+
42
+ def avg_pool_nd(dims, *args, **kwargs):
43
+ """
44
+ Create a 1D, 2D, or 3D average pooling module.
45
+ """
46
+ if dims == 1:
47
+ return nn.AvgPool1d(*args, **kwargs)
48
+ elif dims == 2:
49
+ return nn.AvgPool2d(*args, **kwargs)
50
+ elif dims == 3:
51
+ return nn.AvgPool3d(*args, **kwargs)
52
+ raise ValueError(f"unsupported dimensions: {dims}")
53
+
54
+
55
+ def update_ema(target_params, source_params, rate=0.99):
56
+ """
57
+ Update target parameters to be closer to those of source parameters using
58
+ an exponential moving average.
59
+
60
+ :param target_params: the target parameter sequence.
61
+ :param source_params: the source parameter sequence.
62
+ :param rate: the EMA rate (closer to 1 means slower).
63
+ """
64
+ for targ, src in zip(target_params, source_params):
65
+ targ.detach().mul_(rate).add_(src, alpha=1 - rate)
66
+
67
+
68
+ def zero_module(module):
69
+ """
70
+ Zero out the parameters of a module and return it.
71
+ """
72
+ for p in module.parameters():
73
+ p.detach().zero_()
74
+ return module
75
+
76
+
77
+ def scale_module(module, scale):
78
+ """
79
+ Scale the parameters of a module and return it.
80
+ """
81
+ for p in module.parameters():
82
+ p.detach().mul_(scale)
83
+ return module
84
+
85
+
86
+ def mean_flat(tensor):
87
+ """
88
+ Take the mean over all non-batch dimensions.
89
+ """
90
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
91
+
92
+
93
+ def normalization(channels):
94
+ """
95
+ Make a standard normalization layer.
96
+
97
+ :param channels: number of input channels.
98
+ :return: an nn.Module for normalization.
99
+ """
100
+ return GroupNorm32(32, channels)
101
+
102
+
103
+ def timestep_embedding(timesteps, dim, max_period=10000):
104
+ """
105
+ Create sinusoidal timestep embeddings.
106
+
107
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
108
+ These may be fractional.
109
+ :param dim: the dimension of the output.
110
+ :param max_period: controls the minimum frequency of the embeddings.
111
+ :return: an [N x dim] Tensor of positional embeddings.
112
+ """
113
+ half = dim // 2
114
+ freqs = th.exp(
115
+ -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
116
+ ).to(device=timesteps.device)
117
+ args = timesteps[:, None].float() * freqs[None]
118
+ embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
119
+ if dim % 2:
120
+ embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
121
+ return embedding
122
+
123
+
124
+ def checkpoint(func, inputs, params, flag):
125
+ """
126
+ Evaluate a function without caching intermediate activations, allowing for
127
+ reduced memory at the expense of extra compute in the backward pass.
128
+
129
+ :param func: the function to evaluate.
130
+ :param inputs: the argument sequence to pass to `func`.
131
+ :param params: a sequence of parameters `func` depends on but does not
132
+ explicitly take as arguments.
133
+ :param flag: if False, disable gradient checkpointing.
134
+ """
135
+ if flag:
136
+ args = tuple(inputs) + tuple(params)
137
+ return CheckpointFunction.apply(func, len(inputs), *args)
138
+ else:
139
+ return func(*inputs)
140
+
141
+
142
+ class CheckpointFunction(th.autograd.Function):
143
+ @staticmethod
144
+ def forward(ctx, run_function, length, *args):
145
+ ctx.run_function = run_function
146
+ ctx.input_tensors = list(args[:length])
147
+ ctx.input_params = list(args[length:])
148
+ with th.no_grad():
149
+ output_tensors = ctx.run_function(*ctx.input_tensors)
150
+ return output_tensors
151
+
152
+ @staticmethod
153
+ def backward(ctx, *output_grads):
154
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
155
+ with th.enable_grad():
156
+ # Fixes a bug where the first op in run_function modifies the
157
+ # Tensor storage in place, which is not allowed for detach()'d
158
+ # Tensors.
159
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
160
+ output_tensors = ctx.run_function(*shallow_copies)
161
+ input_grads = th.autograd.grad(
162
+ output_tensors,
163
+ ctx.input_tensors + ctx.input_params,
164
+ output_grads,
165
+ allow_unused=True,
166
+ )
167
+ del ctx.input_tensors
168
+ del ctx.input_params
169
+ del output_tensors
170
+ return (None, None) + input_grads
src/improved_diffusion/resample.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+
3
+ import numpy as np
4
+ import torch as th
5
+ import torch.distributed as dist
6
+
7
+
8
+ def create_named_schedule_sampler(name, diffusion):
9
+ """
10
+ Create a ScheduleSampler from a library of pre-defined samplers.
11
+
12
+ :param name: the name of the sampler.
13
+ :param diffusion: the diffusion object to sample for.
14
+ """
15
+ if name == "uniform":
16
+ return UniformSampler(diffusion)
17
+ elif name == "loss-second-moment":
18
+ return LossSecondMomentResampler(diffusion)
19
+ else:
20
+ raise NotImplementedError(f"unknown schedule sampler: {name}")
21
+
22
+
23
+ class ScheduleSampler(ABC):
24
+ """
25
+ A distribution over timesteps in the diffusion process, intended to reduce
26
+ variance of the objective.
27
+
28
+ By default, samplers perform unbiased importance sampling, in which the
29
+ objective's mean is unchanged.
30
+ However, subclasses may override sample() to change how the resampled
31
+ terms are reweighted, allowing for actual changes in the objective.
32
+ """
33
+
34
+ @abstractmethod
35
+ def weights(self):
36
+ """
37
+ Get a numpy array of weights, one per diffusion step.
38
+
39
+ The weights needn't be normalized, but must be positive.
40
+ """
41
+
42
+ def sample(self, batch_size, device):
43
+ """
44
+ Importance-sample timesteps for a batch.
45
+
46
+ :param batch_size: the number of timesteps.
47
+ :param device: the torch device to save to.
48
+ :return: a tuple (timesteps, weights):
49
+ - timesteps: a tensor of timestep indices.
50
+ - weights: a tensor of weights to scale the resulting losses.
51
+ """
52
+ w = self.weights()
53
+ p = w / np.sum(w)
54
+ indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
55
+ indices = th.from_numpy(indices_np).long().to(device)
56
+ weights_np = 1 / (len(p) * p[indices_np])
57
+ weights = th.from_numpy(weights_np).float().to(device)
58
+ return indices, weights
59
+
60
+
61
+ class UniformSampler(ScheduleSampler):
62
+ def __init__(self, diffusion):
63
+ self.diffusion = diffusion
64
+ self._weights = np.ones([diffusion.num_timesteps])
65
+
66
+ def weights(self):
67
+ return self._weights
68
+
69
+
70
+ class LossAwareSampler(ScheduleSampler):
71
+ def update_with_local_losses(self, local_ts, local_losses):
72
+ """
73
+ Update the reweighting using losses from a model.
74
+
75
+ Call this method from each rank with a batch of timesteps and the
76
+ corresponding losses for each of those timesteps.
77
+ This method will perform synchronization to make sure all of the ranks
78
+ maintain the exact same reweighting.
79
+
80
+ :param local_ts: an integer Tensor of timesteps.
81
+ :param local_losses: a 1D Tensor of losses.
82
+ """
83
+ batch_sizes = [
84
+ th.tensor([0], dtype=th.int32, device=local_ts.device)
85
+ for _ in range(dist.get_world_size())
86
+ ]
87
+ dist.all_gather(
88
+ batch_sizes,
89
+ th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
90
+ )
91
+
92
+ # Pad all_gather batches to be the maximum batch size.
93
+ batch_sizes = [x.item() for x in batch_sizes]
94
+ max_bs = max(batch_sizes)
95
+
96
+ timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
97
+ loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
98
+ dist.all_gather(timestep_batches, local_ts)
99
+ dist.all_gather(loss_batches, local_losses)
100
+ timesteps = [
101
+ x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
102
+ ]
103
+ losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
104
+ self.update_with_all_losses(timesteps, losses)
105
+
106
+ @abstractmethod
107
+ def update_with_all_losses(self, ts, losses):
108
+ """
109
+ Update the reweighting using losses from a model.
110
+
111
+ Sub-classes should override this method to update the reweighting
112
+ using losses from the model.
113
+
114
+ This method directly updates the reweighting without synchronizing
115
+ between workers. It is called by update_with_local_losses from all
116
+ ranks with identical arguments. Thus, it should have deterministic
117
+ behavior to maintain state across workers.
118
+
119
+ :param ts: a list of int timesteps.
120
+ :param losses: a list of float losses, one per timestep.
121
+ """
122
+
123
+
124
+ class LossSecondMomentResampler(LossAwareSampler):
125
+ def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
126
+ self.diffusion = diffusion
127
+ self.history_per_term = history_per_term
128
+ self.uniform_prob = uniform_prob
129
+ self._loss_history = np.zeros(
130
+ [diffusion.num_timesteps, history_per_term], dtype=np.float64
131
+ )
132
+ self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
133
+
134
+ def weights(self):
135
+ if not self._warmed_up():
136
+ return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
137
+ weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
138
+ weights /= np.sum(weights)
139
+ weights *= 1 - self.uniform_prob
140
+ weights += self.uniform_prob / len(weights)
141
+ return weights
142
+
143
+ def update_with_all_losses(self, ts, losses):
144
+ for t, loss in zip(ts, losses):
145
+ if self._loss_counts[t] == self.history_per_term:
146
+ # Shift out the oldest loss term.
147
+ self._loss_history[t, :-1] = self._loss_history[t, 1:]
148
+ self._loss_history[t, -1] = loss
149
+ else:
150
+ self._loss_history[t, self._loss_counts[t]] = loss
151
+ self._loss_counts[t] += 1
152
+
153
+ def _warmed_up(self):
154
+ return (self._loss_counts == self.history_per_term).all()
src/improved_diffusion/respace.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch as th
3
+
4
+ from .gaussian_diffusion import GaussianDiffusion
5
+
6
+
7
+ def space_timesteps(num_timesteps, section_counts):
8
+ """
9
+ Create a list of timesteps to use from an original diffusion process,
10
+ given the number of timesteps we want to take from equally-sized portions
11
+ of the original process.
12
+
13
+ For example, if there's 300 timesteps and the section counts are [10,15,20]
14
+ then the first 100 timesteps are strided to be 10 timesteps, the second 100
15
+ are strided to be 15 timesteps, and the final 100 are strided to be 20.
16
+
17
+ If the stride is a string starting with "ddim", then the fixed striding
18
+ from the DDIM paper is used, and only one section is allowed.
19
+
20
+ :param num_timesteps: the number of diffusion steps in the original
21
+ process to divide up.
22
+ :param section_counts: either a list of numbers, or a string containing
23
+ comma-separated numbers, indicating the step count
24
+ per section. As a special case, use "ddimN" where N
25
+ is a number of steps to use the striding from the
26
+ DDIM paper.
27
+ :return: a set of diffusion steps from the original process to use.
28
+ """
29
+ # if isinstance(section_counts, str):
30
+ # if section_counts.startswith("ddim"):
31
+ # desired_count = int(section_counts[len("ddim") :])
32
+ # for i in range(1, num_timesteps):
33
+ # if len(range(0, num_timesteps, i)) == desired_count:
34
+ # return set(range(0, num_timesteps, i))
35
+ # raise ValueError(
36
+ # f"cannot create exactly {num_timesteps} steps with an integer stride"
37
+ # )
38
+ # section_counts = [int(x) for x in section_counts.split(",")]
39
+ size_per = num_timesteps // len(section_counts)
40
+ extra = num_timesteps % len(section_counts)
41
+ start_idx = 0
42
+ all_steps = []
43
+ for i, section_count in enumerate(section_counts):
44
+ size = size_per + (1 if i < extra else 0)
45
+ if size < section_count:
46
+ raise ValueError(
47
+ f"cannot divide section of {size} steps into {section_count}"
48
+ )
49
+ if section_count <= 1:
50
+ frac_stride = 1
51
+ else:
52
+ frac_stride = (size - 1) / (section_count - 1)
53
+ cur_idx = 0.0
54
+ taken_steps = []
55
+ for _ in range(section_count):
56
+ taken_steps.append(start_idx + round(cur_idx))
57
+ cur_idx += frac_stride
58
+ all_steps += taken_steps
59
+ start_idx += size
60
+ return set(all_steps)
61
+
62
+
63
+ class SpacedDiffusion(GaussianDiffusion):
64
+ """
65
+ A diffusion process which can skip steps in a base diffusion process.
66
+
67
+ :param use_timesteps: a collection (sequence or set) of timesteps from the
68
+ original diffusion process to retain.
69
+ :param kwargs: the kwargs to create the base diffusion process.
70
+ """
71
+
72
+ def __init__(self, use_timesteps, **kwargs):
73
+ self.use_timesteps = set(use_timesteps)
74
+ self.timestep_map = []
75
+ self.original_num_steps = len(kwargs["betas"])
76
+
77
+ # print(kwargs.keys())
78
+ base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
79
+ last_alpha_cumprod = 1.0
80
+ new_betas = []
81
+ for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
82
+ if i in self.use_timesteps:
83
+ new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
84
+ last_alpha_cumprod = alpha_cumprod
85
+ self.timestep_map.append(i)
86
+ kwargs["betas"] = np.array(new_betas)
87
+ super().__init__(**kwargs)
88
+
89
+ def p_mean_variance(
90
+ self, model, *args, **kwargs
91
+ ): # pylint: disable=signature-differs
92
+ # print('called p_mean_var')
93
+ return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
94
+
95
+ def training_losses(
96
+ self, model, *args, **kwargs
97
+ ): # pylint: disable=signature-differs
98
+ # print('called training_losses')
99
+ return super().training_losses(self._wrap_model(model), *args, **kwargs)
100
+
101
+ def _wrap_model(self, model):
102
+ if isinstance(model, _WrappedModel):
103
+ return model
104
+ return _WrappedModel(
105
+ model, self.timestep_map, self.rescale_timesteps, self.original_num_steps
106
+ )
107
+
108
+ def _scale_timesteps(self, t):
109
+ # Scaling is done by the wrapped model.
110
+ return t
111
+
112
+
113
+ class _WrappedModel:
114
+ def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
115
+ self.model = model
116
+ self.timestep_map = timestep_map
117
+ self.rescale_timesteps = rescale_timesteps
118
+ self.original_num_steps = original_num_steps
119
+
120
+ def __call__(self, x, ts, *args,**kwargs):
121
+ # print(ts)
122
+ map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
123
+ new_ts = map_tensor[ts]
124
+ # print(new_ts)
125
+ if self.rescale_timesteps:
126
+ new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
127
+ # temp = self.model(x, new_ts, **kwargs)
128
+ # print(temp.shape)
129
+ # return temp
130
+ # print(new_ts)
131
+ return self.model(x, new_ts,*args, **kwargs)
src/improved_diffusion/rounding.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ # bert results
3
+ from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer, default_data_collator
4
+ import sys, yaml, os
5
+ # print( os.path.join(sys.path[0], '../../transformers/examples/pytorch/language-modeling'))
6
+ # sys.path.insert(0, 'diffusion_lm/transformers/examples/pytorch/language-modeling')
7
+ # sys.path.insert(0, os.path.join(sys.path[0], '../../transformers/examples/pytorch/language-modeling'))
8
+ # from custom_trainer import GPT2LMHeadModelCompress, BERTModelCompress, AutoEncoderWithNoise
9
+
10
+ def load_models(modality, mode, model_name_or_path, emb_dim, file, extra_args=None):
11
+
12
+ if mode in ['random', 'random1', 'random_up_proj', 'glove']:
13
+ if modality == 'synth':
14
+ pass# print(file, 'deciding what to load::: ')
15
+ # if 'synth128' in file:
16
+ # config = 'diffusion_lm/synthetic_data/configs/emnlp2020/experiments/difflm_seed0_m3_k128_trainc20000.yaml'
17
+ # else:
18
+ # config = 'diffusion_lm/synthetic_data/configs/emnlp2020/experiments/difflm_seed0_m3_k32_trainc20000.yaml'
19
+ # import sys, os
20
+ # sys.path.insert(0, 'diffusion_lm/synthetic_data/rnns-stacks')
21
+ # from dataset import Dataset as SynthDataset
22
+ # args_synth = yaml.load(open(config))
23
+ # dataset = SynthDataset(args_synth)
24
+ # model = torch.nn.Embedding(len(dataset.vocab), emb_dim)
25
+ # print('initializing the random embeddings', model)
26
+ # # print(os.path.split(file.split('.')[0])[-1])
27
+ # # path_save = '{}/random_emb.torch'.format(file)
28
+ # path_save = '{}/random_emb.torch'.format(file)
29
+ # model.load_state_dict(torch.load(path_save))
30
+ # print(dataset.vocab)
31
+ # tokenizer = {v: k for k, v in dataset.vocab.items()}
32
+ else:
33
+ import json
34
+ if modality == 'book' or (extra_args is not None and extra_args.use_bert_tokenizer == 'yes'):
35
+ pass# tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
36
+ # if 'e2e' in file and modality == 'book':
37
+ # emb_dim = 1
38
+ else:
39
+ path_save_tokenizer = '{}/vocab.json'.format(file)
40
+ path_save_tokenizer = '/data0/gonghaisong/Diffusion-LM/improved-diffusion/diffusion_models/diff_e2e-tgt_block_rand16_transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd102_xstart_e2e/vocab.json'
41
+ print(f'loading from {path_save_tokenizer}')
42
+ with open(path_save_tokenizer, 'r') as f:
43
+ vocab = json.load(f)
44
+ print(len(vocab))
45
+ tokenizer = {v: k for k, v in vocab.items()}
46
+ model = torch.nn.Embedding(len(tokenizer), emb_dim)
47
+ path_save = '{}/random_emb.torch'.format(file)
48
+ path_save = '/data0/gonghaisong/Diffusion-LM/improved-diffusion/diffusion_models/diff_e2e-tgt_block_rand16_transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd102_xstart_e2e/random_emb.torch'
49
+ model.load_state_dict(torch.load(path_save))
50
+
51
+ return model, tokenizer
52
+
53
+
54
+ def load_tokenizer(modality, mode, model_name_or_path):
55
+ if mode in ['random', 'random_up_proj', 'glove']:
56
+ if modality == 'synth':
57
+ print(model_name_or_path, 'deciding what to load::: ')
58
+ if 'synth128' in model_name_or_path:
59
+ config = 'diffusion_lm/synthetic_data/configs/emnlp2020/experiments/difflm_seed0_m3_k128_trainc20000.yaml'
60
+ else:
61
+ config = 'diffusion_lm/synthetic_data/configs/emnlp2020/experiments/difflm_seed0_m3_k32_trainc20000.yaml'
62
+
63
+ import sys, os
64
+ sys.path.insert(0, 'diffusion_lm/synthetic_data/rnns-stacks')
65
+ from dataset import Dataset as SynthDataset
66
+ args_synth = yaml.load(open(config))
67
+ dataset = SynthDataset(args_synth)
68
+ tokenizer = {v: k for k, v in dataset.vocab.items()}
69
+ elif modality =='book':
70
+ tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
71
+ else:
72
+ import json
73
+ path_save_tokenizer = '{}/vocab.json'.format(model_name_or_path)
74
+ with open(path_save_tokenizer, 'r') as f:
75
+ vocab = json.load(f)
76
+ tokenizer = {v: k for k, v in vocab.items()}
77
+
78
+ return tokenizer
79
+
80
+ def rounding_func(mode, text_emb_lst, model, tokenizer, emb_scale_factor=1.0):
81
+ decoded_out_lst = []
82
+ if mode in ['random', 'random_up_proj', 'glove']:
83
+ down_proj_emb = model.weight # input_embs
84
+ down_proj_emb2 = None
85
+
86
+
87
+ def get_knn(down_proj_emb, text_emb, dist='cos'):
88
+
89
+ if dist == 'cos':
90
+ adjacency = down_proj_emb @ text_emb.transpose(1, 0).to(down_proj_emb.device)
91
+ elif dist == 'l2':
92
+ adjacency = down_proj_emb.unsqueeze(1).expand(-1, text_emb.size(0), -1) - text_emb.unsqueeze(0).expand(
93
+ down_proj_emb.size(0), -1, -1)
94
+ adjacency = -torch.norm(adjacency, dim=-1)
95
+ topk_out = torch.topk(adjacency, k=6, dim=0)
96
+ return topk_out.values, topk_out.indices
97
+
98
+ dist = 'l2'
99
+ # print(npzfile['arr_0'].shape)
100
+ for text_emb in text_emb_lst:
101
+ import torch
102
+ text_emb = torch.tensor(text_emb)
103
+ # print(text_emb.shape)
104
+ if len(text_emb.shape) > 2:
105
+ text_emb = text_emb.view(-1, text_emb.size(-1))
106
+ else:
107
+ text_emb = text_emb
108
+ val, indices = get_knn((down_proj_emb2 if dist == 'cos' else down_proj_emb),
109
+ text_emb.to(down_proj_emb.device), dist=dist)
110
+ # generated_lst.append(tuple(indices[0].tolist()))
111
+
112
+ # print(indices[0].tolist())
113
+ # for i in range(64):
114
+ # print([tokenizer[x.item()] for x in indices[:,i]])
115
+ decoded_out = " ".join([tokenizer[i] for i in indices[0].tolist()])
116
+ decoded_out_lst.append(decoded_out)
117
+
118
+ return decoded_out_lst
119
+
src/improved_diffusion/script_util.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ from . import gaussian_diffusion as gd
4
+ from .respace import SpacedDiffusion, space_timesteps
5
+
6
+ # from .unet import SuperResModel
7
+
8
+ NUM_CLASSES = 1000
9
+
10
+
11
+ def model_and_diffusion_defaults():
12
+ """
13
+ Defaults for image training.
14
+ """
15
+ return dict(
16
+ image_size=64,
17
+ num_channels=128,
18
+ num_res_blocks=2,
19
+ num_heads=4,
20
+ num_heads_upsample=-1,
21
+ attention_resolutions="16,8",
22
+ dropout=0.0,
23
+ learn_sigma=False,
24
+ class_cond=False,
25
+ diffusion_steps=1000,
26
+ noise_schedule="linear",
27
+ timestep_respacing="",
28
+ use_kl=False,
29
+ predict_xstart=False,
30
+ rescale_timesteps=True,
31
+ rescale_learned_sigmas=True,
32
+ use_checkpoint=False,
33
+ use_scale_shift_norm=True,
34
+ model_arch="trans-unet",
35
+ in_channel=8,
36
+ out_channel=8,
37
+ training_mode="emb",
38
+ vocab_size=66,
39
+ config_name="QizhiPei/biot5-base-text2mol",
40
+ experiment_mode="lm",
41
+ logits_mode=1,
42
+ )
43
+
44
+
45
+ # def sr_model_and_diffusion_defaults():
46
+ # res = model_and_diffusion_defaults()
47
+ # res["large_size"] = 256
48
+ # res["small_size"] = 64
49
+ # arg_names = inspect.getfullargspec(sr_create_model_and_diffusion)[0]
50
+ # for k in res.copy().keys():
51
+ # if k not in arg_names:
52
+ # del res[k]
53
+ # return res
54
+
55
+
56
+ # def sr_create_model_and_diffusion(
57
+ # large_size,
58
+ # small_size,
59
+ # class_cond,
60
+ # learn_sigma,
61
+ # num_channels,
62
+ # num_res_blocks,
63
+ # num_heads,
64
+ # num_heads_upsample,
65
+ # attention_resolutions,
66
+ # dropout,
67
+ # diffusion_steps,
68
+ # noise_schedule,
69
+ # timestep_respacing,
70
+ # use_kl,
71
+ # predict_xstart,
72
+ # rescale_timesteps,
73
+ # rescale_learned_sigmas,
74
+ # use_checkpoint,
75
+ # use_scale_shift_norm,
76
+ # ):
77
+ # model = sr_create_model(
78
+ # large_size,
79
+ # small_size,
80
+ # num_channels,
81
+ # num_res_blocks,
82
+ # learn_sigma=learn_sigma,
83
+ # class_cond=class_cond,
84
+ # use_checkpoint=use_checkpoint,
85
+ # attention_resolutions=attention_resolutions,
86
+ # num_heads=num_heads,
87
+ # num_heads_upsample=num_heads_upsample,
88
+ # use_scale_shift_norm=use_scale_shift_norm,
89
+ # dropout=dropout,
90
+ # )
91
+ # diffusion = create_gaussian_diffusion(
92
+ # steps=diffusion_steps,
93
+ # learn_sigma=learn_sigma,
94
+ # noise_schedule=noise_schedule,
95
+ # use_kl=use_kl,
96
+ # predict_xstart=predict_xstart,
97
+ # rescale_timesteps=rescale_timesteps,
98
+ # rescale_learned_sigmas=rescale_learned_sigmas,
99
+ # timestep_respacing=timestep_respacing,
100
+ # )
101
+ # return model, diffusion
102
+
103
+
104
+ # def sr_create_model(
105
+ # large_size,
106
+ # small_size,
107
+ # num_channels,
108
+ # num_res_blocks,
109
+ # learn_sigma,
110
+ # class_cond,
111
+ # use_checkpoint,
112
+ # attention_resolutions,
113
+ # num_heads,
114
+ # num_heads_upsample,
115
+ # use_scale_shift_norm,
116
+ # dropout,
117
+ # ):
118
+ # _ = small_size # hack to prevent unused variable
119
+
120
+ # if large_size == 256:
121
+ # channel_mult = (1, 1, 2, 2, 4, 4)
122
+ # elif large_size == 64:
123
+ # channel_mult = (1, 2, 3, 4)
124
+ # else:
125
+ # raise ValueError(f"unsupported large size: {large_size}")
126
+
127
+ # attention_ds = []
128
+ # for res in attention_resolutions.split(","):
129
+ # attention_ds.append(large_size // int(res))
130
+
131
+ # return SuperResModel(
132
+ # in_channels=3,
133
+ # model_channels=num_channels,
134
+ # out_channels=(3 if not learn_sigma else 6),
135
+ # num_res_blocks=num_res_blocks,
136
+ # attention_resolutions=tuple(attention_ds),
137
+ # dropout=dropout,
138
+ # channel_mult=channel_mult,
139
+ # num_classes=(NUM_CLASSES if class_cond else None),
140
+ # use_checkpoint=use_checkpoint,
141
+ # num_heads=num_heads,
142
+ # num_heads_upsample=num_heads_upsample,
143
+ # use_scale_shift_norm=use_scale_shift_norm,
144
+ # )
145
+
146
+
147
+ def create_gaussian_diffusion(
148
+ *,
149
+ steps=1000,
150
+ learn_sigma=False,
151
+ noise_schedule="linear", # sqrt
152
+ use_kl=False,
153
+ predict_xstart=False, # True
154
+ rescale_timesteps=False, # True
155
+ rescale_learned_sigmas=False, # True
156
+ timestep_respacing="",
157
+ model_arch="conv-unet", # transformer
158
+ training_mode="emb", # e2e
159
+ ):
160
+ return SpacedDiffusion(
161
+ use_timesteps=space_timesteps(2000, [2000]),
162
+ betas=gd.get_named_beta_schedule("sqrt", 2000),
163
+ model_mean_type=(gd.ModelMeanType.START_X),
164
+ model_var_type=(
165
+ (gd.ModelVarType.FIXED_LARGE)
166
+ if not learn_sigma
167
+ else gd.ModelVarType.LEARNED_RANGE
168
+ ),
169
+ loss_type=gd.LossType.E2E_MSE,
170
+ rescale_timesteps=True,
171
+ model_arch="transformer",
172
+ training_mode="e2e",
173
+ )
174
+
175
+
176
+ def add_dict_to_argparser(parser, default_dict):
177
+ for k, v in default_dict.items():
178
+ v_type = type(v)
179
+ if v is None:
180
+ v_type = str
181
+ elif isinstance(v, bool):
182
+ v_type = str2bool
183
+ parser.add_argument(f"--{k}", default=v, type=v_type)
184
+
185
+
186
+ def args_to_dict(args, keys):
187
+ return {k: getattr(args, k) for k in keys}
188
+
189
+
190
+ def str2bool(v):
191
+ """
192
+ https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
193
+ """
194
+ if isinstance(v, bool):
195
+ return v
196
+ if v.lower() in ("yes", "true", "t", "y", "1"):
197
+ return True
198
+ elif v.lower() in ("no", "false", "f", "n", "0"):
199
+ return False
200
+ else:
201
+ raise argparse.ArgumentTypeError("boolean value expected")
src/improved_diffusion/test_util.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch as th
2
+ import numpy as np
3
+
4
+ def compute_logp(args, model, x, input_ids):
5
+ word_emb = model.weight
6
+ sigma = 0.1
7
+ if args.model_arch == '1d-unet':
8
+ x = x.permute(0, 2, 1)
9
+
10
+ bsz, seqlen, dim = x.shape
11
+
12
+ x_flat = x.reshape(-1, x.size(-1)).unsqueeze(0) # 1, bsz*sample*seqlen, dim
13
+ word_emb_flat = word_emb.unsqueeze(1) # vocab, 1, dim
14
+ diff = (x_flat - word_emb_flat) ** 2 # vocab, seqlen, dim
15
+
16
+ logp_expanded = -diff.sum(dim=-1) / (2 * sigma ** 2) # vocab, seqlen
17
+ logp_expanded = logp_expanded.permute((1, 0))
18
+ # print(th.topk(logp_expanded.view(bsz, seqlen, -1), k=5, dim=-1)[0])
19
+ # print(input_ids[0])
20
+ ce = th.nn.CrossEntropyLoss(reduction='none')
21
+ loss = ce(logp_expanded, input_ids.view(-1)).view(bsz, seqlen)
22
+ # print(loss[0])
23
+
24
+ # print(loss.shape)
25
+ return loss
26
+
27
+ def get_weights(model, args):
28
+ if hasattr(model, 'transformer'):
29
+ input_embs = model.transformer.wte # input_embs
30
+ down_proj = model.down_proj
31
+ down_proj_emb = down_proj(input_embs.weight)
32
+ print(down_proj_emb.shape)
33
+ # model = th.nn.Embedding(down_proj_emb.shape[1], down_proj_emb.shape[0])
34
+ model = th.nn.Embedding(down_proj_emb.size(0), down_proj_emb.size(1))
35
+ print(args.emb_scale_factor)
36
+ model.weight.data = down_proj_emb * args.emb_scale_factor
37
+
38
+ elif hasattr(model, 'weight'):
39
+ pass
40
+ else:
41
+ assert NotImplementedError
42
+
43
+ model.weight.requires_grad = False
44
+ return model
45
+
46
+ def denoised_fn_round(args, model, text_emb, t):
47
+ # return text_emb
48
+ thresh_t = 350
49
+ # print(thresh_t)
50
+ # print(t)
51
+ if thresh_t is not None and t[0] > thresh_t:
52
+ return text_emb
53
+ # return text_emb
54
+ # print(t.float().mean(), t[0])
55
+
56
+ # assert t.float().mean() == t[0].float()
57
+
58
+ # print(text_emb.shape) # bsz, seqlen, dim
59
+ # down_proj_emb = model.weight # input_embs
60
+ down_proj_emb = model
61
+ # print(t)
62
+ old_shape = text_emb.shape
63
+ old_device = text_emb.device
64
+
65
+ def get_efficient_knn(down_proj_emb, text_emb, dist='l2'):
66
+ if dist == 'l2':
67
+ emb_norm = (down_proj_emb**2).sum(-1).view(-1, 1) #vocab
68
+ text_emb_t = th.transpose(text_emb.view(-1, text_emb.size(-1)), 0, 1) #d, bsz*seqlen
69
+ arr_norm = (text_emb ** 2).sum(-1).view(-1, 1) #bsz*seqlen, 1
70
+ # print(emb_norm.shape, arr_norm.shape)
71
+ dist = emb_norm + arr_norm.transpose(0, 1) - 2.0 * th.mm(down_proj_emb, text_emb_t) #(vocab, d) x (d, bsz*seqlen)
72
+ dist = th.clamp(dist, 0.0, np.inf)
73
+ # print(dist.shape)
74
+ topk_out = th.topk(-dist, k=1, dim=0)
75
+ # adjacency = down_proj_emb.unsqueeze(1).expand(-1, text_emb.size(0), -1) - text_emb.unsqueeze(0).expand(
76
+ # down_proj_emb.size(0), -1, -1)
77
+ # adjacency = -th.norm(adjacency, dim=-1)
78
+ # topk_out = th.topk(adjacency, k=1, dim=0)
79
+ # print(topk_out1.indices == topk_out.indices)
80
+ # assert th.all(topk_out1.indices == topk_out.indices)
81
+ return topk_out.values, topk_out.indices
82
+
83
+ # def get_knn(down_proj_emb, text_emb, dist='l2'):
84
+ # if dist == 'l2':
85
+ # adjacency = down_proj_emb.unsqueeze(1).expand(-1, text_emb.size(0), -1) - text_emb.unsqueeze(0).expand(
86
+ # down_proj_emb.size(0), -1, -1)
87
+ # adjacency = -th.norm(adjacency, dim=-1)
88
+ # topk_out = th.topk(adjacency, k=1, dim=0)
89
+ # return topk_out.values, topk_out.indices
90
+
91
+ dist = 'l2'
92
+ if len(text_emb.shape) > 2:
93
+ text_emb = text_emb.reshape(-1, text_emb.size(-1))
94
+ else:
95
+ text_emb = text_emb
96
+ # val, indices = get_knn(down_proj_emb,
97
+ # text_emb.to(down_proj_emb.device), dist=dist)
98
+ val, indices = get_efficient_knn(down_proj_emb,
99
+ text_emb.to(down_proj_emb.device), dist=dist)
100
+ rounded_tokens = indices[0]
101
+ # print(rounded_tokens.shape)
102
+ new_embeds = model[rounded_tokens].view(old_shape).to(old_device)
103
+ return new_embeds
104
+
105
+ def load_results(json_path, load_dict):
106
+ import json
107
+ with open(json_path, 'w') as f:
108
+ json.dump(load_dict, f, indent=2)
src/improved_diffusion/text_datasets.py ADDED
@@ -0,0 +1,948 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from PIL import Image
2
+ # import blobfile as bf
3
+ from mpi4py import MPI
4
+ import numpy as np
5
+ from torch.utils.data import DataLoader, Dataset
6
+ from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer, default_data_collator, PreTrainedTokenizerFast, \
7
+ PreTrainedTokenizer
8
+ # from datasets import load_dataset
9
+ import sys, os
10
+ import torch
11
+ # sys.path.insert(0, os.path.join(sys.path[0], '../../transformers/examples/pytorch/language-modeling'))
12
+ # from custom_trainer import GPT2LMHeadModelCompress, BERTModelCompress, AutoEncoderWithNoise
13
+ from collections import Counter, defaultdict
14
+ from functools import partial
15
+ from itertools import chain
16
+
17
+
18
+ def load_data_text(
19
+ *, data_dir, batch_size, image_size, class_cond=False, deterministic=False, data_args=None,
20
+ task_mode='roc', model=None, padding_mode='block', split='train', load_vocab=None,
21
+ ):
22
+ """
23
+ For a dataset, create a generator over (images, kwargs) pairs.
24
+
25
+ Each images is an NCHW float tensor, and the kwargs dict contains zero or
26
+ more keys, each of which map to a batched Tensor of their own.
27
+ The kwargs dict can be used for class labels, in which case the key is "y"
28
+ and the values are integer tensors of class labels.
29
+
30
+ :param data_dir: a dataset directory.
31
+ :param batch_size: the batch size of each returned pair.
32
+ :param image_size: the size to which images are resized.
33
+ :param class_cond: if True, include a "y" key in returned dicts for class
34
+ label. If classes are not available and this is true, an
35
+ exception will be raised.
36
+ :param deterministic: if True, yield results in a deterministic order.
37
+ """
38
+ print('hello loading text data. ')
39
+
40
+ if data_args.experiment.startswith('random') and model is None:
41
+ model = None
42
+ # elif data_args.experiment.startswith('random') and model is not None:
43
+ # print('loading initialized random embeddings. ')
44
+
45
+ if task_mode == 'roc' or task_mode == 'roc-aug' :
46
+ pass
47
+ # training_data, model = get_corpus_rocstory(data_args, model, image_size,
48
+ # padding_mode=padding_mode, split=split,
49
+ # load_vocab=load_vocab)
50
+ elif task_mode == 'simple-wiki':
51
+ pass
52
+ # training_data, model = get_corpus_rocstory(data_args, model, image_size,
53
+ # padding_mode=padding_mode, split=split,
54
+ # load_vocab=load_vocab)
55
+
56
+ elif task_mode == 'e2e-tgt':
57
+ print('hello loading e2e-tgt. ')
58
+ training_data, model = get_corpus_rocstory(data_args, model, image_size,
59
+ padding_mode=padding_mode, split=split,
60
+ load_vocab=load_vocab)
61
+ # elif task_mode == 'yelp':
62
+ # print('hello loading yelp ')
63
+ # training_data, model = get_corpus_rocstory(data_args, model, image_size,
64
+ # padding_mode=padding_mode, split=split,
65
+ # load_vocab=load_vocab)
66
+
67
+ # elif task_mode == 'commonGen' or task_mode == 'commonGen-aug':
68
+ # print('hello loading common-gen ')
69
+ # training_data, model = get_corpus_rocstory(data_args, model, image_size,
70
+ # padding_mode=padding_mode, split=split,
71
+ # load_vocab=load_vocab)
72
+
73
+ # elif task_mode == 'e2e':
74
+ # training_data, model = get_corpus_rocstory(data_args, model, image_size,
75
+ # padding_mode=padding_mode, split=split,
76
+ # load_vocab=load_vocab)
77
+
78
+ # elif task_mode == 'book':
79
+ # tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
80
+ # training_data, model = get_corpus_book(data_args, tokenizer, model, image_size,
81
+ # padding_mode=padding_mode, split=split,)
82
+
83
+ if data_args.modality in ['roc-aug', 'roc', 'book', 'yelp', 'commonGen', 'commonGen-aug'] and data_args.cache_mode=='no':
84
+ pass# dataset = TextDataset_NoCache(
85
+ # training_data,
86
+ # image_size,
87
+ # data_args,
88
+ # model_arch=data_args.model_arch,
89
+ # model_emb=model
90
+ # )
91
+ else:
92
+ dataset = TextDataset(
93
+ training_data,
94
+ image_size,
95
+ data_args,
96
+ model_arch=data_args.model_arch,
97
+ )
98
+
99
+ if deterministic:
100
+
101
+ pass# data_loader = DataLoader(
102
+ # dataset,
103
+ # batch_size=batch_size, # 20,
104
+ # drop_last=True,
105
+ # shuffle=False,
106
+ # num_workers=1,
107
+ # )
108
+
109
+ else:
110
+ data_loader = DataLoader(
111
+ dataset,
112
+ batch_size=batch_size, # 20,
113
+ drop_last=True,
114
+ shuffle=True,
115
+ num_workers=1,
116
+ )
117
+ while True:
118
+ yield from data_loader
119
+
120
+ def helper_tokenize_encode_cond(sentence_lst, vocab_dict, model, seqlen, data_args):
121
+ result_train_lst = []
122
+ group_lst = defaultdict(list)
123
+ with torch.no_grad():
124
+ for (src_ids, input_ids) in sentence_lst:
125
+ tokenized_ = [vocab_dict.get(x, vocab_dict['UNK']) for x in input_ids]
126
+ tokenized_src = [vocab_dict.get(x, vocab_dict['UNK']) for x in src_ids]
127
+ input_ids = [0] + tokenized_ + [1]
128
+ group_lst['word_ids'].append(input_ids)
129
+ group_lst['src_ids'].append(tokenized_src)
130
+
131
+ print(group_lst['word_ids'][:2])
132
+ print('padding mode is pad')
133
+ max_length = seqlen
134
+ group_lst['word_ids'] = _collate_batch_helper(group_lst['word_ids'], vocab_dict['PAD'], max_length)
135
+ max_src_length = max([len(xx) for xx in group_lst['src_ids']])
136
+ print(max_src_length, seqlen)
137
+ max_src_length = min(seqlen, max_src_length)
138
+ group_lst['src_ids'], group_lst['src_mask'] = _collate_batch_helper(group_lst['src_ids'],
139
+ vocab_dict['PAD'],
140
+ max_src_length,
141
+ return_mask=True)
142
+
143
+
144
+ for input_ids, src_ids, src_mask in zip(group_lst['word_ids'], group_lst['src_ids'],
145
+ group_lst['src_mask']):
146
+ if data_args.experiment.startswith('random'):
147
+ hidden_state = model(torch.tensor(input_ids))
148
+ elif data_args.experiment == 'gpt2_pre_compress':
149
+ input_ids2 = torch.tensor(input_ids).to(model.device)
150
+ input_embs = model.transformer.wte(input_ids2) # input_embs
151
+ hidden_state = model.down_proj(input_embs)
152
+ hidden_state = hidden_state * data_args.emb_scale_factor
153
+ result_train_lst.append({'input_ids': input_ids,
154
+ 'hidden_states': hidden_state.cpu().tolist(),
155
+ 'src_ids':src_ids,
156
+ 'src_mask':src_mask
157
+ })
158
+
159
+ return result_train_lst
160
+
161
+ def helper_tokenize_stream(sentence_lst, vocab_dict, model, seqlen, data_args, padding_mode, ):
162
+ import psutil
163
+ # Process.memory_info is expressed in bytes, so convert to megabytes
164
+ print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")
165
+ from datasets import Dataset as Dataset2
166
+ raw_datasets = Dataset2.from_dict({'text':sentence_lst})
167
+ print(raw_datasets)
168
+ print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")
169
+
170
+
171
+ def tokenize_function(examples):
172
+ if isinstance(vocab_dict, dict):
173
+ input_ids = [[0] + [vocab_dict.get(x, vocab_dict['UNK']) for x in seq] + [1] for seq in examples['text']]
174
+ elif isinstance(vocab_dict, PreTrainedTokenizerFast):
175
+ examples['text'] = [" ".join(seq) for seq in examples['text']]
176
+ input_ids = vocab_dict(examples['text'], add_special_tokens=True)['input_ids']
177
+ result_dict = {'input_ids': input_ids}
178
+ # clm input could be much much longer than block_size
179
+ return result_dict
180
+
181
+ tokenized_datasets = raw_datasets.map(
182
+ tokenize_function,
183
+ batched=True,
184
+ num_proc=4,
185
+ remove_columns=['text'],
186
+ load_from_cache_file=True,
187
+ desc="Running tokenizer on dataset",
188
+ )
189
+ print(tokenized_datasets)
190
+ print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")
191
+
192
+ if padding_mode == 'block':
193
+ block_size = seqlen
194
+ def group_texts(examples):
195
+ concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
196
+ total_length = len(concatenated_examples[list(examples.keys())[0]])
197
+ if total_length >= block_size:
198
+ total_length = (total_length // block_size) * block_size
199
+ result = {
200
+ k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
201
+ for k, t in concatenated_examples.items()
202
+ }
203
+ result["labels"] = result["input_ids"].copy()
204
+ return result
205
+
206
+
207
+ lm_datasets = tokenized_datasets.map(
208
+ group_texts,
209
+ batched=True,
210
+ num_proc=data_args.preprocessing_num_workers,
211
+ load_from_cache_file=not data_args.overwrite_cache,
212
+ desc=f"Grouping texts in chunks of {block_size}",
213
+ )
214
+ else:
215
+ def pad_function(group_lst):
216
+ max_length = seqlen
217
+ if isinstance(vocab_dict, dict):
218
+ group_lst['input_ids'] = _collate_batch_helper(group_lst['input_ids'], vocab_dict['PAD'], max_length)
219
+ else:
220
+ group_lst['input_ids'] = _collate_batch_helper(group_lst['input_ids'], vocab_dict.pad_token_id, max_length)
221
+ return group_lst
222
+
223
+ # Process.memory_info is expressed in bytes, so convert to megabytes
224
+ print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")
225
+
226
+ lm_datasets = tokenized_datasets.map(
227
+ pad_function,
228
+ batched=True,
229
+ num_proc=1,
230
+ desc=f"padding",
231
+ )
232
+
233
+
234
+ print(lm_datasets, 'padded dataset')
235
+ print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")
236
+ import datasets
237
+ raw_datasets = datasets.DatasetDict()
238
+ raw_datasets['train'] = lm_datasets
239
+ print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")
240
+ return raw_datasets
241
+
242
+ def helper_tokenize_encode(sentence_lst, vocab_dict, model, seqlen, data_args, padding_mode, ):
243
+ result_train_lst = []
244
+ group_lst = defaultdict(list)
245
+ with torch.no_grad():
246
+ for input_ids in sentence_lst:
247
+ tokenized_ = [vocab_dict.get(x, vocab_dict['UNK']) for x in input_ids]
248
+ input_ids = [0] + tokenized_ + [1]
249
+ group_lst['word_ids'].append(input_ids)
250
+ print(group_lst['word_ids'][:2])
251
+
252
+ if padding_mode == 'block':
253
+ print('padding mode is block')
254
+ concatenated_examples = {k: sum(group_lst[k], []) for k in group_lst.keys()}
255
+ total_length = len(concatenated_examples[list(group_lst.keys())[0]])
256
+ block_size = seqlen
257
+ total_length = (total_length // block_size) * block_size
258
+ # Split by chunks of max_len.
259
+ group_lst = {
260
+ k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
261
+ for k, t in concatenated_examples.items()
262
+ }
263
+ elif padding_mode == 'pad':
264
+ print('padding mode is pad')
265
+ max_length = seqlen
266
+ group_lst['word_ids'] = _collate_batch_helper(group_lst['word_ids'], vocab_dict['PAD'], max_length)
267
+
268
+ for input_ids in group_lst['word_ids']:
269
+ if data_args.experiment.startswith('random'):
270
+ hidden_state = model(torch.tensor(input_ids))
271
+ elif data_args.experiment == 'gpt2_pre_compress':
272
+ input_ids2 = torch.tensor(input_ids).to(model.device)
273
+ input_embs = model.transformer.wte(input_ids2) # input_embs
274
+ hidden_state = model.down_proj(input_embs)
275
+ hidden_state = hidden_state * data_args.emb_scale_factor
276
+ elif data_args.experiment == 'glove':
277
+ hidden_state = model(torch.tensor(input_ids))
278
+ result_train_lst.append({'input_ids': input_ids, 'hidden_states': hidden_state.cpu().tolist()})
279
+
280
+ return result_train_lst
281
+
282
+ def load_glove_model(File):
283
+ print("Loading Glove Model")
284
+ glove_model = {}
285
+ with open(File,'r') as f:
286
+ for line in f:
287
+ split_line = line.split()
288
+ word = split_line[0]
289
+ embedding = torch.tensor(np.array(split_line[1:], dtype=np.float64))
290
+ # embedding = np.array(split_line[1:], dtype=np.float64)
291
+ glove_model[word] = embedding
292
+ print(f"{len(glove_model)} words loaded!")
293
+ return glove_model
294
+
295
+ def load_glove(vocab):
296
+ model = torch.nn.Embedding(len(vocab), 50)
297
+ glove_model = load_glove_model('predictability/glove/glove.6B.50d.txt')
298
+ array_lst = []
299
+ count_ = 0
300
+ for word, idx in vocab.items():
301
+ if word in glove_model:
302
+ array_lst.append(glove_model[word])
303
+ else:
304
+ count_ += 1
305
+ array_lst.append(torch.randn(50))
306
+ print(f'{count_} out of {len(vocab)} is initialized. ')
307
+ array_lst = torch.stack(array_lst)
308
+ print(torch.norm(array_lst, dim=-1).mean())
309
+ model.weight.data = array_lst
310
+ return model
311
+
312
+
313
+ def get_corpus_rocstory(data_args, model, image_size, padding_mode='block',
314
+ split='train', load_vocab=None):
315
+ import csv, torch, json
316
+ from spacy.lang.en import English
317
+
318
+ if data_args.experiment_mode == 'lm':
319
+ if data_args.modality == 'roc':
320
+ pass
321
+ # print('loading dataset from ROCStory')
322
+ # nlp = English()
323
+ # tokenizer = nlp.tokenizer
324
+ # sentence_lst = []
325
+ # print(f'loading from {data_args.roc_train}')
326
+ # if split == 'train':
327
+ # print('loading form the TRAIN set')
328
+ # path = f'{data_args.roc_train}/roc_train.json'
329
+ # elif split == 'valid':
330
+ # print('loading form the VALID set')
331
+ # path = f'{data_args.roc_train}/roc_valid.json'
332
+ # else:
333
+ # assert False, "invalid split for ROC dataset"
334
+
335
+ # with open(path, 'r') as roc_reader:
336
+ # for row in roc_reader:
337
+ # sentences = json.loads(row)[0].strip()
338
+ # word_lst = [x.text for x in tokenizer(sentences)]
339
+ # sentence_lst.append(word_lst)
340
+
341
+ # # with open(data_args.roc_train, 'r') as csvfile:
342
+ # # roc_reader = csv.reader(csvfile) #delimiter=' ', quotechar='|')
343
+ # # for row in roc_reader:
344
+ # # # tokenize.
345
+ # # sentences = " ".join(row[2:])
346
+ # # word_lst = [x.text for x in tokenizer(sentences)]
347
+ # # sentence_lst.append(word_lst)
348
+ # # sentence_lst = sentence_lst[1:]
349
+ # print(sentence_lst[:2])
350
+ if data_args.modality == 'roc-aug':
351
+ pass
352
+ # print('loading dataset from ROCStory')
353
+ # nlp = English()
354
+ # tokenizer = nlp.tokenizer
355
+ # sentence_lst = []
356
+ # if split == 'train':
357
+ # print('loading form the TRAIN set')
358
+ # path_lst = [f'{data_args.roc_train}/roc_train.json']
359
+ # path_lst.append('diffusion_lm/improved-diffusion/diff_models/rocstories_gptj.txt')
360
+ # # path_lst.append('diffusion_lm/improved-diffusion/cache/ar_model_augment_roc.json')
361
+ # # path_lst.append('diffusion_lm/improved-diffusion/cache/ar_model_augment_roc2.json')
362
+
363
+ # elif split == 'valid':
364
+ # print('loading form the VALID set')
365
+ # path_lst = [f'{data_args.roc_train}/roc_valid.json']
366
+ # else:
367
+ # assert False, "invalid split for ROC dataset"
368
+
369
+ # print(path_lst)
370
+ # for path in path_lst:
371
+ # if path.endswith('txt'):
372
+ # with open(path, 'r') as roc_reader:
373
+ # for row in roc_reader:
374
+ # sentences = row.strip()
375
+ # word_lst = [x.text for x in tokenizer(sentences)]
376
+ # sentence_lst.append(word_lst)
377
+ # else:
378
+ # with open(path, 'r') as roc_reader:
379
+ # for row in roc_reader:
380
+ # sentences = json.loads(row)[0].strip()
381
+ # word_lst = [x.text for x in tokenizer(sentences)]
382
+ # sentence_lst.append(word_lst)
383
+ # print(sentence_lst[:2],sentence_lst[-2:], 'dataset size=',len(sentence_lst))
384
+ elif data_args.modality == 'simple-wiki':
385
+ pass
386
+ # print('loading dataset from simple wikipedia')
387
+ # sentence_lst = []
388
+ # with open(data_args.wiki_train, 'r') as ff:
389
+ # for row in ff:
390
+ # word_lst = row.lower().split()
391
+ # sentence_lst.append(word_lst)
392
+ # print(sentence_lst[:2])
393
+ elif data_args.modality == 'e2e-tgt':
394
+ print('loading dataset from simple e2e dataset')
395
+ sentence_lst = []
396
+ nlp = English()
397
+ tokenizer = nlp.tokenizer
398
+ if split == 'train':
399
+ print('loading form the TRAIN set')
400
+ path = '/data0/gonghaisong/Diffusion-LM/datasets/e2e_data/src1_train.txt'
401
+ # path = f'../{data_args.e2e_train}/src1_train.txt'
402
+ elif split == 'valid':
403
+ print('loading form the VALID set')
404
+ path = f'../{data_args.e2e_train}/src1_valid.txt'
405
+ path = '/data0/gonghaisong/Diffusion-LM/datasets/e2e_data/src1_valid.txt'
406
+ elif split == 'test':
407
+ print('loading form the TEST set')
408
+ path = f'../{data_args.e2e_train}/src1_test.txt'
409
+ path = '/data0/gonghaisong/Diffusion-LM/datasets/e2e_data/src1_test.txt'
410
+ elif split == 'debug':
411
+ print('loading form the DEBUG set')
412
+ path = data_args.debug_path
413
+ import json
414
+ with open(path, 'r') as ff:
415
+ for line in ff:
416
+ sentence_lst.append(json.loads(line)[0].split(' '))
417
+ sentence_lst = sentence_lst + sentence_lst
418
+ if split in ['train', 'valid', 'test']:
419
+ with open(path, 'r') as ff:
420
+ for row in ff:
421
+ word_lst = row.split('||')[1]
422
+ word_lst = [x.text for x in tokenizer(word_lst)]
423
+ sentence_lst.append(word_lst)
424
+ print(sentence_lst[:2])
425
+
426
+ elif data_args.modality == 'yelp':
427
+ print('loading dataset from simple YelpNLG dataset')
428
+ sentence_lst = []
429
+ nlp = English()
430
+ tokenizer = nlp.tokenizer
431
+ if split == 'train':
432
+ print('loading form the TRAIN set')
433
+ path = f'{data_args.yelp_train}/yelpnlg-train.csv'
434
+ elif split == 'valid':
435
+ print('loading form the VALID set')
436
+ path = f'{data_args.yelp_train}/yelpnlg-dev.csv'
437
+ elif split == 'test':
438
+ print('loading form the TEST set')
439
+ path = f'{data_args.yelp_train}/yelpnlg-test.csv'
440
+ if split in ['train', 'valid', 'test']:
441
+
442
+ with open(path, 'r') as csvfile:
443
+ yelp_reader = csv.reader(csvfile) #delimiter=' ', quotechar='|')
444
+ for row in yelp_reader:
445
+ sentences = row[1]
446
+ word_lst = [x.text for x in tokenizer(sentences)]
447
+ sentence_lst.append(word_lst)
448
+ sentence_lst = sentence_lst[1:]
449
+ print(sentence_lst[:2])
450
+
451
+ elif data_args.modality == 'commonGen':
452
+ print('loading dataset from simple YelpNLG dataset')
453
+ sentence_lst = []
454
+ nlp = English()
455
+ tokenizer = nlp.tokenizer
456
+ if split == 'train':
457
+ print('loading form the TRAIN set')
458
+ path = f'{data_args.commonGen_train}/commongen.train.jsonl'
459
+ elif split == 'valid':
460
+ print('loading form the VALID set')
461
+ path = f'{data_args.commonGen_train}/commongen.dev.jsonl'
462
+ elif split == 'test':
463
+ print('loading form the TEST set')
464
+ path = f'{data_args.commonGen_train}/commongen.test.jsonl'
465
+ if split in ['train', 'valid', 'test']:
466
+ with open(path, 'r') as ff:
467
+ for line in ff:
468
+ line = json.loads(line)
469
+ for sentences in line['scene']:
470
+ word_lst = [x.text for x in tokenizer(sentences)]
471
+ sentence_lst.append(word_lst)
472
+ print(sentence_lst[:2])
473
+
474
+ elif data_args.modality == 'commonGen-aug':
475
+ print('loading dataset from simple YelpNLG dataset')
476
+ sentence_lst = []
477
+ nlp = English()
478
+ tokenizer = nlp.tokenizer
479
+ if split == 'train':
480
+ print('loading form the TRAIN set')
481
+ path = f'{data_args.commonGen_train}/commongen.train.jsonl'
482
+ path_lst = [f'{data_args.roc_train}/roc_train.json']
483
+ path_lst.append('diffusion_lm/improved-diffusion/diff_models/rocstories_gptj.txt')
484
+ elif split == 'valid':
485
+ print('loading form the VALID set')
486
+ path = f'{data_args.commonGen_train}/commongen.dev.jsonl'
487
+ path_lst = []
488
+ elif split == 'test':
489
+ print('loading form the TEST set')
490
+ path = f'{data_args.commonGen_train}/commongen.test.jsonl'
491
+ path_lst = []
492
+
493
+ if split in ['train', 'valid', 'test']:
494
+ with open(path, 'r') as ff:
495
+ for line in ff:
496
+ line = json.loads(line)
497
+ for sentences in line['scene']:
498
+ word_lst = [x.text for x in tokenizer(sentences)]
499
+ sentence_lst.append(word_lst)
500
+ print(sentence_lst[:2])
501
+ import itertools
502
+ for path in path_lst:
503
+ if path.endswith('txt'):
504
+ with open(path, 'r') as roc_reader:
505
+ for row in roc_reader:
506
+ sentences = row.strip()
507
+ word_lst = [x.text for x in tokenizer(sentences)]
508
+ spl = [[]]
509
+ for x, y in itertools.groupby(word_lst, lambda z: z == '.'):
510
+ spl[-1].extend(y)
511
+ if x: spl.append([])
512
+ sentence_lst.extend(spl[:-1])
513
+ else:
514
+ with open(path, 'r') as roc_reader:
515
+ for row in roc_reader:
516
+ sentences = json.loads(row)[0].strip()
517
+ word_lst = [x.text for x in tokenizer(sentences)]
518
+ spl = [[]]
519
+ for x, y in itertools.groupby(word_lst, lambda z: z == '.'):
520
+ spl[-1].extend(y)
521
+ if x: spl.append([])
522
+ sentence_lst.extend(spl[:-1])
523
+
524
+ print(sentence_lst[-2:])
525
+
526
+
527
+ # get tokenizer.
528
+ if load_vocab is None:
529
+ counter = Counter()
530
+ for input_ids in sentence_lst:
531
+ counter.update(input_ids)
532
+
533
+ if data_args.experiment_mode == 'conditional_gen':
534
+ if data_args.modality == 'e2e':
535
+ print('loading dataset from simple e2e dataset')
536
+ sentence_lst = []
537
+ nlp = English()
538
+ tokenizer = nlp.tokenizer
539
+ if split == 'train':
540
+ path = f'{data_args.e2e_train}/src1_train.txt'
541
+ with open(path, 'r') as ff:
542
+ for row in ff:
543
+ src_lst, word_lst = row.split('||')
544
+ word_lst = [x.text for x in tokenizer(word_lst)]
545
+ src_lst = [x.text for x in tokenizer(src_lst)]
546
+ sentence_lst.append((src_lst, word_lst))
547
+ elif split == 'valid':
548
+ path = f'{data_args.e2e_train}/src1_valid.txt'
549
+ sentence_lst = read_e2e_files(path, data_args, tokenizer)
550
+ print(sentence_lst[:2])
551
+ # get tokenizer.
552
+ if load_vocab is None:
553
+ counter = Counter()
554
+ for (src_ids, input_ids) in sentence_lst:
555
+ counter.update(input_ids)
556
+ counter.update(src_ids)
557
+
558
+ if load_vocab is None:
559
+ vocab_dict = {'START': 0, 'END': 1, 'UNK':2, 'PAD':3}
560
+ for k, v in counter.items():
561
+ if v > 10:
562
+ vocab_dict[k] = len(vocab_dict)
563
+ print(len(counter), len(vocab_dict))
564
+
565
+ path_save_vocab = '/data0/gonghaisong/Diffusion-LM/improved-diffusion/diffusion_models/diff_e2e-tgt_block_rand16_transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd102_xstart_e2e/vocab.json'
566
+ print(f'save the vocab to {path_save_vocab}')
567
+ with open(path_save_vocab, 'w') as f:
568
+ json.dump(vocab_dict, f)
569
+ else:
570
+ vocab_dict = load_vocab
571
+ path_save_vocab = '/data0/gonghaisong/Diffusion-LM/improved-diffusion/diffusion_models/diff_e2e-tgt_block_rand16_transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd102_xstart_e2e/vocab.json'
572
+ if not os.path.exists(path_save_vocab):
573
+ print(f'save the vocab to {path_save_vocab}')
574
+ if isinstance(vocab_dict, dict):
575
+ with open(path_save_vocab, 'w') as f:
576
+ json.dump(vocab_dict, f)
577
+ assert vocab_dict['START'] == 0
578
+ elif isinstance(vocab_dict, PreTrainedTokenizerFast):
579
+ vocab_dict.save_pretrained(data_args.checkpoint_path)
580
+ else:
581
+ assert False, "invalid type of vocab_dict"
582
+
583
+
584
+
585
+ if model is None and data_args.experiment == 'random':
586
+ model = torch.nn.Embedding(len(vocab_dict), data_args.in_channel)
587
+ print('initializing the random embeddings', model)
588
+ torch.nn.init.normal_(model.weight)
589
+ path_save = '/data0/gonghaisong/Diffusion-LM/improved-diffusion/diffusion_models/diff_e2e-tgt_block_rand16_transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd102_xstart_e2e/random_emb.torch'
590
+ print(f'save the random encoder to {data_args.checkpoint_path}/random_emb.torch')
591
+ torch.save(model.state_dict(), path_save)
592
+
593
+ # path_save = f'{data_args.checkpoint_path}/random_emb.torch'
594
+ # if not os.path.exists(path_save) and data_args.experiment == 'random':
595
+ # torch.save(model.state_dict(), path_save)
596
+
597
+
598
+ if data_args.experiment_mode == 'lm' and data_args.modality in ['roc-aug', 'roc', 'yelp', 'commonGen', 'commonGen-aug'] \
599
+ and data_args.cache_mode=='no':
600
+ train_dataset = helper_tokenize_stream(sentence_lst, vocab_dict, model, image_size**2, data_args, padding_mode)
601
+ return train_dataset, model
602
+ elif data_args.experiment_mode == 'lm':
603
+ result_train_lst = helper_tokenize_encode(sentence_lst, vocab_dict, model, image_size**2, data_args, padding_mode)
604
+ elif data_args.experiment_mode == 'conditional_gen':
605
+ result_train_lst = helper_tokenize_encode_cond(sentence_lst, vocab_dict, model, image_size ** 2, data_args)
606
+ return {'train': result_train_lst}, model
607
+
608
+
609
+ def write_e2e_corr(prompt_lst, file_dict, corr_path):
610
+ print(len(prompt_lst))
611
+ with open(corr_path, 'w') as f:
612
+ for x in prompt_lst:
613
+ for line in file_dict[x]:
614
+ print(" ".join(line), file=f)
615
+ print('', file=f)
616
+
617
+
618
+ def write_e2e_src(prompt_lst, corr_path):
619
+ with open(corr_path, 'w') as f:
620
+ for x in prompt_lst:
621
+ print(" ".join(x), file=f)
622
+ return
623
+
624
+
625
+ def read_e2e_files(path, args, tokenizer):
626
+ file_dict = {}
627
+ with open(path, 'r') as f:
628
+ for line in f:
629
+ src_lst, word_lst = line.strip().split('||')
630
+ tgt = tuple([x.text for x in tokenizer(word_lst)])
631
+ src = tuple([x.text for x in tokenizer(src_lst)])
632
+ if src not in file_dict:
633
+ file_dict[src] = []
634
+ file_dict[src].append(tgt)
635
+ temp = '1'
636
+ prompt_text_dict = file_dict
637
+ prompt_text_lst = list(prompt_text_dict.keys())
638
+ gold_dir = os.path.join(args.out_dir, '{}_{}_{}'.format(temp, args.split, 'gold'))
639
+ print("gold dir", gold_dir)
640
+ write_e2e_corr(prompt_text_lst, prompt_text_dict, gold_dir)
641
+ src_dir = os.path.join(args.out_dir, '{}_{}_{}'.format(temp, args.split, 'src'))
642
+ write_e2e_src(prompt_text_lst, src_dir)
643
+ final_lst = [(xx, prompt_text_dict[xx][0]) for xx in prompt_text_lst]
644
+ return final_lst
645
+
646
+
647
+ def get_corpus_book(data_args, tokenizer, model, image_size, padding_mode='block', split='train',):
648
+ max_length = image_size ** 2
649
+ import os
650
+ assert padding_mode == 'block'
651
+ raw_datasets = load_dataset('bookcorpus')
652
+ if "validation" not in raw_datasets.keys():
653
+ raw_datasets["validation"] = load_dataset(
654
+ 'bookcorpus',
655
+ split=f"train[:1%]",
656
+ )
657
+ raw_datasets["train"] = load_dataset(
658
+ 'bookcorpus',
659
+ split=f"train[1%:]",
660
+ )
661
+ print(raw_datasets)
662
+ column_names = raw_datasets["train"].column_names
663
+
664
+ def tokenize_function(examples):
665
+ output = tokenizer(examples['text'], add_special_tokens=False)
666
+ return output
667
+
668
+
669
+ tokenized_datasets = raw_datasets.map(
670
+ tokenize_function,
671
+ batched=True,
672
+ num_proc=data_args.preprocessing_num_workers,
673
+ remove_columns=column_names,
674
+ load_from_cache_file=True,
675
+ )
676
+
677
+ print(tokenized_datasets)
678
+
679
+ block_size = max_length
680
+
681
+ # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
682
+ def group_texts(examples):
683
+ # Concatenate all texts.
684
+ concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
685
+ total_length = len(concatenated_examples[list(examples.keys())[0]])
686
+ if total_length >= block_size:
687
+ total_length = (total_length // block_size) * block_size
688
+ result = {
689
+ k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
690
+ for k, t in concatenated_examples.items()
691
+ }
692
+ return result
693
+
694
+ lm_datasets = tokenized_datasets.map(
695
+ group_texts,
696
+ batched=True,
697
+ num_proc=4,
698
+ load_from_cache_file=True,
699
+ desc=f"Grouping texts in chunks of {block_size}",
700
+ )
701
+
702
+ print(lm_datasets)
703
+
704
+ if model is None:
705
+ if data_args.training_mode.startswith('e2e'):
706
+ print('since its e2e, initialize a dummy embedding' )
707
+ model = torch.nn.Embedding(len(tokenizer), 1)
708
+ else:
709
+ model = torch.nn.Embedding(len(tokenizer), data_args.in_channel)
710
+ print('initializing the random embeddings', model)
711
+ torch.nn.init.normal_(model.weight)
712
+ path_save = f'{data_args.checkpoint_path}/random_emb.torch'
713
+ print(f'save the random encoder to {data_args.checkpoint_path}/random_emb.torch')
714
+ torch.save(model.state_dict(), path_save)
715
+
716
+ if split == 'train':
717
+ return lm_datasets, model
718
+ else:
719
+ lm_datasets['train'] = lm_datasets['validation']
720
+ return lm_datasets, model
721
+
722
+
723
+ class TextDataset(Dataset):
724
+ def __init__(self, text_datasets, resolution, data_args, model_arch='conv-unet',
725
+ classes=None, shard=0, num_shards=1, eigen_transform=None,
726
+ mapping_func=None, model_emb=None):
727
+ super().__init__()
728
+ self.resolution = resolution
729
+ self.text_datasets = text_datasets
730
+ self.length = len(self.text_datasets['train'])
731
+ self.model_arch = model_arch
732
+ self.data_args = data_args
733
+ print(self.resolution)
734
+ self.eigen_transform = eigen_transform
735
+ self.mapping_func = mapping_func
736
+ self.model_emb = model_emb
737
+ # self.local_images = image_paths[shard:][::num_shards]
738
+ # self.local_classes = None if classes is None else classes[shard:][::num_shards]
739
+
740
+ def __len__(self):
741
+ return self.length
742
+
743
+ def __getitem__(self, idx):
744
+
745
+ # We are not on a new enough PIL to support the `reducing_gap`
746
+ # argument, which uses BOX downsampling at powers of two first.
747
+ # Thus, we do it by hand to improve downsample quality.
748
+ if self.model_arch == 'conv-unet':
749
+ pass# arr = np.array(self.text_datasets['train'][idx]['hidden_states'],
750
+ # dtype=np.float32).reshape(self.resolution, self.resolution, -1)
751
+ # # print(self.eigen_transform.shape)
752
+ # if self.eigen_transform is not None:
753
+ # old_shape = arr.shape
754
+ # arr = arr.reshape(1, -1) - self.eigen_transform['mean']
755
+ # arr = arr @ self.eigen_transform['map']
756
+ # arr = arr.reshape(old_shape)
757
+ # if hasattr(self.data_args, 'noise_level') and self.data_args.noise_level > 0:
758
+ # arr = arr + self.data_args.noise_level * np.random.randn(*arr.shape).astype(arr.dtype)
759
+
760
+
761
+ # out_dict = {}
762
+ # out_dict['input_ids'] = np.array(self.text_datasets['train'][idx]['input_ids'])
763
+ # # if self.local_classes is not None:
764
+ # # out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
765
+ # # print(out_dict.keys())
766
+ # return np.transpose(arr, [2, 0, 1]), out_dict
767
+ elif self.model_arch == '1d-unet':
768
+ pass# arr = np.array(self.text_datasets['train'][idx]['hidden_states'],
769
+ # dtype=np.float32) # seqlen, dim
770
+ # if self.eigen_transform is not None:
771
+ # old_shape = arr.shape
772
+ # arr = arr.reshape(1, -1) - self.eigen_transform['mean']
773
+ # arr = arr @ self.eigen_transform['map']
774
+ # arr = arr.reshape(old_shape)
775
+ # if hasattr(self.data_args, 'noise_level') and self.data_args.noise_level > 0:
776
+ # arr = arr + self.data_args.noise_level * np.random.randn(*arr.shape).astype(arr.dtype)
777
+ # arr = np.transpose(arr, [1, 0])
778
+ # out_dict = {}
779
+ # out_dict['input_ids'] = np.array(self.text_datasets['train'][idx]['input_ids'])
780
+ # # out_dict['mapping_func'] = self.mapping_func
781
+ # # if self.local_classes is not None:
782
+ # # out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
783
+ # # print(arr.shape)
784
+ # return arr, out_dict
785
+ else:
786
+ arr = np.array(self.text_datasets['train'][idx]['hidden_states'],
787
+ dtype=np.float32)
788
+ if self.eigen_transform is not None:
789
+ old_shape = arr.shape
790
+ # arr = arr.reshape(1, -1) @ self.eigen_transform
791
+ arr = arr.reshape(1, -1) - self.eigen_transform['mean']
792
+ arr = arr @ self.eigen_transform['map']
793
+ arr = arr.reshape(old_shape)
794
+
795
+ if hasattr(self.data_args, 'noise_level') and self.data_args.noise_level > 0:
796
+ # print(arr.dtype)
797
+ # print(self.data_args.noise_level, 'using the noise level.')
798
+ arr = arr + self.data_args.noise_level * np.random.randn(*arr.shape).astype(arr.dtype)
799
+ # print(arr.dtype)
800
+
801
+ out_dict = {}
802
+ out_dict['input_ids'] = np.array(self.text_datasets['train'][idx]['input_ids'])
803
+ # out_dict['mapping_func'] = self.mapping_func
804
+ if self.data_args.experiment_mode == 'conditional_gen':
805
+ out_dict['src_ids'] = np.array(self.text_datasets['train'][idx]['src_ids'])
806
+ out_dict['src_mask'] = np.array(self.text_datasets['train'][idx]['src_mask'])
807
+ # if self.local_classes is not None:
808
+ # out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
809
+ return arr, out_dict
810
+ # print(arr.dtype)
811
+ # arr = arr.float()
812
+ # print(arr.shape)
813
+
814
+
815
+ class TextDataset_NoCache(Dataset):
816
+ def __init__(self, text_datasets, resolution, data_args, model_arch='conv-unet',
817
+ classes=None, shard=0, num_shards=1, eigen_transform=None,
818
+ mapping_func=None, model_emb=None):
819
+ super().__init__()
820
+ self.resolution = resolution
821
+ self.text_datasets = text_datasets
822
+ self.length = len(self.text_datasets['train'])
823
+ self.model_arch = model_arch
824
+ self.data_args = data_args
825
+ print(self.resolution)
826
+ self.eigen_transform = eigen_transform
827
+ self.mapping_func = mapping_func
828
+ self.model_emb = model_emb
829
+ # self.local_images = image_paths[shard:][::num_shards]
830
+ # self.local_classes = None if classes is None else classes[shard:][::num_shards]
831
+
832
+ def __len__(self):
833
+ return self.length
834
+
835
+ def __getitem__(self, idx):
836
+
837
+ # We are not on a new enough PIL to support the `reducing_gap`
838
+ # argument, which uses BOX downsampling at powers of two first.
839
+ # Thus, we do it by hand to improve downsample quality.
840
+ with torch.no_grad():
841
+ input_ids = self.text_datasets['train'][idx]['input_ids']
842
+ model = self.model_emb
843
+ if self.data_args.experiment.startswith('random'):
844
+ hidden_state = model(torch.tensor(input_ids))
845
+ elif self.data_args.experiment == 'gpt2_pre_compress':
846
+ input_ids2 = torch.tensor(input_ids).to(model.device)
847
+ input_embs = model.transformer.wte(input_ids2) # input_embs
848
+ hidden_state = model.down_proj(input_embs)
849
+ hidden_state = hidden_state * data_args.emb_scale_factor
850
+
851
+ if self.model_arch == 'conv-unet':
852
+ arr = np.array(hidden_state,
853
+ dtype=np.float32).reshape(self.resolution, self.resolution, -1)
854
+ # print(self.eigen_transform.shape)
855
+ if self.eigen_transform is not None:
856
+ old_shape = arr.shape
857
+ arr = arr.reshape(1, -1) - self.eigen_transform['mean']
858
+ arr = arr @ self.eigen_transform['map']
859
+ arr = arr.reshape(old_shape)
860
+ if hasattr(self.data_args, 'noise_level') and self.data_args.noise_level > 0:
861
+ arr = arr + self.data_args.noise_level * np.random.randn(*arr.shape).astype(arr.dtype)
862
+
863
+ out_dict = {}
864
+ out_dict['input_ids'] = np.array(self.text_datasets['train'][idx]['input_ids'])
865
+ # if self.local_classes is not None:
866
+ # out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
867
+ # print(out_dict.keys())
868
+ return np.transpose(arr, [2, 0, 1]), out_dict
869
+ elif self.model_arch == '1d-unet':
870
+ arr = np.array(hidden_state,
871
+ dtype=np.float32) # seqlen, dim
872
+ if self.eigen_transform is not None:
873
+ old_shape = arr.shape
874
+ arr = arr.reshape(1, -1) - self.eigen_transform['mean']
875
+ arr = arr @ self.eigen_transform['map']
876
+ arr = arr.reshape(old_shape)
877
+ if hasattr(self.data_args, 'noise_level') and self.data_args.noise_level > 0:
878
+ arr = arr + self.data_args.noise_level * np.random.randn(*arr.shape).astype(arr.dtype)
879
+ arr = np.transpose(arr, [1, 0])
880
+ out_dict = {}
881
+ out_dict['input_ids'] = np.array(self.text_datasets['train'][idx]['input_ids'])
882
+ # out_dict['mapping_func'] = self.mapping_func
883
+ # if self.local_classes is not None:
884
+ # out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
885
+ # print(arr.shape)
886
+ return arr, out_dict
887
+ else:
888
+ arr = np.array(hidden_state,
889
+ dtype=np.float32)
890
+ if self.eigen_transform is not None:
891
+ old_shape = arr.shape
892
+ # arr = arr.reshape(1, -1) @ self.eigen_transform
893
+ arr = arr.reshape(1, -1) - self.eigen_transform['mean']
894
+ arr = arr @ self.eigen_transform['map']
895
+ arr = arr.reshape(old_shape)
896
+
897
+ if hasattr(self.data_args, 'noise_level') and self.data_args.noise_level > 0:
898
+ # print(arr.dtype)
899
+ # print(self.data_args.noise_level, 'using the noise level.')
900
+ arr = arr + self.data_args.noise_level * np.random.randn(*arr.shape).astype(arr.dtype)
901
+ # print(arr.dtype)
902
+
903
+ out_dict = {}
904
+ out_dict['input_ids'] = np.array(self.text_datasets['train'][idx]['input_ids'])
905
+ # out_dict['mapping_func'] = self.mapping_func
906
+ if self.data_args.experiment_mode == 'conditional_gen':
907
+ out_dict['src_ids'] = np.array(self.text_datasets['train'][idx]['src_ids'])
908
+ out_dict['src_mask'] = np.array(self.text_datasets['train'][idx]['src_mask'])
909
+ # if self.local_classes is not None:
910
+ # out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
911
+ return arr, out_dict
912
+
913
+ def _collate_batch_helper(examples, pad_token_id, max_length, return_mask=False):
914
+ result = torch.full([len(examples), max_length], pad_token_id, dtype=torch.int64).tolist()
915
+ mask_ = torch.full([len(examples), max_length], pad_token_id, dtype=torch.int64).tolist()
916
+ for i, example in enumerate(examples):
917
+ curr_len = min(len(example), max_length)
918
+ result[i][:curr_len] = example[:curr_len]
919
+ mask_[i][:curr_len] = [1] * curr_len
920
+ if return_mask:
921
+ return result, mask_
922
+ return result
923
+
924
+ def _torch_collate_batch(examples, pad_token_id, max_length):
925
+ """Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary."""
926
+ import numpy as np
927
+ import torch
928
+
929
+ # Tensorize if necessary.
930
+ if isinstance(examples[0], (list, tuple, np.ndarray)):
931
+ examples = [torch.tensor(e, dtype=torch.long) for e in examples]
932
+
933
+ # length_of_first = examples[0].size(0)
934
+ # Check if padding is necessary.
935
+ # are_tensors_same_length = all(x.size(0) == length_of_first for x in examples)
936
+ # if are_tensors_same_length and (pad_to_multiple_of is None or length_of_first % pad_to_multiple_of == 0):
937
+ # return torch.stack(examples, dim=0)
938
+ # Creating the full tensor and filling it with our data.
939
+ # max_length = max(x.size(0) for x in examples)
940
+ # if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
941
+ # max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
942
+ result = examples[0].new_full([len(examples), max_length], pad_token_id)
943
+ for i, example in enumerate(examples):
944
+ if True:
945
+ result[i, : example.shape[0]] = example
946
+ else:
947
+ result[i, -example.shape[0] :] = example
948
+ return result
src/improved_diffusion/train_util.py ADDED
@@ -0,0 +1,445 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import copy
3
+ import functools
4
+ import blobfile as bf
5
+ import torch
6
+ import torch.distributed as dist
7
+ from torch.nn.parallel.distributed import DistributedDataParallel as DDP
8
+ from torch.optim import AdamW
9
+
10
+ from . import dist_util, logger
11
+ from .fp16_util import (
12
+ make_master_params,
13
+ master_params_to_model_params,
14
+ model_grads_to_master_grads,
15
+ unflatten_master_params,
16
+ zero_grad,
17
+ )
18
+ from .nn import update_ema
19
+ from .resample import LossAwareSampler, UniformSampler
20
+ import wandb
21
+ from tqdm import tqdm
22
+
23
+ INITIAL_LOG_LOSS_SCALE = 20.0
24
+
25
+
26
+ class TrainLoop:
27
+ def __init__(
28
+ self,
29
+ *,
30
+ model,
31
+ diffusion,
32
+ data,
33
+ batch_size,
34
+ microbatch,
35
+ lr,
36
+ ema_rate,
37
+ log_interval,
38
+ save_interval,
39
+ resume_checkpoint,
40
+ use_fp16=False,
41
+ fp16_scale_growth=1e-3,
42
+ schedule_sampler=None,
43
+ weight_decay=0.0,
44
+ lr_anneal_steps=0,
45
+ checkpoint_path="",
46
+ gradient_clipping=-1.0,
47
+ eval_data=None,
48
+ eval_interval=-1,
49
+ ):
50
+ print('Initiating train loop')
51
+ rank = dist.get_rank()
52
+ world_size = dist.get_world_size()
53
+ self.rank = rank
54
+ self.world_size = world_size
55
+ self.diffusion = diffusion
56
+ self.data = data
57
+ self.eval_data = eval_data
58
+ self.batch_size = batch_size
59
+ self.microbatch = microbatch if microbatch > 0 else batch_size
60
+ self.lr = lr * world_size
61
+ self.ema_rate = (
62
+ [ema_rate]
63
+ if isinstance(ema_rate, float)
64
+ else [float(x) for x in ema_rate.split(",")]
65
+ )
66
+ self.log_interval = log_interval
67
+ self.eval_interval = eval_interval
68
+ self.save_interval = save_interval
69
+ self.resume_checkpoint = resume_checkpoint
70
+ self.use_fp16 = use_fp16
71
+ self.fp16_scale_growth = fp16_scale_growth
72
+ self.schedule_sampler = schedule_sampler or UniformSampler(diffusion)
73
+ self.weight_decay = weight_decay
74
+ self.lr_anneal_steps = lr_anneal_steps
75
+ self.gradient_clipping = gradient_clipping
76
+
77
+ self.step = 0
78
+ self.resume_step = 0
79
+ self.global_batch = self.batch_size * dist.get_world_size()
80
+
81
+ self.lg_loss_scale = INITIAL_LOG_LOSS_SCALE
82
+ self.sync_cuda = torch.cuda.is_available()
83
+ self.checkpoint_path = checkpoint_path
84
+
85
+ self.model = model.to(rank)
86
+
87
+ if torch.cuda.is_available(): # DEBUG **
88
+ self.use_ddp = True
89
+ self.ddp_model = self.model
90
+ # self.ddp_model = DDP(
91
+ # self.model,
92
+ # device_ids=[self.rank],
93
+ # find_unused_parameters=False,
94
+ # )
95
+ else:
96
+ self.ddp_model = model.to("cpu")
97
+
98
+ self.model_params = list(self.ddp_model.parameters())
99
+ self.master_params = self.model_params
100
+ self.opt = AdamW(self.master_params, lr=self.lr, weight_decay=self.weight_decay)
101
+ if self.resume_step:
102
+ # self._load_optimizer_state()
103
+ # # Model was resumed, either due to a restart or a checkpoint
104
+ # # being specified at the command line.
105
+ # self.ema_params = [
106
+ # self._load_ema_parameters(rate) for rate in self.ema_rate
107
+ # ]
108
+ pass
109
+ else:
110
+ self.ema_params = [
111
+ copy.deepcopy(self.master_params) for _ in range(len(self.ema_rate))
112
+ ]
113
+ print('Finish initiating train loop')
114
+
115
+ def _load_and_sync_parameters(self):
116
+ resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
117
+
118
+ if resume_checkpoint:
119
+ self.resume_step = parse_resume_step_from_filename(resume_checkpoint)
120
+ if dist.get_rank() == 0:
121
+ # logger.log(f"loading model from checkpoint: {resume_checkpoint}...")
122
+ print(f"loading model from checkpoint: {resume_checkpoint}...")
123
+ self.model.load_state_dict(
124
+ dist_util.load_state_dict(
125
+ resume_checkpoint, map_location=dist_util.dev()
126
+ )
127
+ )
128
+
129
+ dist_util.sync_params(self.model.parameters())
130
+
131
+ def _load_ema_parameters(self, rate):
132
+ ema_params = copy.deepcopy(self.master_params)
133
+
134
+ main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
135
+ ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step, rate)
136
+ if ema_checkpoint:
137
+ if dist.get_rank() == 0:
138
+ logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...")
139
+ state_dict = dist_util.load_state_dict(
140
+ ema_checkpoint, map_location=dist_util.dev()
141
+ )
142
+ ema_params = self._state_dict_to_master_params(state_dict)
143
+
144
+ dist_util.sync_params(ema_params)
145
+ return ema_params
146
+
147
+ def _load_optimizer_state(self):
148
+ main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
149
+ opt_checkpoint = bf.join(
150
+ bf.dirname(main_checkpoint), f"opt{self.resume_step:06}.pt"
151
+ )
152
+ if bf.exists(opt_checkpoint):
153
+ logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}")
154
+ state_dict = dist_util.load_state_dict(
155
+ opt_checkpoint, map_location=dist_util.dev()
156
+ )
157
+ self.opt.load_state_dict(state_dict)
158
+
159
+ def _setup_fp16(self):
160
+ self.master_params = make_master_params(self.model_params)
161
+ self.model.convert_to_fp16()
162
+
163
+ def run_loop(self):
164
+ pbar = tqdm(total=self.lr_anneal_steps // self.world_size)
165
+ print('Start running train loop')
166
+ while (
167
+ not self.lr_anneal_steps
168
+ or self.step + self.resume_step < self.lr_anneal_steps // self.world_size
169
+ ):
170
+ pbar.set_description(f"Step: {self.step + self.resume_step}")
171
+ batch = next(self.data)
172
+ # if self.step<3:
173
+ # print("RANK:",self.rank,"STEP:",self.step,"BATCH:",batch)
174
+ self.run_step(batch, cond=None)
175
+ if self.step % self.log_interval == 0:
176
+ # dist.barrier()
177
+ pass
178
+ # print('loggggg')
179
+ # logger.dumpkvs()
180
+ if self.eval_data is not None and self.step % self.eval_interval == 0:
181
+ # batch_eval, cond_eval = next(self.eval_data)
182
+ # self.forward_only(batch, cond)
183
+ print("eval on validation set")
184
+ pass # logger.dumpkvs()
185
+ if self.step % self.save_interval == 0 and self.step != 0:
186
+ self.save()
187
+ # Run for a finite amount of time in integration tests.
188
+ if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0:
189
+ return
190
+ self.step += 1
191
+ pbar.update(1)
192
+ # Save the last checkpoint if it wasn't already saved.
193
+ if (self.step - 1) % self.save_interval != 0:
194
+ self.save()
195
+
196
+ def run_step(self, batch, cond):
197
+ self.forward_backward(batch, cond)
198
+ if self.use_fp16:
199
+ self.optimize_fp16()
200
+ else:
201
+ self.optimize_normal()
202
+ self.log_step()
203
+
204
+ def forward_only(self, batch, cond):
205
+ with torch.no_grad():
206
+ zero_grad(self.model_params)
207
+ for i in range(0, batch.shape[0], self.microbatch):
208
+ micro = batch[i : i + self.microbatch].to(dist_util.dev())
209
+ micro_cond = {
210
+ k: v[i : i + self.microbatch].to(dist_util.dev())
211
+ for k, v in cond.items()
212
+ }
213
+ last_batch = (i + self.microbatch) >= batch.shape[0]
214
+ t, weights = self.schedule_sampler.sample(
215
+ micro.shape[0], dist_util.dev()
216
+ )
217
+ # print(micro_cond.keys())
218
+ compute_losses = functools.partial(
219
+ self.diffusion.training_losses,
220
+ self.ddp_model,
221
+ micro,
222
+ t,
223
+ micro_cond,
224
+ )
225
+
226
+ if last_batch or not self.use_ddp:
227
+ losses = compute_losses()
228
+ else:
229
+ with self.ddp_model.no_sync():
230
+ losses = compute_losses()
231
+
232
+ log_loss_dict(
233
+ self.diffusion,
234
+ t,
235
+ {f"eval_{k}": v * weights for k, v in losses.items()},
236
+ )
237
+
238
+ def forward_backward(self, batch, cond):
239
+ # zero_grad(self.model_params)
240
+ self.opt.zero_grad()
241
+ for i in range(0, batch[0].shape[0], self.microbatch):
242
+ # micro = batch[i : i + self.microbatch].to(self.rank)
243
+ # last_batch = (i + self.microbatch) >= batch.shape[0]
244
+ # t, weights = self.schedule_sampler.sample(micro.shape[0], self.rank)
245
+
246
+ micro = (
247
+ batch[0].to(self.rank), # selfies_ids
248
+ batch[1].to(self.rank), # caption_state
249
+ batch[2].to(self.rank), # caption_mask
250
+ batch[3].to(self.rank), # corrupted_selfies_ids
251
+ )
252
+ last_batch = True
253
+ t, weights = self.schedule_sampler.sample(micro[0].shape[0], self.rank)
254
+
255
+ compute_losses = functools.partial(
256
+ self.diffusion.training_losses,
257
+ self.ddp_model,
258
+ micro,
259
+ t,
260
+ None,
261
+ )
262
+
263
+ if last_batch or not self.use_ddp:
264
+ losses = compute_losses()
265
+ else:
266
+ with self.ddp_model.no_sync():
267
+ losses = compute_losses()
268
+
269
+ if isinstance(self.schedule_sampler, LossAwareSampler):
270
+ self.schedule_sampler.update_with_local_losses(
271
+ t, losses["loss"].detach()
272
+ )
273
+
274
+ loss = (losses["loss"] * weights).mean()
275
+ # print('----DEBUG-----',self.step,self.log_interval)
276
+ if self.step % self.log_interval == 0 and self.rank == 0:
277
+ print("rank0: ", self.step, loss.item())
278
+ wandb.log({"loss": loss.item()})
279
+ # log_loss_dict(
280
+ # self.diffusion, t, {k: v * weights for k, v in losses.items()}
281
+ # )
282
+ if self.use_fp16:
283
+ # loss_scale = 2 ** self.lg_loss_scale
284
+ # (loss * loss_scale).backward()
285
+ pass
286
+ else:
287
+ loss.backward()
288
+
289
+ def optimize_fp16(self):
290
+ if any(not torch.isfinite(p.grad).all() for p in self.model_params):
291
+ self.lg_loss_scale -= 1
292
+ logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}")
293
+ return
294
+
295
+ model_grads_to_master_grads(self.model_params, self.master_params)
296
+ self.master_params[0].grad.mul_(1.0 / (2**self.lg_loss_scale))
297
+ self._log_grad_norm()
298
+ self._anneal_lr()
299
+ self.opt.step()
300
+ for rate, params in zip(self.ema_rate, self.ema_params):
301
+ update_ema(params, self.master_params, rate=rate)
302
+ master_params_to_model_params(self.model_params, self.master_params)
303
+ self.lg_loss_scale += self.fp16_scale_growth
304
+
305
+ def grad_clip(self):
306
+ # print('doing gradient clipping')
307
+ max_grad_norm = self.gradient_clipping # 3.0
308
+ if hasattr(self.opt, "clip_grad_norm"):
309
+ # Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping
310
+ self.opt.clip_grad_norm(max_grad_norm)
311
+ # else:
312
+ # assert False
313
+ # elif hasattr(self.model, "clip_grad_norm_"):
314
+ # # Some models (like FullyShardedDDP) have a specific way to do gradient clipping
315
+ # self.model.clip_grad_norm_(args.max_grad_norm)
316
+ else:
317
+ # Revert to normal clipping otherwise, handling Apex or full precision
318
+ torch.nn.utils.clip_grad_norm_(
319
+ self.model.parameters(), # amp.master_params(self.opt) if self.use_apex else
320
+ max_grad_norm,
321
+ )
322
+
323
+ def optimize_normal(self):
324
+ if self.gradient_clipping > 0:
325
+ self.grad_clip()
326
+ # self._log_grad_norm()
327
+ self._anneal_lr()
328
+ self.opt.step()
329
+ for rate, params in zip(self.ema_rate, self.ema_params):
330
+ update_ema(params, self.master_params, rate=rate)
331
+
332
+ def _log_grad_norm(self):
333
+ sqsum = 0.0
334
+ for p in self.master_params:
335
+ sqsum += (p.grad**2).sum().item()
336
+ # logger.logkv_mean("grad_norm", np.sqrt(sqsum))
337
+
338
+ def _anneal_lr(self):
339
+ if not self.lr_anneal_steps:
340
+ return
341
+ frac_done = (self.step + self.resume_step) / self.lr_anneal_steps
342
+ lr = self.lr * (1 - frac_done)
343
+ for param_group in self.opt.param_groups:
344
+ param_group["lr"] = lr
345
+
346
+ def log_step(self):
347
+ logger.logkv("step", self.step + self.resume_step)
348
+ # logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch)
349
+ if self.use_fp16:
350
+ logger.logkv("lg_loss_scale", self.lg_loss_scale)
351
+
352
+ def save(self):
353
+ def save_checkpoint(rate, params):
354
+ state_dict = self._master_params_to_state_dict(params)
355
+ if dist.get_rank() == 0:
356
+ # logger.log(f"saving model {rate}...")
357
+ print(f"saving model {rate}...")
358
+ if not rate:
359
+ filename = f"PLAIN_model{((self.step+self.resume_step)*self.world_size):06d}.pt"
360
+ else:
361
+ filename = f"PLAIN_ema_{rate}_{((self.step+self.resume_step)*self.world_size):06d}.pt"
362
+ # print('writing to', bf.join(get_blob_logdir(), filename))
363
+ # print('writing to', bf.join(self.checkpoint_path, filename))
364
+ # with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f:
365
+ # torch.save(state_dict, f)
366
+ with bf.BlobFile(
367
+ bf.join(self.checkpoint_path, filename), "wb"
368
+ ) as f: # DEBUG **
369
+ torch.save(state_dict, f)
370
+
371
+ save_checkpoint(0, self.master_params)
372
+ for rate, params in zip(self.ema_rate, self.ema_params):
373
+ save_checkpoint(rate, params)
374
+
375
+ # if dist.get_rank() == 0: # DEBUG **
376
+ # with bf.BlobFile(
377
+ # bf.join(get_blob_logdir(), f"opt{(self.step+self.resume_step):06d}.pt"),
378
+ # "wb",
379
+ # ) as f:
380
+ # torch.save(self.opt.state_dict(), f)
381
+
382
+ dist.barrier()
383
+
384
+ def _master_params_to_state_dict(self, master_params):
385
+ if self.use_fp16:
386
+ master_params = unflatten_master_params(
387
+ list(self.model.parameters()), master_params # DEBUG **
388
+ )
389
+ state_dict = self.model.state_dict()
390
+ for i, (name, _value) in enumerate(self.model.named_parameters()):
391
+ assert name in state_dict
392
+ state_dict[name] = master_params[i]
393
+ return state_dict
394
+
395
+ def _state_dict_to_master_params(self, state_dict):
396
+ params = [state_dict[name] for name, _ in self.model.named_parameters()]
397
+ if self.use_fp16:
398
+ return make_master_params(params)
399
+ else:
400
+ return params
401
+
402
+
403
+ def parse_resume_step_from_filename(filename):
404
+ """
405
+ Parse filenames of the form path/to/modelNNNNNN.pt, where NNNNNN is the
406
+ checkpoint's number of steps.
407
+ """
408
+ split = filename.split("model")
409
+ if len(split) < 2:
410
+ return 0
411
+ split1 = split[-1].split(".")[0]
412
+ try:
413
+ return int(split1)
414
+ except ValueError:
415
+ return 0
416
+
417
+
418
+ def get_blob_logdir():
419
+ return os.environ.get("DIFFUSION_BLOB_LOGDIR", logger.get_dir())
420
+
421
+
422
+ def find_resume_checkpoint():
423
+ # On your infrastructure, you may want to override this to automatically
424
+ # discover the latest checkpoint on your blob storage, etc.
425
+ return None
426
+
427
+
428
+ def find_ema_checkpoint(main_checkpoint, step, rate):
429
+ if main_checkpoint is None:
430
+ return None
431
+ filename = f"ema_{rate}_{(step):06d}.pt"
432
+ path = bf.join(bf.dirname(main_checkpoint), filename)
433
+ if bf.exists(path):
434
+ return path
435
+ return None
436
+
437
+
438
+ def log_loss_dict(diffusion, ts, losses):
439
+ return
440
+ for key, values in losses.items():
441
+ logger.logkv_mean(key, values.mean().item())
442
+ # Log the quantiles (four quartiles, in particular).
443
+ for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()):
444
+ quartile = int(4 * sub_t / diffusion.num_timesteps)
445
+ logger.logkv_mean(f"{key}_q{quartile}", sub_loss)
src/improved_diffusion/transformer_model.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from transformers import AutoConfig, T5EncoderModel
6
+
7
+ from .nn import SiLU, linear, timestep_embedding
8
+
9
+
10
+ class TransformerNetModel(nn.Module):
11
+ def __init__(
12
+ self,
13
+ in_channels=32,
14
+ model_channels=128,
15
+ dropout=0.1,
16
+ config_name="QizhiPei/biot5-base-text2mol",
17
+ vocab_size=None, # 821
18
+ hidden_size=768,
19
+ num_attention_heads=12,
20
+ num_hidden_layers=12,
21
+ ):
22
+ super().__init__()
23
+
24
+ config = AutoConfig.from_pretrained(config_name)
25
+ config.is_decoder = True
26
+ config.add_cross_attention = True
27
+ config.hidden_dropout_prob = 0.1
28
+ config.num_attention_heads = num_attention_heads
29
+ config.num_hidden_layers = num_hidden_layers
30
+ config.max_position_embeddings = 512
31
+ config.layer_norm_eps = 1e-12
32
+ config.vocab_size = vocab_size
33
+ config.d_model = hidden_size
34
+
35
+ self.hidden_size = hidden_size
36
+ self.in_channels = in_channels
37
+ self.model_channels = model_channels
38
+ self.dropout = dropout
39
+ self.word_embedding = nn.Embedding(vocab_size, self.in_channels)
40
+ self.lm_head = nn.Linear(self.in_channels, vocab_size)
41
+ self.lm_head.weight = self.word_embedding.weight
42
+
43
+ self.caption_down_proj = nn.Sequential(
44
+ linear(768, self.hidden_size),
45
+ SiLU(),
46
+ linear(self.hidden_size, self.hidden_size),
47
+ )
48
+
49
+ time_embed_dim = model_channels * 4 # 512
50
+ self.time_embed = nn.Sequential(
51
+ linear(self.model_channels, time_embed_dim),
52
+ SiLU(),
53
+ linear(time_embed_dim, self.hidden_size),
54
+ )
55
+
56
+ self.input_up_proj = nn.Sequential(
57
+ nn.Linear(self.in_channels, self.hidden_size),
58
+ nn.Tanh(),
59
+ nn.Linear(self.hidden_size, self.hidden_size),
60
+ )
61
+
62
+ self.input_transformers = T5EncoderModel(config)
63
+ # self.input_transformers.eval()
64
+ # for param in self.input_transformers.parameters():
65
+ # param.requires_grad = False
66
+
67
+ self.register_buffer(
68
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))
69
+ )
70
+ self.position_embeddings = nn.Embedding(
71
+ config.max_position_embeddings, self.hidden_size
72
+ )
73
+
74
+ self.LayerNorm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
75
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
76
+ self.output_down_proj = nn.Sequential(
77
+ nn.Linear(self.hidden_size, self.hidden_size),
78
+ nn.Tanh(),
79
+ nn.Linear(self.hidden_size, self.in_channels),
80
+ )
81
+
82
+ def get_embeds(self, input_ids):
83
+ return self.word_embedding(input_ids)
84
+
85
+ def get_embeds_with_deep(self, input_ids):
86
+ atom, deep = input_ids
87
+ atom = self.word_embedding(atom)
88
+ deep = self.deep_embedding(deep)
89
+
90
+ return torch.concat([atom, deep], dim=-1)
91
+
92
+ def get_logits(self, hidden_repr):
93
+ return self.lm_head(hidden_repr)
94
+
95
+ def forward(self, x, timesteps, caption_state, caption_mask, y=None):
96
+ emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
97
+ emb_x = self.input_up_proj(x)
98
+ seq_length = x.size(1)
99
+ position_ids = self.position_ids[:, :seq_length]
100
+ emb_inputs = (
101
+ self.position_embeddings(position_ids)
102
+ + emb_x
103
+ + emb.unsqueeze(1).expand(-1, seq_length, -1)
104
+ )
105
+ emb_inputs = self.dropout(self.LayerNorm(emb_inputs))
106
+
107
+ caption_state = self.dropout(
108
+ self.LayerNorm(self.caption_down_proj(caption_state))
109
+ )
110
+
111
+ input_trans_hidden_states = self.input_transformers.encoder(
112
+ inputs_embeds=emb_inputs,
113
+ encoder_hidden_states=caption_state,
114
+ encoder_attention_mask=caption_mask,
115
+ ).last_hidden_state
116
+ h = self.output_down_proj(input_trans_hidden_states)
117
+ h = h.type(x.dtype)
118
+ return h
src/improved_diffusion/transformer_utils.py ADDED
@@ -0,0 +1,450 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """PyTorch BERT model."""
17
+
18
+
19
+ import math
20
+ import os
21
+ import warnings
22
+ from dataclasses import dataclass
23
+ from typing import Optional, Tuple
24
+
25
+ import torch
26
+ import torch.utils.checkpoint
27
+ from packaging import version
28
+ from torch import nn
29
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
30
+ from transformers.activations import ACT2FN
31
+ from transformers.file_utils import (
32
+ ModelOutput,
33
+ add_code_sample_docstrings,
34
+ add_start_docstrings,
35
+ add_start_docstrings_to_model_forward,
36
+ replace_return_docstrings,
37
+ )
38
+ from transformers.modeling_outputs import (
39
+ BaseModelOutputWithPastAndCrossAttentions,
40
+ BaseModelOutputWithPoolingAndCrossAttentions,
41
+ CausalLMOutputWithCrossAttentions,
42
+ MaskedLMOutput,
43
+ MultipleChoiceModelOutput,
44
+ NextSentencePredictorOutput,
45
+ QuestionAnsweringModelOutput,
46
+ SequenceClassifierOutput,
47
+ TokenClassifierOutput,
48
+ )
49
+ from transformers.modeling_utils import (
50
+ PreTrainedModel,
51
+ apply_chunking_to_forward,
52
+ find_pruneable_heads_and_indices,
53
+ prune_linear_layer,
54
+ )
55
+ from transformers.utils import logging
56
+ from transformers.models.bert.configuration_bert import BertConfig
57
+
58
+
59
+ logger = logging.get_logger(__name__)
60
+
61
+ _CHECKPOINT_FOR_DOC = "bert-base-uncased"
62
+ _CONFIG_FOR_DOC = "BertConfig"
63
+ _TOKENIZER_FOR_DOC = "BertTokenizer"
64
+
65
+ BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
66
+ "bert-base-uncased",
67
+ "bert-large-uncased",
68
+ "bert-base-cased",
69
+ "bert-large-cased",
70
+ "bert-base-multilingual-uncased",
71
+ "bert-base-multilingual-cased",
72
+ "bert-base-chinese",
73
+ "bert-base-german-cased",
74
+ "bert-large-uncased-whole-word-masking",
75
+ "bert-large-cased-whole-word-masking",
76
+ "bert-large-uncased-whole-word-masking-finetuned-squad",
77
+ "bert-large-cased-whole-word-masking-finetuned-squad",
78
+ "bert-base-cased-finetuned-mrpc",
79
+ "bert-base-german-dbmdz-cased",
80
+ "bert-base-german-dbmdz-uncased",
81
+ "cl-tohoku/bert-base-japanese",
82
+ "cl-tohoku/bert-base-japanese-whole-word-masking",
83
+ "cl-tohoku/bert-base-japanese-char",
84
+ "cl-tohoku/bert-base-japanese-char-whole-word-masking",
85
+ "TurkuNLP/bert-base-finnish-cased-v1",
86
+ "TurkuNLP/bert-base-finnish-uncased-v1",
87
+ "wietsedv/bert-base-dutch-cased",
88
+ # See all BERT models at https://huggingface.co/models?filter=bert
89
+ ]
90
+
91
+
92
+ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
93
+ """Load tf checkpoints in a pytorch model."""
94
+ try:
95
+ import re
96
+
97
+ import numpy as np
98
+ import tensorflow as tf
99
+ except ImportError:
100
+ logger.error(
101
+ "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
102
+ "https://www.tensorflow.org/install/ for installation instructions."
103
+ )
104
+ raise
105
+ tf_path = os.path.abspath(tf_checkpoint_path)
106
+ logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
107
+ # Load weights from TF model
108
+ init_vars = tf.train.list_variables(tf_path)
109
+ names = []
110
+ arrays = []
111
+ for name, shape in init_vars:
112
+ logger.info(f"Loading TF weight {name} with shape {shape}")
113
+ array = tf.train.load_variable(tf_path, name)
114
+ names.append(name)
115
+ arrays.append(array)
116
+
117
+ for name, array in zip(names, arrays):
118
+ name = name.split("/")
119
+ # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
120
+ # which are not required for using pretrained model
121
+ if any(
122
+ n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
123
+ for n in name
124
+ ):
125
+ logger.info(f"Skipping {'/'.join(name)}")
126
+ continue
127
+ pointer = model
128
+ for m_name in name:
129
+ if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
130
+ scope_names = re.split(r"_(\d+)", m_name)
131
+ else:
132
+ scope_names = [m_name]
133
+ if scope_names[0] == "kernel" or scope_names[0] == "gamma":
134
+ pointer = getattr(pointer, "weight")
135
+ elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
136
+ pointer = getattr(pointer, "bias")
137
+ elif scope_names[0] == "output_weights":
138
+ pointer = getattr(pointer, "weight")
139
+ elif scope_names[0] == "squad":
140
+ pointer = getattr(pointer, "classifier")
141
+ else:
142
+ try:
143
+ pointer = getattr(pointer, scope_names[0])
144
+ except AttributeError:
145
+ logger.info(f"Skipping {'/'.join(name)}")
146
+ continue
147
+ if len(scope_names) >= 2:
148
+ num = int(scope_names[1])
149
+ pointer = pointer[num]
150
+ if m_name[-11:] == "_embeddings":
151
+ pointer = getattr(pointer, "weight")
152
+ elif m_name == "kernel":
153
+ array = np.transpose(array)
154
+ try:
155
+ if pointer.shape != array.shape:
156
+ raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
157
+ except AssertionError as e:
158
+ e.args += (pointer.shape, array.shape)
159
+ raise
160
+ logger.info(f"Initialize PyTorch weight {name}")
161
+ pointer.data = torch.from_numpy(array)
162
+ return model
163
+
164
+
165
+ class BertEmbeddings(nn.Module):
166
+ """Construct the embeddings from word, position and token_type embeddings."""
167
+
168
+ def __init__(self, config):
169
+ super().__init__()
170
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
171
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
172
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
173
+
174
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
175
+ # any TensorFlow checkpoint file
176
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
177
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
178
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
179
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
180
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
181
+ if version.parse(torch.__version__) > version.parse("1.6.0"):
182
+ self.register_buffer(
183
+ "token_type_ids",
184
+ torch.zeros(self.position_ids.size(), dtype=torch.long),
185
+ persistent=False,
186
+ )
187
+
188
+ def forward(
189
+ self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
190
+ ):
191
+ if input_ids is not None:
192
+ input_shape = input_ids.size()
193
+ else:
194
+ input_shape = inputs_embeds.size()[:-1]
195
+
196
+ seq_length = input_shape[1]
197
+
198
+ if position_ids is None:
199
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
200
+
201
+ # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
202
+ # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
203
+ # issue #5664
204
+ if token_type_ids is None:
205
+ if hasattr(self, "token_type_ids"):
206
+ buffered_token_type_ids = self.token_type_ids[:, :seq_length]
207
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
208
+ token_type_ids = buffered_token_type_ids_expanded
209
+ else:
210
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
211
+
212
+ if inputs_embeds is None:
213
+ inputs_embeds = self.word_embeddings(input_ids)
214
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
215
+
216
+ embeddings = inputs_embeds + token_type_embeddings
217
+ if self.position_embedding_type == "absolute":
218
+ position_embeddings = self.position_embeddings(position_ids)
219
+ embeddings += position_embeddings
220
+ embeddings = self.LayerNorm(embeddings)
221
+ embeddings = self.dropout(embeddings)
222
+ return embeddings
223
+
224
+
225
+ class BertSelfAttention(nn.Module):
226
+ def __init__(self, config, hidden_size, num_attention_heads, attention_head_size, position_embedding_type=None):
227
+ super().__init__()
228
+ # hidden_size, num_attention_heads, attention_probs_dropout_prob
229
+ # if hidden_size % num_attention_heads != 0 and not hasattr(config, "embedding_size"):
230
+ # raise ValueError(
231
+ # f"The hidden size ({hidden_size}) is not a multiple of the number of attention "
232
+ # f"heads ({num_attention_heads})"
233
+ # )
234
+
235
+ self.num_attention_heads = num_attention_heads
236
+ self.attention_head_size = attention_head_size
237
+ # self.attention_head_size = int(hidden_size / num_attention_heads)
238
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
239
+
240
+ self.query = nn.Linear(hidden_size, self.all_head_size)
241
+ self.key = nn.Linear(hidden_size, self.all_head_size)
242
+ self.value = nn.Linear(hidden_size, self.all_head_size)
243
+
244
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
245
+ self.position_embedding_type = position_embedding_type or getattr(
246
+ config, "position_embedding_type", "absolute"
247
+ )
248
+ # print(self.position_embedding_type, config.max_position_embeddings)
249
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
250
+ self.max_position_embeddings = config.max_position_embeddings
251
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
252
+
253
+ self.is_decoder = config.is_decoder
254
+
255
+ def transpose_for_scores(self, x):
256
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
257
+ x = x.view(*new_x_shape)
258
+ # print(x.shape)
259
+ return x.permute(0, 2, 1, 3)
260
+
261
+ def forward(
262
+ self,
263
+ hidden_states,
264
+ attention_mask=None,
265
+ head_mask=None,
266
+ encoder_hidden_states=None,
267
+ encoder_attention_mask=None,
268
+ past_key_value=None,
269
+ output_attentions=False,
270
+ ):
271
+ mixed_query_layer = self.query(hidden_states)
272
+
273
+ # If this is instantiated as a cross-attention module, the keys
274
+ # and values come from an encoder; the attention mask needs to be
275
+ # such that the encoder's padding tokens are not attended to.
276
+ is_cross_attention = encoder_hidden_states is not None
277
+
278
+ if is_cross_attention and past_key_value is not None:
279
+ # reuse k,v, cross_attentions
280
+ key_layer = past_key_value[0]
281
+ value_layer = past_key_value[1]
282
+ attention_mask = encoder_attention_mask
283
+ elif is_cross_attention:
284
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
285
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
286
+ attention_mask = encoder_attention_mask
287
+ elif past_key_value is not None:
288
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
289
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
290
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
291
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
292
+ else:
293
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
294
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
295
+
296
+ query_layer = self.transpose_for_scores(mixed_query_layer)
297
+
298
+ if self.is_decoder:
299
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
300
+ # Further calls to cross_attention layer can then reuse all cross-attention
301
+ # key/value_states (first "if" case)
302
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
303
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
304
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
305
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
306
+ past_key_value = (key_layer, value_layer)
307
+
308
+ # Take the dot product between "query" and "key" to get the raw attention scores.
309
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
310
+
311
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
312
+ seq_length = hidden_states.size()[1]
313
+ position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
314
+ position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
315
+ distance = position_ids_l - position_ids_r
316
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
317
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
318
+
319
+ if self.position_embedding_type == "relative_key":
320
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
321
+ attention_scores = attention_scores + relative_position_scores
322
+ elif self.position_embedding_type == "relative_key_query":
323
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
324
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
325
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
326
+
327
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
328
+ if attention_mask is not None:
329
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
330
+ attention_scores = attention_scores + attention_mask
331
+
332
+ # Normalize the attention scores to probabilities.
333
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
334
+
335
+ # This is actually dropping out entire tokens to attend to, which might
336
+ # seem a bit unusual, but is taken from the original Transformer paper.
337
+ attention_probs = self.dropout(attention_probs)
338
+
339
+ # Mask heads if we want to
340
+ if head_mask is not None:
341
+ attention_probs = attention_probs * head_mask
342
+
343
+ context_layer = torch.matmul(attention_probs, value_layer)
344
+
345
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
346
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
347
+ context_layer = context_layer.view(*new_context_layer_shape)
348
+
349
+ # outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
350
+ #
351
+ # if self.is_decoder:
352
+ # outputs = outputs + (past_key_value,)
353
+ return context_layer
354
+
355
+ class BertOutput(nn.Module):
356
+ def __init__(self, config):
357
+ super().__init__()
358
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
359
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
360
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
361
+
362
+ def forward(self, hidden_states, input_tensor):
363
+ hidden_states = self.dense(hidden_states)
364
+ hidden_states = self.dropout(hidden_states)
365
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
366
+ return hidden_states
367
+
368
+ class BertSelfOutput(nn.Module):
369
+ def __init__(self, config, hidden_size, input_hidden_size):
370
+ super().__init__()
371
+ self.dense = nn.Linear(hidden_size, hidden_size)
372
+
373
+ if input_hidden_size != hidden_size:
374
+ self.rescale=True
375
+ self.dense2 = nn.Linear(input_hidden_size, hidden_size)
376
+ else:
377
+ self.rescale = False
378
+ self.LayerNorm = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
379
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
380
+
381
+ def forward(self, hidden_states, input_tensor):
382
+ hidden_states = self.dense(hidden_states)
383
+ if self.rescale:
384
+ input_tensor2 = self.dense2(input_tensor)
385
+ else:
386
+ input_tensor2 = input_tensor
387
+ hidden_states = self.dropout(hidden_states)
388
+ hidden_states = self.LayerNorm(hidden_states + input_tensor2)
389
+ return hidden_states
390
+
391
+ def trans_nd(config, hidden_size, num_attention_heads, attention_head_size):
392
+ return BertSelfAttention(config, hidden_size, num_attention_heads, attention_head_size,
393
+ position_embedding_type=None)
394
+
395
+ def layer_norm(hidden_size, ):
396
+ # print(f'layer norm, {hidden_size}')
397
+ return nn.LayerNorm(hidden_size)
398
+
399
+ class BertAttention(nn.Module):
400
+ def __init__(self, config, hidden_size, num_attention_heads, attention_head_size,
401
+ position_embedding_type=None):
402
+ super().__init__()
403
+ self.self = BertSelfAttention(config, hidden_size, num_attention_heads, attention_head_size,
404
+ position_embedding_type=position_embedding_type)
405
+ self.output = BertSelfOutput(config, num_attention_heads * attention_head_size, hidden_size)
406
+ self.pruned_heads = set()
407
+
408
+ def prune_heads(self, heads):
409
+ if len(heads) == 0:
410
+ return
411
+ heads, index = find_pruneable_heads_and_indices(
412
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
413
+ )
414
+
415
+ # Prune linear layers
416
+ self.self.query = prune_linear_layer(self.self.query, index)
417
+ self.self.key = prune_linear_layer(self.self.key, index)
418
+ self.self.value = prune_linear_layer(self.self.value, index)
419
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
420
+
421
+ # Update hyper params and store pruned heads
422
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
423
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
424
+ self.pruned_heads = self.pruned_heads.union(heads)
425
+
426
+ def forward(
427
+ self,
428
+ hidden_states,
429
+ attention_mask=None,
430
+ head_mask=None,
431
+ encoder_hidden_states=None,
432
+ encoder_attention_mask=None,
433
+ past_key_value=None,
434
+ output_attentions=False,
435
+ ):
436
+ self_outputs = self.self(
437
+ hidden_states,
438
+ attention_mask,
439
+ head_mask,
440
+ encoder_hidden_states,
441
+ encoder_attention_mask,
442
+ past_key_value,
443
+ output_attentions,
444
+ )
445
+
446
+ attention_output = self.output(self_outputs, hidden_states)
447
+ # print(self_outputs.shape, attention_output.shape, 'output of BertAttention')
448
+ # outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
449
+ return attention_output
450
+
src/scripts/__init__.py ADDED
File without changes
src/scripts/batch_decode.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys, glob
2
+ # full_lst = glob.glob('diff_models_synth128*')
3
+ # full_lst = glob.glob('diff_models_synth32*')
4
+ # full_lst = glob.glob('diff_models_synth32_3_rand16*')
5
+ # full_lst = glob.glob('diff_models_synth_rand_16_trans_lr_1e-5_long_Lsimple')
6
+ full_lst = glob.glob(sys.argv[1])
7
+ top_p = -1.0 if len(sys.argv) < 2 else sys.argv[2]
8
+ print(f'top_p = {top_p}')
9
+ pattern_ = 'model' if len(sys.argv) < 3 else sys.argv[3]
10
+ print(f'pattern_ = {pattern_}', sys.argv[3])
11
+ # print(full_lst)
12
+
13
+ output_lst = []
14
+ for lst in full_lst:
15
+ print(lst)
16
+ try:
17
+ tgt = sorted(glob.glob(f"{lst}/{pattern_}*pt"))[-1]
18
+ lst = os.path.split(lst)[1]
19
+ print(lst)
20
+ num = 1
21
+ except:
22
+ continue
23
+ model_arch_ = lst.split('_')[5-num]
24
+ model_arch = 'conv-unet' if 'conv-unet' in lst else 'transformer'
25
+ mode = 'image' if ('conv' in model_arch ) else 'text' #or '1d-unet' in model_arch_
26
+ print(mode, model_arch_)
27
+ dim_ =lst.split('_')[4-num]
28
+
29
+ # diffusion_steps= 4000
30
+ # noise_schedule = 'cosine'
31
+ # dim = dim_.split('rand')[1]
32
+
33
+ if 'synth' in lst:
34
+ modality = 'synth'
35
+ elif 'pos' in lst:
36
+ modality = 'pos'
37
+ elif 'image' in lst:
38
+ modality = 'image'
39
+ elif 'roc' in lst:
40
+ modality = 'roc'
41
+ elif 'e2e-tgt' in lst:
42
+ modality = 'e2e-tgt'
43
+ elif 'simple-wiki' in lst:
44
+ modality = 'simple-wiki'
45
+ elif 'book' in lst:
46
+ modality = 'book'
47
+ elif 'yelp' in lst:
48
+ modality = 'yelp'
49
+ elif 'commonGen' in lst:
50
+ modality = 'commonGen'
51
+ elif 'e2e' in lst:
52
+ modality = 'e2e'
53
+
54
+
55
+ if 'synth32' in lst:
56
+ kk = 32
57
+ elif 'synth128' in lst:
58
+ kk = 128
59
+
60
+ try:
61
+ diffusion_steps = int(lst.split('_')[7-num])
62
+ print(diffusion_steps)
63
+ except:
64
+ diffusion_steps = 4000
65
+ try:
66
+ noise_schedule = lst.split('_')[8-num]
67
+ assert noise_schedule in ['cosine', 'linear']
68
+ print(noise_schedule)
69
+ except:
70
+ noise_schedule = 'cosine'
71
+ try:
72
+ dim = int(dim_.split('rand')[1])
73
+ except:
74
+ dim =lst.split('_')[4-num]
75
+ try:
76
+ print(len(lst.split('_')))
77
+ num_channels = int(lst.split('_')[-1].split('h')[1])
78
+ except:
79
+ num_channels = 128
80
+
81
+ print(tgt, model_arch, dim, num_channels)
82
+ # out_dir = 'diffusion_lm/improved_diffusion/out_gen_large_nucleus'
83
+ # num_samples = 512
84
+
85
+ # out_dir = 'diffusion_lm/improved_diffusion/out_gen_v2_nucleus'
86
+
87
+ out_dir = 'generation_outputs'
88
+ num_samples = 50
89
+
90
+ if modality == 'e2e':
91
+ num_samples = 547
92
+
93
+ COMMAND = f'python scripts/{mode}_sample.py ' \
94
+ f'--model_path {tgt} --batch_size 50 --num_samples {num_samples} --top_p {top_p} ' \
95
+ f'--out_dir {out_dir} '
96
+ print(COMMAND)
97
+ # os.system(COMMAND)
98
+
99
+ # shape_str = "x".join([str(x) for x in arr.shape])
100
+ model_base_name = os.path.basename(os.path.split(tgt)[0]) + f'.{os.path.split(tgt)[1]}'
101
+ if modality == 'e2e-tgt' or modality == 'e2e':
102
+ out_path2 = os.path.join(out_dir, f"{model_base_name}.samples_{top_p}.json")
103
+ else:
104
+ out_path2 = os.path.join(out_dir, f"{model_base_name}.samples_{top_p}.txt")
105
+ output_cands = glob.glob(out_path2)
106
+ print(out_path2, output_cands)
107
+ if len(output_cands) > 0:
108
+ out_path2 = glob.glob(out_path2)[0]
109
+ else:
110
+ os.system(COMMAND)
111
+ out_path2 = glob.glob(out_path2)[0]
112
+
113
+ output_lst.append(out_path2)
114
+
115
+ if modality == 'pos':
116
+ model_name_path = 'predictability/diff_models/pos_e=15_b=20_m=gpt2_wikitext-103-raw-v1_s=102'
117
+ elif modality == 'synth':
118
+ if kk == 128:
119
+ model_name_path = 'predictability/diff_models/synth_e=15_b=10_m=gpt2_wikitext-103-raw-v1_None'
120
+ else:
121
+ model_name_path = 'predictability/diff_models/synth_e=15_b=20_m=gpt2_wikitext-103-raw-v1_None'
122
+ elif modality == 'e2e-tgt':
123
+ model_name_path = "predictability/diff_models/e2e-tgt_e=15_b=20_m=gpt2_wikitext-103-raw-v1_101_None"
124
+ elif modality == 'roc':
125
+ model_name_path = "predictability/diff_models/roc_e=6_b=10_m=gpt2_wikitext-103-raw-v1_101_wp_pad_v1"
126
+ elif modality == 'e2e':
127
+ COMMAND1 = f"python diffusion_lm/e2e_data/mbr.py {out_path2}"
128
+
129
+ os.system(COMMAND1)
130
+ COMMAND2 = f"python e2e-metrics/measure_scores.py " \
131
+ f"diffusion_lm/improved_diffusion/out_gen_v2_dropout2/1_valid_gold " \
132
+ f"{out_path2}.clean -p -t -H > {os.path.join(os.path.split(tgt)[0], 'e2e_valid_eval.txt')}"
133
+ print(COMMAND2)
134
+ os.system(COMMAND2)
135
+ continue
136
+ else:
137
+ print('not trained a AR model yet... only look at the output plz.')
138
+ continue
139
+ COMMAND = f"python scripts/ppl_under_ar.py " \
140
+ f"--model_path {tgt} " \
141
+ f"--modality {modality} --experiment random " \
142
+ f"--model_name_or_path {model_name_path} " \
143
+ f"--input_text {out_path2} --mode eval"
144
+
145
+ print(COMMAND)
146
+ print()
147
+ os.system(COMMAND)
148
+ print('output lists:')
149
+ print("\n".join(output_lst))
src/scripts/batch_nll.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys, glob
2
+ full_lst = glob.glob(sys.argv[1])
3
+ pattern_ = 'model' if len(sys.argv) < 2 else sys.argv[2]
4
+ clamp = 'clamp' if len(sys.argv) <= 3 else sys.argv[3]
5
+ print(f'pattern_ = {pattern_}', sys.argv[2])
6
+
7
+ for lst in full_lst:
8
+ print(lst)
9
+ try:
10
+ tgt = sorted(glob.glob(f"{lst}/{pattern_}*pt"))[-1]
11
+ lst = os.path.split(lst)[1]
12
+ print(lst)
13
+ num = 1
14
+ except:
15
+ continue
16
+
17
+ COMMAND = f'python scripts/nll.py --clip_denoised False ' \
18
+ f'--model_path {tgt} ' \
19
+ f'--out_dir diffusion_lm/improved_diffusion/scores_eval2_valid_None ' \
20
+ f'--num_samples 64 --split valid --clamp {clamp}'
21
+ print(COMMAND)
22
+ os.system(COMMAND)
23
+
24
+ COMMAND = f'python scripts/nll.py --clip_denoised False ' \
25
+ f'--model_path {tgt} ' \
26
+ f'--out_dir diffusion_lm/improved_diffusion/scores_eval2_valid_None ' \
27
+ f'--num_samples 64 --split train --clamp {clamp}'
28
+ print(COMMAND)
29
+ os.system(COMMAND)
src/scripts/infill_util.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch as th
2
+
3
+ def get_score(input_embs, label_ids, model_control, t=None):
4
+ label_ids2 = label_ids.clone()
5
+ label_ids2[:, :65] = -100
6
+ # print(label_ids2[:, 65:])
7
+ # print(final.shape, tgt_embs.shape)
8
+ # input_embs = th.cat([final, tgt_embs], dim=1)
9
+ model_out = model_control(input_embs=input_embs,
10
+ labels=label_ids2, t=t)
11
+ print(model_out.loss, 'final end')
12
+ loss_fn = th.nn.CrossEntropyLoss(reduction='none')
13
+ shifted_logits = model_out.logits[:, :-1].contiguous()
14
+ shifted_labels = label_ids2[:, 1:].contiguous()
15
+ loss = loss_fn(shifted_logits.view(-1, shifted_logits.size(-1)), shifted_labels.view(-1)).reshape(
16
+ shifted_labels.shape)
17
+ return loss.sum(dim=-1).tolist()
18
+
19
+
20
+ def langevin_fn3(debug_lst, model_control, model3, label_ids, step_size, sample, mean, sigma,
21
+ alpha, t, prev_sample): # current best.
22
+ if t[0].item() < 10:
23
+ K = 0
24
+ else:
25
+ K = 3
26
+ # K = 3
27
+
28
+ if t[0].item() > 0:
29
+ tt = t[0].item() - 1
30
+ else:
31
+ tt = 200
32
+ label_ids = label_ids.cuda()
33
+ tgt_embs = model3(label_ids[:, sample.size(1):])
34
+
35
+ label_ids2 = label_ids.clone()
36
+ label_ids2[:, :65] = -100
37
+ input_embs_param = th.nn.Parameter(sample)
38
+ if False:
39
+ input_embs = th.cat([input_embs_param, tgt_embs], dim=1)
40
+ debug_lst.append(get_score(input_embs, label_ids2, model_control, t=tt))
41
+ with th.enable_grad():
42
+ for i in range(K):
43
+ optimizer = th.optim.Adagrad([input_embs_param], lr=step_size)
44
+ optimizer.zero_grad()
45
+ input_embs = th.cat([input_embs_param, tgt_embs], dim=1)
46
+ model_out = model_control(input_embs=input_embs,
47
+ labels=label_ids2, t=tt)
48
+
49
+ coef = 0.01
50
+ # coef=1.
51
+ if sigma.mean() == 0:
52
+ logp_term = coef * ((mean - input_embs_param) ** 2 / 1.).mean(dim=0).sum()
53
+ else:
54
+ logp_term = coef * ((mean - input_embs_param) ** 2 / sigma).mean(dim=0).sum()
55
+ # print(model_out.loss, f'start_{i}', logp_term.item(), t[0].item(), sigma.mean().item())
56
+ loss = model_out.loss + logp_term
57
+ loss.backward()
58
+ optimizer.step()
59
+ epsilon = th.randn_like(input_embs_param.data)
60
+ input_embs_param = th.nn.Parameter((input_embs_param.data + 0.0 * sigma.mean().item() * epsilon).detach())
61
+ # input_embs_param = th.nn.Parameter((input_embs_param.data +
62
+ # np.sqrt(2*sigma.mean().item()) * epsilon).detach())
63
+
64
+ # input_embs = th.cat([input_embs_param, tgt_embs], dim=1)
65
+ # model_out = model_control(input_embs=input_embs,
66
+ # labels=label_ids2,
67
+ # t=tt)
68
+ # print(model_out.loss, 'end')
69
+
70
+ return input_embs_param.data
71
+
72
+ def langevin_fn4(debug_lst, model_control, model3, label_ids, step_size, sample, mean, sigma,
73
+ alpha, t, prev_sample): # current best.
74
+ if t[0].item() < 10:
75
+ K = 0
76
+ else:
77
+ K = 3
78
+
79
+ if t[0].item() >0:
80
+ tt =t[0].item() - 1
81
+ else:
82
+ tt = 200
83
+ label_ids = label_ids.cuda()
84
+ input_embs_param = th.nn.Parameter(sample)
85
+ if False:
86
+ input_embs = th.cat([input_embs_param, tgt_embs], dim=1)
87
+ debug_lst.append(get_score(input_embs, label_ids2, model_control, t=tt))
88
+ with th.enable_grad():
89
+ for i in range(K):
90
+ optimizer = th.optim.Adagrad([input_embs_param], lr=step_size)
91
+ optimizer.zero_grad()
92
+ # print(input_embs_param.shape, label_ids.shape)
93
+ model_out = model_control(input_embs=input_embs_param, pos_ids=label_ids, t=tt)
94
+
95
+ coef = 0.0001 # prev default.
96
+ # coef = 0.001
97
+ # coef = 0.0005
98
+
99
+
100
+ # coef=1.
101
+ if sigma.mean() == 0:
102
+ logp_term = coef * ((mean - input_embs_param) ** 2 / 1.).mean(dim=0).sum()
103
+ else:
104
+ logp_term = coef * ((mean - input_embs_param)**2 / sigma).mean(dim=0).sum()
105
+ print(model_out.loss, f'start_{i}', logp_term.item(),
106
+ t[0].item(), sigma.mean().item())
107
+ loss = model_out.loss + logp_term
108
+ loss.backward()
109
+ optimizer.step()
110
+ epsilon = th.randn_like(input_embs_param.data)
111
+ input_embs_param = th.nn.Parameter((input_embs_param.data + 0.0*sigma.mean().item() * epsilon).detach())
112
+ # input_embs_param = th.nn.Parameter((input_embs_param.data +
113
+ # np.sqrt(2*sigma.mean().item()) * epsilon).detach())
114
+
115
+ model_out = model_control(input_embs=input_embs_param, pos_ids=label_ids, t=tt)
116
+ print(model_out.loss, 'end')
117
+
118
+ return input_embs_param.data
119
+
120
+ def langevin_fn_length(coeff, diffusion, partial_mask, diff_model, tgt_embs, step_size, sample, mean, sigma,
121
+ alpha, t, prev_sample): # current best.
122
+ if t[0].item() < 10:
123
+ K = 0
124
+ else:
125
+ K = 3
126
+
127
+ if t[0].item() >0:
128
+ tt =t[0].item() - 1
129
+ else:
130
+ tt = 200
131
+ input_embs_param = th.nn.Parameter(sample)
132
+ if False:
133
+ input_embs = th.cat([input_embs_param, tgt_embs], dim=1)
134
+ debug_lst.append(get_score(input_embs, label_ids2, model_control, t=tt))
135
+ with th.enable_grad():
136
+ for i in range(K):
137
+ optimizer = th.optim.Adagrad([input_embs_param], lr=step_size)
138
+ optimizer.zero_grad()
139
+ print(t.shape)
140
+ # print(input_embs_param.shape, label_ids.shape)
141
+ out = diffusion.p_mean_variance(
142
+ diff_model,
143
+ input_embs_param,
144
+ t,
145
+ clip_denoised=False,
146
+ denoised_fn=None,
147
+ model_kwargs={},
148
+ )
149
+
150
+ # model_out = model_control(input_embs=input_embs_param, pos_ids=label_ids, t=tt)
151
+ coef = coeff
152
+ # coef = 0.0001 # prev default.
153
+ # coef = 0.001
154
+ # coef = 0.0005
155
+
156
+
157
+ # coef=1.
158
+ if sigma.mean() == 0:
159
+ logp_term = coef * ((mean - input_embs_param) ** 2 / 1.).mean(dim=0).sum()
160
+ infill_loss = (out['pred_xstart'][~partial_mask] - tgt_embs[~partial_mask]) ** 2
161
+ infill_loss = infill_loss.mean(dim=0).sum()
162
+ else:
163
+ logp_term = coef * ((mean - input_embs_param)**2 / sigma).mean(dim=0).sum()
164
+ # print(out['pred_xstart'].shape, tgt_embs.shape)
165
+ # print(partial_mask[0])
166
+ infill_loss = ((out['pred_xstart'][~partial_mask] - tgt_embs[~partial_mask]) ** 2).view(tgt_embs.size(0), -1, tgt_embs.size(-1) )
167
+ # print(infill_loss.shape, ((mean - input_embs_param)**2).shape )
168
+ infill_loss = (infill_loss/sigma.mean()).mean(dim=0).sum()
169
+ print(infill_loss, f'start_{i}', logp_term.item(),
170
+ t[0].item(), sigma.mean().item())
171
+ loss = logp_term + infill_loss
172
+ loss.backward()
173
+ optimizer.step()
174
+ epsilon = th.randn_like(input_embs_param.data)
175
+ input_embs_param = th.nn.Parameter((input_embs_param.data + 0.0*sigma.mean().item() * epsilon).detach())
176
+ # input_embs_param = th.nn.Parameter((input_embs_param.data +
177
+ # np.sqrt(2*sigma.mean().item()) * epsilon).detach())
178
+
179
+ # model_out = model_control(input_embs=input_embs_param, pos_ids=label_ids, t=tt)
180
+ # print(model_out.loss, 'end')
181
+
182
+ return input_embs_param.data
183
+
184
+ def langevin_fn_tree(coeff, model_control, model3, label_ids, step_size, sample, mean, sigma,
185
+ alpha, t, prev_sample): # current best.
186
+ if t[0].item() < 10:
187
+ K = 0
188
+ else:
189
+ K = 3
190
+
191
+ if t[0].item() >0:
192
+ tt =t[0].item() - 1
193
+ else:
194
+ tt = 200
195
+ label_ids = label_ids.cuda()
196
+ input_embs_param = th.nn.Parameter(sample)
197
+
198
+ with th.enable_grad():
199
+ for i in range(K):
200
+ optimizer = th.optim.Adagrad([input_embs_param], lr=step_size)
201
+ optimizer.zero_grad()
202
+ # print(input_embs_param.shape, label_ids.shape)
203
+ model_out = model_control(input_embs=input_embs_param, parse_chart=label_ids, t=tt)
204
+
205
+ # coef = 0.0001
206
+ # coef = 0.001
207
+ # coef = 0.01
208
+
209
+ # coef = 0.1 # good for partial.
210
+ # coef=0.001 # also good for full (more fluent).
211
+ # coef=0.0001
212
+
213
+ # coef=0.0005 # good for full.
214
+ coef = coeff
215
+
216
+ # coef = 0.5
217
+
218
+
219
+ # coef=1.
220
+ if sigma.mean() == 0:
221
+ logp_term = coef * ((mean - input_embs_param) ** 2 / 1.).mean(dim=0).sum()
222
+ else:
223
+ logp_term = coef * ((mean - input_embs_param)**2 / sigma).mean(dim=0).sum()
224
+ # print(model_out.loss, f'start_{i}', logp_term.item(),
225
+ # t[0].item(), sigma.mean().item())
226
+ loss = model_out.loss + logp_term
227
+ loss.backward()
228
+ optimizer.step()
229
+ epsilon = th.randn_like(input_embs_param.data)
230
+ input_embs_param = th.nn.Parameter((input_embs_param.data + 0.0*sigma.mean().item() * epsilon).detach())
231
+ # input_embs_param = th.nn.Parameter((input_embs_param.data +
232
+ # np.sqrt(2*sigma.mean().item()) * epsilon).detach())
233
+
234
+ # COMMENT OUT
235
+ # model_out = model_control(input_embs=input_embs_param, parse_chart=label_ids, t=tt)
236
+ # print(model_out.loss, 'end')
237
+
238
+ return input_embs_param.data
239
+
240
+ def langevin_fn1(debug_lst, model_control, model3, label_ids, step_size, sample, mean, sigma,
241
+ alpha, t, prev_sample): # current best.
242
+ if t[0].item() < 10:
243
+ K = 0
244
+ else:
245
+ K = 1
246
+ # K = 3
247
+
248
+ if t[0].item() > 0:
249
+ tt = t[0].item() - 1
250
+ else:
251
+ tt = 200
252
+ label_ids = label_ids.cuda()
253
+ tgt_embs = model3(label_ids[:, sample.size(1):])
254
+
255
+ label_ids2 = label_ids.clone()
256
+ label_ids2[:, :65] = -100
257
+ input_embs_param = th.nn.Parameter(sample)
258
+ if True:
259
+ input_embs = th.cat([input_embs_param, tgt_embs], dim=1)
260
+ debug_lst.append(get_score(input_embs, label_ids2, model_control, t=tt))
261
+ with th.enable_grad():
262
+ for i in range(K):
263
+ optimizer = th.optim.Adagrad([input_embs_param], lr=step_size)
264
+ optimizer.zero_grad()
265
+ input_embs = th.cat([input_embs_param, tgt_embs], dim=1)
266
+ model_out = model_control(input_embs=input_embs,
267
+ labels=label_ids2, t=tt)
268
+
269
+ # coef = 0.0
270
+ # if sigma.mean() == 0:
271
+ # logp_term = coef * ((mean - input_embs_param) ** 2 / 1.).mean(dim=0).sum()
272
+ # else:
273
+ # logp_term = coef * ((mean - input_embs_param) ** 2 / sigma).mean(dim=0).sum()
274
+ print(model_out.loss, f'start_{i}', t[0].item(), sigma.mean().item())
275
+ coef = 3.
276
+ loss = model_out.loss # + logp_term
277
+ loss.backward()
278
+ # print(input_embs_param.grad.shape, )
279
+ input_embs_param.data = input_embs_param.data - coef * sigma.mean().item() * input_embs_param.grad
280
+ # optimizer.step()
281
+ # epsilon = th.randn_like(input_embs_param.data)
282
+ # input_embs_param = th.nn.Parameter((input_embs_param.data + 0.0 * sigma.mean().item() * epsilon).detach())
283
+ # input_embs_param = th.nn.Parameter((input_embs_param.data +
284
+ # np.sqrt(2*sigma.mean().item()) * epsilon).detach())
285
+
286
+ input_embs = th.cat([input_embs_param, tgt_embs], dim=1)
287
+ model_out = model_control(input_embs=input_embs,
288
+ labels=label_ids2,
289
+ t=tt)
290
+ print(model_out.loss, 'end')
291
+ # if True:
292
+ # debug_lst.append(get_score(input_embs, label_ids2, model_control, t=tt))
293
+
294
+ return input_embs_param.data
295
+
296
+
297
+ def langevin_fn3_compose(debug_lst, model_control, model3, label_ids_lst, step_size, sample, mean, sigma,
298
+ alpha, t, prev_sample): # current best.
299
+ if t[0].item() < 10:
300
+ K = 0
301
+ else:
302
+ K = 3
303
+ # K = 3
304
+
305
+ if t[0].item() > 0:
306
+ tt = t[0].item() - 1
307
+ else:
308
+ tt = 200
309
+
310
+ tgt_embs_lst = [model3(label_ids[:, sample.size(1):]) for label_ids in label_ids_lst]
311
+
312
+ label_ids2_lst = []
313
+ for label_ids in label_ids_lst:
314
+ label_ids2 = label_ids.clone()
315
+ label_ids2[:, :65] = -100
316
+ label_ids2_lst.append(label_ids2)
317
+
318
+ input_embs_param = th.nn.Parameter(sample)
319
+ if True:
320
+ part_score = []
321
+ for (tgt_embs,label_ids2) in zip(tgt_embs_lst, label_ids2_lst):
322
+ input_embs = th.cat([input_embs_param, tgt_embs], dim=1)
323
+ score_ = get_score(input_embs, label_ids2, model_control, t=tt)
324
+ part_score.append(score_)
325
+ debug_lst.append(part_score)
326
+ with th.enable_grad():
327
+ for i in range(K):
328
+ optimizer = th.optim.Adagrad([input_embs_param], lr=step_size)
329
+ optimizer.zero_grad()
330
+ cum_loss = 0
331
+ for (tgt_embs, label_ids2) in zip(tgt_embs_lst, label_ids2_lst):
332
+ input_embs = th.cat([input_embs_param, tgt_embs], dim=1)
333
+ model_out = model_control(input_embs=input_embs,
334
+ labels=label_ids2, t=tt)
335
+ cum_loss += model_out.loss
336
+
337
+ coef = 0.01
338
+ if sigma.mean() == 0:
339
+ logp_term = coef * ((mean - input_embs_param) ** 2 / 1.).mean(dim=0).sum()
340
+ else:
341
+ logp_term = coef * ((mean - input_embs_param) ** 2 / sigma).mean(dim=0).sum()
342
+ print(cum_loss, f'start_{i}', logp_term.item(), t[0].item(), sigma.mean().item())
343
+ loss = cum_loss + logp_term
344
+ loss.backward()
345
+ optimizer.step()
346
+ epsilon = th.randn_like(input_embs_param.data)
347
+ input_embs_param = th.nn.Parameter((input_embs_param.data + 0.0 * sigma.mean().item() * epsilon).detach())
348
+
349
+ part_score = []
350
+ for (tgt_embs, label_ids2) in zip(tgt_embs_lst, label_ids2_lst):
351
+ input_embs = th.cat([input_embs_param, tgt_embs], dim=1)
352
+ score_ = get_score(input_embs, label_ids2, model_control, t=tt)
353
+ part_score.append(score_)
354
+
355
+ return input_embs_param.data
src/scripts/mydatasets.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import torch
4
+ import random
5
+ import selfies as sf
6
+ from rdkit import Chem
7
+ from datasets import load_dataset
8
+ from transformers import T5EncoderModel
9
+ from torch.utils.data import DistributedSampler, DataLoader, Dataset
10
+
11
+
12
+ def get_dataloader(dataset, batchsize, rank, world_size):
13
+ sampler = DistributedSampler(
14
+ dataset, num_replicas=world_size, rank=rank, shuffle=True
15
+ )
16
+
17
+ def collate(batch):
18
+ selfies_ids = [i["selfies_ids"] for i in batch]
19
+ caption_state = [i["caption_state"] for i in batch]
20
+ caption_mask = [i["caption_mask"] for i in batch]
21
+ corrupted_selfies_ids = [i["corrupted_selfies_ids"] for i in batch]
22
+ return (
23
+ torch.concat(selfies_ids, dim=0),
24
+ torch.concat(caption_state, dim=0),
25
+ torch.concat(caption_mask, dim=0),
26
+ torch.concat(corrupted_selfies_ids, dim=0),
27
+ )
28
+
29
+ dataloader = DataLoader(
30
+ dataset,
31
+ batch_size=batchsize,
32
+ shuffle=False,
33
+ collate_fn=collate,
34
+ sampler=sampler,
35
+ )
36
+
37
+ def cycle():
38
+ ec = 0
39
+ while True:
40
+ dataloader.sampler.set_epoch(ec)
41
+ for i in dataloader:
42
+ yield i
43
+ ec += 1
44
+
45
+ return iter(cycle())
46
+
47
+
48
+ class Lang2molDataset_train(Dataset):
49
+ def __init__(
50
+ self,
51
+ dir,
52
+ tokenizer,
53
+ split,
54
+ dataset_name,
55
+ pre=None,
56
+ prob=0,
57
+ load_state=True,
58
+ corrupt_prob=0.4,
59
+ token_max_length=256,
60
+ ):
61
+ super().__init__()
62
+ self.dir = dir
63
+ self.tokenizer = tokenizer
64
+ self.split = split
65
+ self.pre = pre
66
+ self.prob = prob
67
+ self.corrupt_prob = corrupt_prob
68
+ self.token_max_length = token_max_length
69
+ self.dataset_name = dataset_name
70
+ self.ori_data = self.create_data()
71
+ self.load_state = load_state
72
+ self.model = T5EncoderModel.from_pretrained("QizhiPei/biot5-base-text2mol")
73
+ self.model.to("cuda")
74
+ self.model.eval()
75
+
76
+ def create_data(self):
77
+ try:
78
+ dataset = load_dataset(
79
+ self.dataset_name,
80
+ token=True,
81
+ split=self.split,
82
+ ).sort("id")
83
+ except:
84
+ dataset = load_dataset(
85
+ self.dataset_name,
86
+ use_auth_token=True,
87
+ split=self.split,
88
+ ).sort("id")
89
+
90
+ return [
91
+ (int(sample_id), sample_selfies, sample_caption, sample_smiles)
92
+ for (sample_id, sample_selfies, sample_caption, sample_smiles) in zip(
93
+ dataset["id"],
94
+ dataset["selfies"],
95
+ dataset["caption"],
96
+ dataset["smiles"],
97
+ )
98
+ ]
99
+
100
+ def __len__(self):
101
+ return len(self.ori_data)
102
+
103
+ def permute(self, selfies):
104
+ if random.random() < self.prob:
105
+ return changeorder(selfies, shuffle=True)
106
+ else:
107
+ return selfies
108
+
109
+ def __getitem__(self, idx):
110
+ data = self.ori_data[idx]
111
+ sample = {
112
+ "id": data[0],
113
+ "selfies": self.permute(data[1]),
114
+ "caption": data[2],
115
+ "smiles": data[3],
116
+ }
117
+
118
+ # Molecules
119
+ output_molecule = self.tokenizer(
120
+ sample["selfies"],
121
+ max_length=self.token_max_length,
122
+ truncation=True,
123
+ padding="max_length",
124
+ add_special_tokens=True,
125
+ return_tensors="pt",
126
+ return_attention_mask=True,
127
+ )
128
+ sample["selfies_ids"] = output_molecule["input_ids"]
129
+ sample["corrupted_selfies_ids"] = sample["selfies_ids"]
130
+
131
+ # Captions
132
+ output_caption = self.tokenizer(
133
+ sample["caption"],
134
+ max_length=self.token_max_length,
135
+ truncation=True,
136
+ padding="max_length",
137
+ add_special_tokens=True,
138
+ return_tensors="pt",
139
+ return_attention_mask=True,
140
+ )
141
+ sample["caption_state"] = self.model(
142
+ input_ids=output_caption["input_ids"].to("cuda"),
143
+ attention_mask=output_caption["attention_mask"].to("cuda"),
144
+ ).last_hidden_state
145
+ sample["caption_mask"] = output_caption["attention_mask"]
146
+
147
+ return sample
148
+
149
+
150
+ class Lang2molDataset_eval(Dataset):
151
+ def __init__(
152
+ self,
153
+ dir,
154
+ tokenizer,
155
+ split,
156
+ dataset_name,
157
+ pre=None,
158
+ prob=0,
159
+ load_state=True,
160
+ corrupt_prob=0.4,
161
+ token_max_length=512,
162
+ ):
163
+ super().__init__()
164
+ self.dir = dir
165
+ self.tokenizer = tokenizer
166
+ self.split = split
167
+ self.pre = pre
168
+ self.prob = prob
169
+ self.corrupt_prob = corrupt_prob
170
+ self.token_max_length = token_max_length
171
+ self.dataset_name = dataset_name
172
+ self.ori_data = self.create_data()
173
+ self.load_state = load_state
174
+ self.model = T5EncoderModel.from_pretrained("QizhiPei/biot5-base-text2mol")
175
+ self.model.to("cuda")
176
+ self.model.eval()
177
+
178
+ def create_data(self):
179
+ try:
180
+ dataset = load_dataset(
181
+ self.dataset_name,
182
+ token=True,
183
+ split=self.split,
184
+ ).sort("id")
185
+ except:
186
+ dataset = load_dataset(
187
+ self.dataset_name,
188
+ use_auth_token=True,
189
+ split=self.split,
190
+ ).sort("id")
191
+
192
+ return [
193
+ (int(sample_id), sample_selfies, sample_caption, sample_smiles)
194
+ for (sample_id, sample_selfies, sample_caption, sample_smiles) in zip(
195
+ dataset["id"],
196
+ dataset["selfies"],
197
+ dataset["caption"],
198
+ dataset["smiles"],
199
+ )
200
+ ]
201
+
202
+ def __len__(self):
203
+ return len(self.ori_data)
204
+
205
+ def permute(self, selfies):
206
+ if random.random() < self.prob:
207
+ return changeorder(selfies, shuffle=True)
208
+ else:
209
+ return selfies
210
+
211
+ def __getitem__(self, idx):
212
+ data = self.ori_data[idx]
213
+ sample = {
214
+ "id": data[0],
215
+ "selfies": self.permute(data[1]),
216
+ "caption": data[2],
217
+ "smiles": data[3],
218
+ }
219
+
220
+ output_caption = self.tokenizer(
221
+ sample["caption"],
222
+ max_length=self.token_max_length,
223
+ truncation=True,
224
+ padding="max_length",
225
+ add_special_tokens=True,
226
+ return_tensors="pt",
227
+ return_attention_mask=True,
228
+ )
229
+ sample["caption_state"] = self.model(
230
+ input_ids=output_caption["input_ids"].to("cuda"),
231
+ attention_mask=output_caption["attention_mask"].to("cuda"),
232
+ ).last_hidden_state
233
+ sample["caption_mask"] = output_caption["attention_mask"]
234
+
235
+ return sample
236
+
237
+
238
+ class Lang2molDataset_submission(Dataset):
239
+ def __init__(
240
+ self,
241
+ dir,
242
+ tokenizer,
243
+ split,
244
+ dataset_name,
245
+ pre=None,
246
+ prob=0,
247
+ load_state=True,
248
+ corrupt_prob=0.4,
249
+ token_max_length=256,
250
+ ):
251
+ super().__init__()
252
+ self.dir = dir
253
+ self.tokenizer = tokenizer
254
+ self.split = split
255
+ self.pre = pre
256
+ self.prob = prob
257
+ self.corrupt_prob = corrupt_prob
258
+ self.token_max_length = token_max_length
259
+ self.dataset_name = dataset_name
260
+ self.ori_data = self.create_data()
261
+ self.load_state = load_state
262
+ self.model = T5EncoderModel.from_pretrained("QizhiPei/biot5-base-text2mol")
263
+ self.model.to("cuda")
264
+ self.model.eval()
265
+
266
+ def create_data(self):
267
+ try:
268
+ dataset = load_dataset(
269
+ self.dataset_name,
270
+ token=True,
271
+ split=self.split,
272
+ )
273
+ except:
274
+ dataset = load_dataset(
275
+ self.dataset_name,
276
+ use_auth_token=True,
277
+ split=self.split,
278
+ )
279
+
280
+ return [sample_caption for sample_caption in dataset["caption"]]
281
+
282
+ def __len__(self):
283
+ return len(self.ori_data)
284
+
285
+ def permute(self, selfies):
286
+ if random.random() < self.prob:
287
+ return changeorder(selfies, shuffle=True)
288
+ else:
289
+ return selfies
290
+
291
+ def __getitem__(self, idx):
292
+ sample = {"caption": self.ori_data[idx]}
293
+
294
+ # Captions
295
+ output_caption = self.tokenizer(
296
+ sample["caption"],
297
+ max_length=self.token_max_length,
298
+ truncation=True,
299
+ padding="max_length",
300
+ add_special_tokens=True,
301
+ return_tensors="pt",
302
+ return_attention_mask=True,
303
+ )
304
+ sample["caption_state"] = self.model(
305
+ input_ids=output_caption["input_ids"].to("cuda"),
306
+ attention_mask=output_caption["attention_mask"].to("cuda"),
307
+ ).last_hidden_state
308
+ sample["caption_mask"] = output_caption["attention_mask"]
309
+
310
+ return sample
311
+
312
+
313
+ def changeorder(selfies, shuffle):
314
+ smiles = sf.encoder(selfies)
315
+ mol = Chem.MolFromSmiles(smiles)
316
+ if mol is None:
317
+ return selfies
318
+ Chem.Kekulize(mol)
319
+ atom_indices = [atom.GetIdx() for atom in mol.GetAtoms()]
320
+ if shuffle:
321
+ random.shuffle(atom_indices)
322
+ reordered_mol = Chem.RenumberAtoms(mol, atom_indices)
323
+ new_smiles = Chem.MolToSmiles(reordered_mol, kekuleSmiles=True)
324
+ new_selfies = sf.decoder(new_smiles)
325
+
326
+ return new_selfies
src/scripts/mytokenizers.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import random
4
+ import selfies as sf
5
+ from transformers import AutoTokenizer
6
+
7
+
8
+ ################################
9
+ def getrandomnumber(numbers, k, weights=None):
10
+ if k == 1:
11
+ return random.choices(numbers, weights=weights, k=k)[0]
12
+ else:
13
+ return random.choices(numbers, weights=weights, k=k)
14
+
15
+
16
+ # simple smiles tokenizer
17
+ # treat every charater as token
18
+ def build_simple_smiles_vocab(dir):
19
+ assert dir is not None, "dir and smiles_vocab can not be None at the same time."
20
+ if not os.path.exists(os.path.join(dir, "simple_smiles_tokenizer_vocab.txt")):
21
+ # print('Generating Vocabulary for {} ...'.format(dir))
22
+ dirs = list(
23
+ os.path.join(dir, i) for i in ["train.txt", "validation.txt", "test.txt"]
24
+ )
25
+ smiles = []
26
+ for idir in dirs:
27
+ with open(idir, "r") as f:
28
+ for i, line in enumerate(f):
29
+ if i == 0:
30
+ continue
31
+ line = line.split("\t")
32
+ assert len(line) == 3, "Dataset format error."
33
+ if line[1] != "*":
34
+ smiles.append(line[1].strip())
35
+ char_set = set()
36
+ for smi in smiles:
37
+ for c in smi:
38
+ char_set.add(c)
39
+ vocabstring = "".join(char_set)
40
+ with open(os.path.join(dir, "simple_smiles_tokenizer_vocab.txt"), "w") as f:
41
+ f.write(os.path.join(vocabstring))
42
+ return vocabstring
43
+ else:
44
+ print("Reading in Vocabulary...")
45
+ with open(os.path.join(dir, "simple_smiles_tokenizer_vocab.txt"), "r") as f:
46
+ vocabstring = f.readline().strip()
47
+ return vocabstring
48
+
49
+
50
+ class Tokenizer:
51
+ def __init__(
52
+ self,
53
+ pretrained_name="QizhiPei/biot5-base-text2mol",
54
+ selfies_dict_path=os.path.join("dataset", "selfies_dict.txt"),
55
+ ):
56
+ self.tokenizer = self.get_tokenizer(pretrained_name, selfies_dict_path)
57
+
58
+ def get_tokenizer(self, pretrained_name, selfies_dict_path):
59
+ tokenizer = AutoTokenizer.from_pretrained(pretrained_name, use_fast=True)
60
+ tokenizer.model_max_length = int(1e9)
61
+
62
+ amino_acids = [
63
+ "A",
64
+ "C",
65
+ "D",
66
+ "E",
67
+ "F",
68
+ "G",
69
+ "H",
70
+ "I",
71
+ "K",
72
+ "L",
73
+ "M",
74
+ "N",
75
+ "P",
76
+ "Q",
77
+ "R",
78
+ "S",
79
+ "T",
80
+ "V",
81
+ "W",
82
+ "Y",
83
+ ]
84
+ prefixed_amino_acids = [f"<p>{aa}" for aa in amino_acids]
85
+ tokenizer.add_tokens(prefixed_amino_acids)
86
+
87
+ selfies_dict_list = [line.strip() for line in open(selfies_dict_path)]
88
+ tokenizer.add_tokens(selfies_dict_list)
89
+
90
+ special_tokens_dict = {
91
+ "additional_special_tokens": [
92
+ "<bom>",
93
+ "<eom>",
94
+ "<bop>",
95
+ "<eop>",
96
+ "MOLECULE NAME",
97
+ "DESCRIPTION",
98
+ "PROTEIN NAME",
99
+ "FUNCTION",
100
+ "SUBCELLULAR LOCATION",
101
+ "PROTEIN FAMILIES",
102
+ ]
103
+ }
104
+ tokenizer.add_special_tokens(special_tokens_dict)
105
+ return tokenizer
106
+
107
+ def __call__(self, *args, **kwds):
108
+ return self.tokenizer(*args, **kwds)
109
+
110
+ def __len__(self):
111
+ return len(self.tokenizer)
112
+
113
+ def corrupt(self, selfies_list: list):
114
+ tensors = []
115
+ if type(selfies_list) is str:
116
+ selfies_list = [selfies_list]
117
+ for selfies in selfies_list:
118
+ tensors.append(self.corrupt_one(selfies))
119
+ return torch.concat(tensors, dim=0)
120
+
121
+ # TODO: rewrite this for selfies
122
+ def corrupt_one(self, selfies):
123
+ smi = sf.decoder(selfies)
124
+ # res = [self.toktoid[i] for i in self.rg.findall(smi)]
125
+ res = [i for i in self.rg.findall(smi)]
126
+ total_length = len(res) + 2
127
+ if total_length > self.max_len:
128
+ return self.encode_one(smi)
129
+ ######################## start corruption ###########################
130
+ r = random.random()
131
+ if r < 0.3:
132
+ pa, ring = True, True
133
+ elif r < 0.65:
134
+ pa, ring = True, False
135
+ else:
136
+ pa, ring = False, True
137
+ #########################
138
+ max_ring_num = 1
139
+ ringpos = []
140
+ papos = []
141
+ for pos, at in enumerate(res):
142
+ if at == "(" or at == ")":
143
+ papos.append(pos)
144
+ elif at.isnumeric():
145
+ max_ring_num = max(max_ring_num, int(at))
146
+ ringpos.append(pos)
147
+ # ( & ) remove
148
+ r = random.random()
149
+ if r < 0.3:
150
+ remove, padd = True, True
151
+ elif r < 0.65:
152
+ remove, padd = True, False
153
+ else:
154
+ remove, padd = False, True
155
+ if pa and len(papos) > 0:
156
+ if remove:
157
+ # remove pa
158
+ n_remove = getrandomnumber(
159
+ [1, 2, 3, 4], 1, weights=[0.6, 0.2, 0.1, 0.1]
160
+ )
161
+ p_remove = set(random.choices(papos, weights=None, k=n_remove))
162
+ total_length -= len(p_remove)
163
+ for p in p_remove:
164
+ res[p] = None
165
+ # print('debug pa delete {}'.format(p))
166
+ # Ring remove
167
+ r = random.random()
168
+ if r < 0.3:
169
+ remove, radd = True, True
170
+ elif r < 0.65:
171
+ remove, radd = True, False
172
+ else:
173
+ remove, radd = False, True
174
+ if ring and len(ringpos) > 0:
175
+ if remove:
176
+ # remove ring
177
+ n_remove = getrandomnumber(
178
+ [1, 2, 3, 4], 1, weights=[0.7, 0.2, 0.05, 0.05]
179
+ )
180
+ p_remove = set(random.choices(ringpos, weights=None, k=n_remove))
181
+ total_length -= len(p_remove)
182
+ for p in p_remove:
183
+ res[p] = None
184
+ # print('debug ring delete {}'.format(p))
185
+ # ring add & ( ) add
186
+ if pa:
187
+ if padd:
188
+ n_add = getrandomnumber([1, 2, 3], 1, weights=[0.8, 0.2, 0.1])
189
+ n_add = min(self.max_len - total_length, n_add)
190
+ for _ in range(n_add):
191
+ sele = random.randrange(len(res) + 1)
192
+ res.insert(sele, "(" if random.random() < 0.5 else ")")
193
+ # print('debug pa add {}'.format(sele))
194
+ total_length += 1
195
+ if ring:
196
+ if radd:
197
+ n_add = getrandomnumber([1, 2, 3], 1, weights=[0.8, 0.2, 0.1])
198
+ n_add = min(self.max_len - total_length, n_add)
199
+ for _ in range(n_add):
200
+ sele = random.randrange(len(res) + 1)
201
+ res.insert(sele, str(random.randrange(1, max_ring_num + 1)))
202
+ # print('debug ring add {}'.format(sele))
203
+ total_length += 1
204
+
205
+ ########################## end corruption ###############################
206
+ # print('test:',res)
207
+ # print('test:',''.join([i for i in res if i is not None]))
208
+
209
+ res = [self.toktoid[i] for i in res if i is not None]
210
+ res = [1] + res + [2]
211
+ if len(res) < self.max_len:
212
+ res += [0] * (self.max_len - len(res))
213
+ else:
214
+ res = res[: self.max_len]
215
+ res[-1] = 2
216
+ return torch.LongTensor([res])
217
+
218
+ def decode_one(self, sample):
219
+ return self.tokenizer.decode(sample)
220
+
221
+ def decode(self, sample_list):
222
+ if len(sample_list.shape)==1:
223
+ return [self.decode_one(sample_list)]
224
+ return [self.decode_one(sample) for sample in sample_list]
225
+
226
+ if __name__ == "__main__":
227
+ import selfies as sf
228
+
229
+ tokenizer = Tokenizer(
230
+ selfies_dict_path=r"D:\molecule\mol-lang-bridge\dataset\selfies_dict.txt"
231
+ )
232
+ smiles = [
233
+ "[210Po]",
234
+ "C[C@H]1C(=O)[C@H]([C@H]([C@H](O1)OP(=O)(O)OP(=O)(O)OC[C@@H]2[C@H](C[C@@H](O2)N3C=C(C(=O)NC3=O)C)O)O)O",
235
+ "C(O)P(=O)(O)[O-]",
236
+ "CCCCCCCCCCCC(=O)OC(=O)CCCCCCCCCCC",
237
+ "C[C@]12CC[C@H](C[C@H]1CC[C@@H]3[C@@H]2CC[C@]4([C@H]3CCC4=O)C)O[C@H]5[C@@H]([C@H]([C@@H]([C@H](O5)C(=O)O)O)O)O",
238
+ ]
239
+ selfies = [sf.encoder(smiles_ele) for smiles_ele in smiles]
240
+ output = tokenizer(
241
+ selfies,
242
+ max_length=512,
243
+ truncation=True,
244
+ padding="max_length",
245
+ add_special_tokens=True,
246
+ return_tensors="pt",
247
+ return_attention_mask=True,
248
+ )
249
+ print(output["input_ids"])
src/scripts/nll.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Approximate the bits/dimension for an image model.
3
+ """
4
+
5
+ import argparse
6
+ import os, json
7
+ import torch as th
8
+ import numpy as np
9
+ import torch.distributed as dist
10
+
11
+ from improved_diffusion import dist_util, logger
12
+ from improved_diffusion.image_datasets import load_data
13
+ from improved_diffusion.text_datasets import load_data_text, load_synthetic_data
14
+ from improved_diffusion.script_util import (
15
+ model_and_diffusion_defaults,
16
+ create_model_and_diffusion,
17
+ add_dict_to_argparser,
18
+ args_to_dict,
19
+ )
20
+ from functools import partial
21
+ from transformers import set_seed
22
+ from improved_diffusion.test_util import get_weights, denoised_fn_round, compute_logp, load_results
23
+
24
+ def main():
25
+ set_seed(42)
26
+ args = create_argparser().parse_args()
27
+
28
+ # load configurations.
29
+ config_path = os.path.join(os.path.split(args.model_path)[0], "training_args.json")
30
+ print(config_path)
31
+ # sys.setdefaultencoding('utf-8')
32
+ with open(config_path, 'rb', ) as f:
33
+ training_args = json.load(f)
34
+
35
+ training_args['batch_size'] = args.batch_size
36
+ print(args.data_dir)
37
+ del training_args['data_dir']
38
+ # print(args.__dict__, training_args)
39
+ args.__dict__.update(training_args)
40
+ print(args.__dict__['batch_size'], training_args['batch_size'], args.clip_denoised, args.batch_size)
41
+ print(args.data_dir)
42
+ # if args.noise_level > 0.0: flag_noise=True #DEBUG
43
+ args.noise_level = 0.0
44
+ args.roc_train = 'diffusion_lm/ROCstory'
45
+ if args.modality == 'roc-aug':
46
+ args.modality = 'roc'
47
+ # DEBUG
48
+ dist_util.setup_dist()
49
+ logger.configure()
50
+
51
+ logger.log("creating model and diffusion...")
52
+ model, diffusion = create_model_and_diffusion(
53
+ **args_to_dict(args, model_and_diffusion_defaults().keys())
54
+ )
55
+ model.load_state_dict(th.load(args.model_path))
56
+ # model.load_state_dict(
57
+ # dist_util.load_state_dict(args.model_path, map_location="cpu")
58
+ # )
59
+ # diffusion.rescale_timesteps = False # IMPORTANT DEBUG --> REMOVE
60
+ model.to(dist_util.dev())
61
+ model.eval() # DEBUG
62
+
63
+ logger.log("creating data loader...")
64
+ if args.modality == 'image':
65
+ data = load_data(
66
+ data_dir=args.data_dir,
67
+ batch_size=args.batch_size,
68
+ image_size=args.image_size,
69
+ class_cond=args.class_cond,
70
+ deterministic=True,
71
+ )
72
+ elif args.modality == 'permuted_image':
73
+ # perm = np.arange(args.image_size * args.image_size)
74
+ # np.random.shuffle(perm)
75
+ model_path_base = os.path.split(args.model_path)[0]
76
+ print(f'load permutation to {model_path_base}/permutation.json')
77
+ with open(f'{model_path_base}/permutation.json', 'r') as f:
78
+ perm = json.load(f)
79
+ perm = np.array(perm)
80
+ data = load_data(
81
+ data_dir=args.data_dir,
82
+ batch_size=args.batch_size,
83
+ image_size=args.image_size,
84
+ class_cond=args.class_cond,
85
+ permutation=perm
86
+ )
87
+ elif args.modality == 'synth':
88
+ from improved_diffusion.rounding import load_models
89
+ model2, tokenizer = load_models(args.modality, args.experiment, args.model_name_or_path, args.in_channel,
90
+ os.path.split(args.model_path)[0])
91
+
92
+ data = load_synthetic_data(
93
+ data_dir=args.data_dir,
94
+ batch_size=args.batch_size,
95
+ image_size=args.image_size,
96
+ class_cond=args.class_cond,
97
+ data_args=args,
98
+ model=model2,
99
+ split='train',
100
+ # split='valid',
101
+ deterministic=True
102
+
103
+ )
104
+ elif args.modality == 'pos':
105
+ from improved_diffusion.rounding import load_models
106
+ model2, tokenizer = load_models(args.modality, args.experiment, args.model_name_or_path, args.in_channel,
107
+ os.path.split(args.model_path)[0])
108
+ data = load_synthetic_data(
109
+ data_dir=args.data_dir,
110
+ batch_size=args.batch_size,
111
+ image_size=args.image_size,
112
+ class_cond=args.class_cond,
113
+ data_args=args,
114
+ model=model2,
115
+ pos=True,
116
+ deterministic = True
117
+ )
118
+ else:
119
+ from improved_diffusion.rounding import load_models
120
+ model2, tokenizer = load_models(args.modality, args.experiment, args.model_name_or_path, args.in_channel,
121
+ os.path.split(args.model_path)[0])
122
+ # print(tokenizer)
123
+ # rev_tokenizer = {k:int(v) for k, v in tokenizer.items()}
124
+ rev_tokenizer = {v:k for k, v in tokenizer.items()}
125
+
126
+ if args.training_mode == 'e2e':
127
+ print('e2e, load the right model embeddings', '*'*80)
128
+ model2.weight = th.nn.Parameter(model.word_embedding.weight.clone().cpu())
129
+
130
+ # print(rev_tokenizer)
131
+ data = load_data_text(
132
+ data_dir=args.data_dir,
133
+ batch_size=args.batch_size,
134
+ image_size=args.image_size,
135
+ class_cond=args.class_cond,
136
+ data_args=args,
137
+ model=model2,
138
+ deterministic=True,
139
+ task_mode=args.modality,
140
+ padding_mode=args.padding_mode, # block, pad
141
+ split=args.split,
142
+ load_vocab=rev_tokenizer,
143
+ )
144
+
145
+ logger.log("evaluating...")
146
+ run_bpd_evaluation(model, diffusion, data, args.num_samples, args.clip_denoised, args, model2)
147
+
148
+
149
+
150
+
151
+ def run_bpd_evaluation(model, diffusion, data, num_samples, clip_denoised, args, model2):
152
+ all_bpd = []
153
+ all_metrics = {"vb": [], "mse": [], "xstart_mse": []}
154
+ num_complete = 0
155
+ model3 = get_weights(model2, args)
156
+ while num_complete < num_samples:
157
+ batch, model_kwargs = next(data)
158
+ batch = batch.to(dist_util.dev())
159
+ model_kwargs = {k: v.to(dist_util.dev()) for k, v in model_kwargs.items()}
160
+ model_kwargs['mapping_func'] = partial(compute_logp, args, model3.cuda())
161
+ minibatch_metrics = diffusion.calc_bpd_loop(
162
+ model, batch, clip_denoised=clip_denoised, model_kwargs=model_kwargs,
163
+ # denoised_fn=None,
164
+ denoised_fn=partial(denoised_fn_round, args, model3.cuda()) if args.clamp == 'clamp' else None,
165
+ )
166
+
167
+ for key, term_list in all_metrics.items():
168
+ terms = minibatch_metrics[key].mean(dim=0) / dist.get_world_size()
169
+ dist.all_reduce(terms)
170
+ term_list.append(terms.detach().cpu().numpy())
171
+
172
+ total_bpd = minibatch_metrics["total_bpd"]
173
+ total_bpd = total_bpd.mean() / dist.get_world_size()
174
+ dist.all_reduce(total_bpd)
175
+ all_bpd.append(total_bpd.item())
176
+ num_complete += dist.get_world_size() * batch.shape[0]
177
+
178
+ logger.log(f"done {num_complete} samples on {args.split}: bpd={np.mean(all_bpd)}, "
179
+ f"per token={np.mean(all_bpd) * args.in_channel} ", args.model_path)
180
+ temp_cat = np.mean(np.stack(all_metrics['vb']), axis=0)
181
+ if len(temp_cat) % 8 == 0:
182
+ print([y.sum() for y in np.split(np.mean(np.stack(all_metrics['vb']), axis=0), 8)])
183
+ else:
184
+ print(temp_cat[0].sum())
185
+ print([y.sum() for y in np.split(temp_cat[1:-1], 8)])
186
+ print(temp_cat[-1].sum())
187
+ vb_temp = np.mean(np.stack(all_metrics['vb']), axis=0)
188
+ print(vb_temp.shape, vb_temp.sum())
189
+ print(vb_temp[-10:])
190
+
191
+
192
+ if dist.get_rank() == 0:
193
+ for name, terms in all_metrics.items():
194
+ model_base_name = os.path.basename(
195
+ os.path.split(args.model_path)[0]) + f'.{os.path.split(args.model_path)[1]}'
196
+ # args.out_dir = os.path.join(args.out_dir, f"{model_base_name}.samples_{shape_str}.txt")
197
+ out_path = os.path.join(args.out_dir, f"{model_base_name}.{name}_{args.split}_{args.clamp}_terms.npz")
198
+ logger.log(f"saving {name} terms to {out_path}")
199
+ np.savez(out_path, np.mean(np.stack(terms), axis=0))
200
+
201
+ dist.barrier()
202
+ logger.log("evaluation complete")
203
+
204
+ if 'ema' in args.model_path:
205
+ json_path = os.path.join(os.path.split(args.model_path)[0], f'ema_score_{args.split}_nll.json')
206
+ elif args.clamp == 'noclamp':
207
+ json_path = os.path.join(os.path.split(args.model_path)[0], f'score_{args.split}_nll_noclamp.json')
208
+ else:
209
+ json_path = os.path.join(os.path.split(args.model_path)[0], f'score_{args.split}_nll.json')
210
+
211
+ print(f'written to {json_path}')
212
+ temp_cat = np.mean(np.stack(all_metrics['vb']), axis=0)
213
+ if len(temp_cat) % 8 == 0:
214
+ temp_cat = temp_cat
215
+ else:
216
+ temp_cat = temp_cat[1:-1]
217
+ json_dict = {
218
+ f'score_{args.split}_ppl_token': np.mean(all_bpd) * args.in_channel,
219
+ f'score_{args.split}_ppl_dim': np.mean(all_bpd),
220
+ f'break_down_{args.split}_dim' : [y.sum().item() for y in np.split(temp_cat, 8)],
221
+ f'last_10_{args.split}_dim': vb_temp[-10:].tolist(),
222
+ 'source_file': out_path,
223
+ 'num_samples':num_samples,
224
+ }
225
+ load_results(json_path, json_dict)
226
+
227
+
228
+ def create_argparser():
229
+ defaults = dict(
230
+ data_dir="", clip_denoised=False, num_samples=128, batch_size=64, model_path="",
231
+ out_dir="diffusion_lm/improved_diffusion/scores",
232
+ emb_scale_factor=1.0, split='train', debug_path='', clamp='clamp',
233
+ )
234
+ defaults.update(model_and_diffusion_defaults())
235
+ parser = argparse.ArgumentParser()
236
+ add_dict_to_argparser(parser, defaults)
237
+ return parser
238
+
239
+
240
+ if __name__ == "__main__":
241
+ main()
src/scripts/tree_helper.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import spacy, nltk
3
+ from nltk.tree import Tree
4
+ import numpy as np
5
+
6
+ def collapse_unary_strip_pos(tree, strip_top=True):
7
+ """Collapse unary chains and strip part of speech tags."""
8
+
9
+ def strip_pos(tree):
10
+ if len(tree) == 1 and isinstance(tree[0], str):
11
+ return tree[0]
12
+ else:
13
+ return nltk.tree.Tree(tree.label(), [strip_pos(child) for child in tree])
14
+
15
+ collapsed_tree = strip_pos(tree)
16
+ collapsed_tree.collapse_unary(collapsePOS=True, joinChar="::")
17
+ if collapsed_tree.label() in ("TOP", "ROOT", "S1", "VROOT"):
18
+ if strip_top:
19
+ if len(collapsed_tree) == 1:
20
+ collapsed_tree = collapsed_tree[0]
21
+ else:
22
+ collapsed_tree.set_label("")
23
+ elif len(collapsed_tree) == 1:
24
+ collapsed_tree[0].set_label(
25
+ collapsed_tree.label() + "::" + collapsed_tree[0].label())
26
+ collapsed_tree = collapsed_tree[0]
27
+ return collapsed_tree
28
+
29
+ def _get_labeled_spans(tree, spans_out, start):
30
+ if isinstance(tree, str):
31
+ return start + 1
32
+
33
+ assert len(tree) > 1 or isinstance(
34
+ tree[0], str
35
+ ), "Must call collapse_unary_strip_pos first"
36
+ end = start
37
+ for child in tree:
38
+ end = _get_labeled_spans(child, spans_out, end)
39
+ # Spans are returned as closed intervals on both ends
40
+ spans_out.append((start, end - 1, tree.label()))
41
+ return end
42
+
43
+ def get_labeled_spans(tree):
44
+ """Converts a tree into a list of labeled spans.
45
+ Args:
46
+ tree: an nltk.tree.Tree object
47
+ Returns:
48
+ A list of (span_start, span_end, span_label) tuples. The start and end
49
+ indices indicate the first and last words of the span (a closed
50
+ interval). Unary chains are collapsed, so e.g. a (S (VP ...)) will
51
+ result in a single span labeled "S+VP".
52
+ """
53
+ tree = collapse_unary_strip_pos(tree)
54
+ spans_out = []
55
+ _get_labeled_spans(tree, spans_out, start=0)
56
+ return spans_out
57
+
58
+ def padded_chart_from_spans(label_vocab, spans, ):
59
+ num_words = 64
60
+ chart = np.full((num_words, num_words), -100, dtype=int)
61
+ # chart = np.tril(chart, -1)
62
+ # Now all invalid entries are filled with -100, and valid entries with 0
63
+ for start, end, label in spans:
64
+ if label in label_vocab:
65
+ chart[start, end] = label_vocab[label]
66
+ return chart
67
+
68
+ def chart_from_tree(label_vocab, tree, verbose=False):
69
+ spans = get_labeled_spans(tree)
70
+ num_words = len(tree.leaves())
71
+ chart = np.full((num_words, num_words), -100, dtype=int)
72
+ chart = np.tril(chart, -1)
73
+ # Now all invalid entries are filled with -100, and valid entries with 0
74
+ # print(tree)
75
+ for start, end, label in spans:
76
+ # Previously unseen unary chains can occur in the dev/test sets.
77
+ # For now, we ignore them and don't mark the corresponding chart
78
+ # entry as a constituent.
79
+ # print(start, end, label)
80
+ if label in label_vocab:
81
+ chart[start, end] = label_vocab[label]
82
+ if not verbose:
83
+ return chart
84
+ else:
85
+ return chart, spans
86
+
87
+ def pad_charts(charts, padding_value=-100):
88
+ """
89
+ Our input text format contains START and END, but the parse charts doesn't.
90
+ NEED TO: update the charts, so that we include these two, and set their span label to 0.
91
+
92
+ :param charts:
93
+ :param padding_value:
94
+ :return:
95
+ """
96
+ max_len = 64
97
+ padded_charts = torch.full(
98
+ (len(charts), max_len, max_len),
99
+ padding_value,
100
+ )
101
+ padded_charts = np.tril(padded_charts, -1)
102
+ # print(padded_charts[-2:], padded_charts.shape)
103
+ # print(padded_charts[1])
104
+ for i, chart in enumerate(charts):
105
+ # print(chart, len(chart), len(chart[0]))
106
+ chart_size = len(chart)
107
+ padded_charts[i, 1:chart_size+1, 1:chart_size+1] = chart
108
+
109
+ # print(padded_charts[-2:], padded_charts.shape)
110
+ return padded_charts
train.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ from transformers import set_seed
4
+ from src.scripts.mytokenizers import Tokenizer
5
+ from src.improved_diffusion import gaussian_diffusion as gd
6
+ from src.improved_diffusion.respace import SpacedDiffusion
7
+ from src.improved_diffusion import dist_util
8
+ from src.improved_diffusion.transformer_model import TransformerNetModel
9
+ from src.improved_diffusion.resample import create_named_schedule_sampler
10
+ from src.improved_diffusion.script_util import model_and_diffusion_defaults
11
+ from src.improved_diffusion.script_util import add_dict_to_argparser
12
+ from src.improved_diffusion.train_util import TrainLoop
13
+ import torch.distributed as dist
14
+ import wandb
15
+ from src.scripts.mydatasets import get_dataloader, Lang2molDataset_train
16
+ import warnings
17
+ import torch.multiprocessing as mp
18
+
19
+
20
+ def main_worker(rank, world_size):
21
+ args = create_argparser().parse_args()
22
+ set_seed(42)
23
+
24
+ wandb.login(key=args.wandb_token)
25
+ wandb.init(
26
+ project="ACL_Lang2Mol",
27
+ config=args.__dict__,
28
+ )
29
+
30
+ dist_util.setup_dist(rank, world_size)
31
+ tokenizer = Tokenizer()
32
+ model = TransformerNetModel(
33
+ in_channels=args.model_in_channels,
34
+ model_channels=args.model_model_channels,
35
+ dropout=args.model_dropout,
36
+ vocab_size=len(tokenizer),
37
+ hidden_size=args.model_hidden_size,
38
+ num_attention_heads=args.model_num_attention_heads,
39
+ num_hidden_layers=args.model_num_hidden_layers,
40
+ )
41
+ if args.model_path != "":
42
+ model.load_state_dict(
43
+ dist_util.load_state_dict(args.model_path, map_location="cpu")
44
+ )
45
+
46
+ model.train()
47
+
48
+ print("Total params:", sum(p.numel() for p in model.parameters()))
49
+ print(
50
+ "Total trainable params:",
51
+ sum(p.numel() for p in model.parameters() if p.requires_grad),
52
+ )
53
+ print("Tokenizer vocab length:", len(tokenizer))
54
+
55
+ diffusion = SpacedDiffusion(
56
+ use_timesteps=[i for i in range(args.diffusion_steps)],
57
+ betas=gd.get_named_beta_schedule("sqrt", args.diffusion_steps),
58
+ model_mean_type=(gd.ModelMeanType.START_X),
59
+ model_var_type=((gd.ModelVarType.FIXED_LARGE)),
60
+ loss_type=gd.LossType.E2E_MSE,
61
+ rescale_timesteps=True,
62
+ model_arch="transformer",
63
+ training_mode="e2e",
64
+ )
65
+
66
+ schedule_sampler = create_named_schedule_sampler("uniform", diffusion)
67
+
68
+ print("Loading data...")
69
+ train_dataset = Lang2molDataset_train(
70
+ dir=args.dataset_path,
71
+ tokenizer=tokenizer,
72
+ split="train",
73
+ corrupt_prob=0.0,
74
+ token_max_length=512,
75
+ dataset_name=args.dataset_name,
76
+ )
77
+ dataloader = get_dataloader(train_dataset, args.batch_size, rank, world_size)
78
+ print("Finish loading data")
79
+
80
+ TrainLoop(
81
+ model=model,
82
+ diffusion=diffusion,
83
+ data=dataloader,
84
+ batch_size=args.batch_size,
85
+ microbatch=args.microbatch,
86
+ lr=args.lr,
87
+ ema_rate=args.ema_rate,
88
+ log_interval=args.log_interval,
89
+ save_interval=args.save_interval,
90
+ resume_checkpoint=args.resume_checkpoint,
91
+ use_fp16=args.use_fp16,
92
+ fp16_scale_growth=args.fp16_scale_growth,
93
+ schedule_sampler=schedule_sampler,
94
+ weight_decay=args.weight_decay,
95
+ lr_anneal_steps=args.lr_anneal_steps,
96
+ checkpoint_path=args.checkpoint_path,
97
+ gradient_clipping=args.gradient_clipping,
98
+ eval_data=None,
99
+ eval_interval=args.eval_interval,
100
+ ).run_loop()
101
+ dist.destroy_process_group()
102
+
103
+
104
+ def create_argparser():
105
+ defaults = dict()
106
+ text_defaults = dict(
107
+ wandb_token="",
108
+ batch_size=16,
109
+ cache_mode="no",
110
+ checkpoint_path="checkpoints",
111
+ class_cond=False,
112
+ config="ll",
113
+ config_name="QizhiPei/biot5-base-text2mol",
114
+ dataset_path="dataset",
115
+ diffusion_steps=2000,
116
+ dropout=0.01,
117
+ e2e_train="",
118
+ ema_rate="0.9999",
119
+ emb_scale_factor=1.0,
120
+ eval_interval=2000,
121
+ experiment="random",
122
+ experiment_mode="lm",
123
+ fp16_scale_growth=0.001,
124
+ gradient_clipping=2.4,
125
+ image_size=8,
126
+ in_channel=16,
127
+ learn_sigma=False,
128
+ log_interval=1000,
129
+ logits_mode=1,
130
+ lr=0.00005,
131
+ lr_anneal_steps=500000,
132
+ microbatch=-1,
133
+ modality="e2e-tgt",
134
+ model_arch="transformer",
135
+ noise_level=0.0,
136
+ noise_schedule="sqrt",
137
+ num_channels=128,
138
+ num_heads=4,
139
+ num_heads_upsample=-1,
140
+ num_res_blocks=2,
141
+ out_channel=16,
142
+ padding_mode="pad",
143
+ predict_xstart=True,
144
+ preprocessing_num_workers=1,
145
+ rescale_learned_sigmas=True,
146
+ rescale_timesteps=True,
147
+ resume_checkpoint="",
148
+ save_interval=50000,
149
+ schedule_sampler="uniform",
150
+ seed=42,
151
+ timestep_respacing="",
152
+ training_mode="e2e",
153
+ use_bert_tokenizer="no",
154
+ use_checkpoint=False,
155
+ use_fp16=False,
156
+ use_kl=False,
157
+ use_scale_shift_norm=True,
158
+ weight_decay=0.0,
159
+ model_in_channels=32,
160
+ model_model_channels=128,
161
+ model_dropout=0.01,
162
+ model_hidden_size=1024,
163
+ model_num_attention_heads=16,
164
+ model_num_hidden_layers=12,
165
+ dataset_name="",
166
+ model_path="",
167
+ )
168
+ defaults.update(model_and_diffusion_defaults())
169
+ defaults.update(text_defaults)
170
+ parser = argparse.ArgumentParser()
171
+ add_dict_to_argparser(parser, defaults)
172
+ return parser
173
+
174
+
175
+ if __name__ == "__main__":
176
+ world_size = 1
177
+ mp.spawn(main_worker, args=(world_size,), nprocs=world_size, join=True)