diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..7c686966cc50eea41c8b745f35baaccba59ad3d7 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +*.csv filter=lfs diff=lfs merge=lfs -text +*.png filter=lfs diff=lfs merge=lfs -text diff --git a/.streamlit/config.toml b/.streamlit/config.toml new file mode 100644 index 0000000000000000000000000000000000000000..508672a57fb1caf6378b84cf71d6d68e6a6d2df6 --- /dev/null +++ b/.streamlit/config.toml @@ -0,0 +1,19 @@ +[theme] +base="light" + +# Primary accent for interactive elements +primaryColor = '#0078aa' + +# Background color for the main content area +# backgroundColor = '#273346' + +# Background color for sidebar and most interactive widgets +# secondaryBackgroundColor = '#7d828c' + +# Color used for almost all text +# textColor = '#4bc9ff' + +# Font family for all text in the app, except code blocks +# Accepted values (serif | sans serif | monospace) +# Default: "sans serif" +# font = "sans serif" \ No newline at end of file diff --git a/README.md b/README.md index 8385832fab599d9c95c4f72509c4e9ab212e63e5..29fb8da7b13aa98165cfc881994252300dd0f575 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,113 @@ --- -title: Mhnfs -emoji: πŸš€ -colorFrom: yellow -colorTo: purple +title: MHNfs +emoji: πŸ”¬ +short_description: Activity prediction for low-data scenarios +colorFrom: gray +colorTo: gray sdk: streamlit -sdk_version: 1.32.2 +sdk_version: 1.29.0 app_file: app.py -pinned: false -license: gpl-3.0 +pinned: true --- -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +# Activity Predictions with MHNfs for low-data scenarios + +## βš™οΈ Under the hood +
+ The predictive model (MHNfs) used in this application was specifically designed and + trained for low-data scenarios. The model predicts whether a molecule is active or + inactive. The predicted activity value is a continuous value between 0 and 1, and, + similar to a probability, the higher/lower the value, the more confident the model + is that the molecule is active/inactive.
+
+ The model was trained on the FS-Mol dataset which + includes 5120 tasks (roughly 5000 tasks were used for training, rest for evaluation). + The training tasks are listed here: + https://github.com/microsoft/FS-Mol/tree/main/datasets/targets. +
+ +## 🎯 About few-shot learning and the model MHNfs +
+ Few-shot learning is a machine learning sub-field which aims to provide + predictive models for scenarios in which only little data is known/available.
+
+ MHNfs is a few-shot learning model which is specifically designed for drug + discovery applications. It is built to use the input prompts in a way such that + the provided available knowledge, i.e. the known active and inactive molecules, + functions as context to predict the activity of the new requested molecules. + Precisely, the provided active and inactive molecules are associated with a + large set of general molecules - called context molecules - to enrich the + provided information and to remove spurious correlations arising from the + decoration of molecules. This is analogous to a Large Language Model which would + not only use the provided information in the current prompt as context but would + also have access to way more information, e.g., a prompting history. +
+ +## πŸ’» Run the prediction pipeline locally for larger screening chunks + +### Get started: +```bash +# Copied from hugging face +# Make sure you have git-lfs installed (https://git-lfs.com) +git lfs install + +# Clone repo +git clone https://huggingface.co/spaces/tschouis/mhnfs + +# Alternatively, if you want to clone without large files +GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/spaces/tschouis/mhnfs +``` + +### Install requirements +```bash +pip install -r requirements.txt +``` +Notably, this command was tested inside a conda environment with python 3.7. + +### Run the prediction pipeline: +For your screening, load the model, i.e. the **Activity Predictor** into your python file or notebook and simply run it: +```python +from src.prediction_pipeline load ActivityPredictor + +# Define inputs +query_smiles = ["C1CCCCC1", "C1CCCCC1", "C1CCCCC1", "C1CCCCC1"] # Replace with your data +support_actives_smiles = ["C1CCCCC1", "C1CCCCC1"] # Replace with your data +support_inactives_smiles = ["C1CCCCC1", "C1CCCCC1"] # Replace with your data + +# Make predictions +predictions = predictor.predict(query_smiles, support_actives_smiles support_inactives_smiles) +``` + +* Provide molecules in SMILES notation. +* Make sure that the inputs to the Activity Predictor are either comma separated lists, or flattened numpy arrays, or pandas DataFrames. In the latter case, there should be a "smiles" column (both upper and lower case "SMILES" are accepted). All other columns are ignored. + + + +### Run the app locally with streamlib: +```bash +# Navigate into root directory of this project +cd .../whatever_your_dir_name_is/ # Replace with your path + +# Run streamlit app +python -m streamlit run +``` + + +## πŸ€— Hugging face app +Explore our hugging-face app here: + +## πŸ“š Cite us + +``` +@inproceedings{ + schimunek2023contextenriched, + title={Context-enriched molecule representations improve few-shot drug discovery}, + author={Johannes Schimunek and Philipp Seidl and Lukas Friedrich and Daniel Kuhn and Friedrich Rippmann and Sepp Hochreiter and GΓΌnter Klambauer}, + booktitle={The Eleventh International Conference on Learning Representations}, + year={2023}, + url={https://openreview.net/forum?id=XrMWUuEevr} +} +``` + + diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..b9a33518d3d236f90b5e8aa6d906ff4cdc8cc557 --- /dev/null +++ b/app.py @@ -0,0 +1,65 @@ +""" +This script runs the streamlit app for MHNfs + +MHNfs: Few-shot method for drug discovery activity predictions + (https://openreview.net/pdf?id=XrMWUuEevr) +""" + +# -------------------------------------------------------------------------------------- +# Imports +import streamlit as st + +from src.app.layout import LayoutMaker +from src.app.prediction_utils import (create_prediction_df, + create_molecule_grid_plot) +from src.prediction_pipeline import ActivityPredictor + +# -------------------------------------------------------------------------------------- +# Functions +class App(): + def __init__(self): + # Set page configration to wide + st.set_page_config(layout="wide", page_title="MHNfs", page_icon="πŸ”¬") + + # Layout maker + self.layoutMaker = LayoutMaker() + + # Load mhnfs model + self.predictor = ActivityPredictor() + + def define_layout(self): + + # Define Sidebar width + css = ''' + + ''' + st.markdown(css, unsafe_allow_html=True) + + # Sidebar + self.inputs, self.buttons = self.layoutMaker.make_sidebar() + + # Main page + # - header + self.layoutMaker.make_header() + + # - main body + self.layoutMaker.make_main_content_area(self.predictor, + self.inputs, + self.buttons, + create_prediction_df, + create_molecule_grid_plot) + +def run_app(): + app = App() + app.define_layout() + + +# -------------------------------------------------------------------------------------- +# Run script +if __name__ == "__main__": + run_app() diff --git a/assets/data_preprocessing_objects/ecdfs.pkl b/assets/data_preprocessing_objects/ecdfs.pkl new file mode 100644 index 0000000000000000000000000000000000000000..7fbab090f3b42d53c49654def367089b26477d68 --- /dev/null +++ b/assets/data_preprocessing_objects/ecdfs.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:eeec12688fd9e0bb0bbd68d5203e2fb46c45d30a07417f0883adbfc133d48e9f +size 520417347 diff --git a/assets/data_preprocessing_objects/scaler_fitted.pkl b/assets/data_preprocessing_objects/scaler_fitted.pkl new file mode 100755 index 0000000000000000000000000000000000000000..0661f6a057e0debdcbc0f9edcd9d1f802ac46edd --- /dev/null +++ b/assets/data_preprocessing_objects/scaler_fitted.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4538c1c1d9b5b50d29a14c14134f66a563c3a0f4022ce77b8eb2959c3eff51ea +size 54501 diff --git a/assets/example_csv/.~lock.known_inactive_molecules.csv# b/assets/example_csv/.~lock.known_inactive_molecules.csv# new file mode 100644 index 0000000000000000000000000000000000000000..6aac86237cd3e80015c7e16f5769a5def850e413 --- /dev/null +++ b/assets/example_csv/.~lock.known_inactive_molecules.csv# @@ -0,0 +1 @@ +,johannes,Latitude-5501,02.01.2024 15:57,file:///home/johannes/.config/libreoffice/4; \ No newline at end of file diff --git a/assets/example_csv/known_active_molecules.csv b/assets/example_csv/known_active_molecules.csv new file mode 100644 index 0000000000000000000000000000000000000000..907303100b4fa831a5adb782eb0ebbd9fe64826c --- /dev/null +++ b/assets/example_csv/known_active_molecules.csv @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bc98c05246b42d84c6833d191efa32c7c6473d76c5f2719c8ff3310cfe22df04 +size 353 diff --git a/assets/example_csv/known_inactive_molecules.csv b/assets/example_csv/known_inactive_molecules.csv new file mode 100644 index 0000000000000000000000000000000000000000..2861e2b837c19d703a459c11898a8bcf6a069fe1 --- /dev/null +++ b/assets/example_csv/known_inactive_molecules.csv @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a6e183c33b7445ae0c00bea4a7cdae52bfce14da2829f6827e20dda162df23af +size 363 diff --git a/assets/example_csv/molecules_for_prediction.csv b/assets/example_csv/molecules_for_prediction.csv new file mode 100644 index 0000000000000000000000000000000000000000..46f476e8b149a54d7ba5048930d3ded4faf38b88 --- /dev/null +++ b/assets/example_csv/molecules_for_prediction.csv @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:497adfdbd026c7ab7d1564b685a246fcb7eb6eabb2442918862b31ccd0b32369 +size 460 diff --git a/assets/example_csv/predictions/nottrustworthy_example.csv b/assets/example_csv/predictions/nottrustworthy_example.csv new file mode 100644 index 0000000000000000000000000000000000000000..18bbee0dbc0ff12ccb73b4e749bf62c22f2c21a9 --- /dev/null +++ b/assets/example_csv/predictions/nottrustworthy_example.csv @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f3f8b5e017175b8d62982b1fc4138a4348f51b6a0469c32df991f5d2576a679d +size 588 diff --git a/assets/example_csv/predictions/nottrustworthy_example.png b/assets/example_csv/predictions/nottrustworthy_example.png new file mode 100644 index 0000000000000000000000000000000000000000..0f9ab8ff9fa7e16f3cdbfcca74fa0049e52a1194 --- /dev/null +++ b/assets/example_csv/predictions/nottrustworthy_example.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ae7aff2e2cd2e68bdcb4a5563be38c13d7780453443657b36f01333ab57a949c +size 25505 diff --git a/assets/example_csv/predictions/trustworthy_example.csv b/assets/example_csv/predictions/trustworthy_example.csv new file mode 100644 index 0000000000000000000000000000000000000000..1ebd92bc23042b661efe278566a608ab914913a3 --- /dev/null +++ b/assets/example_csv/predictions/trustworthy_example.csv @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3517bcef4a9998975b031d1b4f2b4aa29679669079100230f84e27bc06f80c02 +size 889 diff --git a/assets/example_csv/predictions/trustworthy_example.png b/assets/example_csv/predictions/trustworthy_example.png new file mode 100644 index 0000000000000000000000000000000000000000..d2709d25f963f96ef45b994ff2b7ab35a5b5c887 --- /dev/null +++ b/assets/example_csv/predictions/trustworthy_example.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:df2a73cdf527546e8b078cb45618b4554a77f11fdd48367ef25939e0a6a2b518 +size 28331 diff --git a/assets/header.png b/assets/header.png new file mode 100644 index 0000000000000000000000000000000000000000..c5c71fbddd84d37cae9f0f6d62becfd7a47bea9e --- /dev/null +++ b/assets/header.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1d355c5fc158281371a09759584110e611c810d2442e8aad30551998aa728f0a +size 122706 diff --git a/assets/logo.png b/assets/logo.png new file mode 100644 index 0000000000000000000000000000000000000000..a306190a2f0ac4f404745bf23609cade08260287 --- /dev/null +++ b/assets/logo.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:505cc795dcaac622e2af6bf2ed118d7ab28d3eab27fd421755844c042ed7646a +size 40875 diff --git a/assets/mhnfs_data/cfg.yaml b/assets/mhnfs_data/cfg.yaml new file mode 100644 index 0000000000000000000000000000000000000000..20686869fab8fcae751e6349a348a7df5066e284 --- /dev/null +++ b/assets/mhnfs_data/cfg.yaml @@ -0,0 +1,42 @@ +model: + encoder: + activation: selu + input_dim: 2248 + number_hidden_layers: 0 + number_hidden_neurons: 1024 + regularization: + input_dropout: 0.1 + dropout: 0.5 + layerNormBlock: + affine: False + usage: True + transformer: + activity_embedding_dim: 64 + number_heads: 8 + dim_forward: 567 + dropout: 0.5 + num_layers: 1 + ss_dropout: 0.1 + hopfield: + dim_QK: 512 + heads: 8 + beta: 0.044194173824159216 + dropout: 0.5 + prediction_scaling: 0.044194173824159216 + associationSpace_dim: 1024 + similarityModule: + type: cosineSim + l2Norm: False + scaling: 1/N + training: + optimizer: AdamW + batch_size: 512 + lr: 0.0001 + weightDecay: 0.0 + lrScheduler: + usage: True + context: + ratio_training_molecules: 0.05 +system: + ressources: + device: cpu diff --git a/assets/mhnfs_data/full_context_set.npy b/assets/mhnfs_data/full_context_set.npy new file mode 100755 index 0000000000000000000000000000000000000000..d15fb13e1c80dd060fcbf7d7e580c609adef9140 --- /dev/null +++ b/assets/mhnfs_data/full_context_set.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1ed40b8d9cc39859772af0d32ed69c7f2467b9235f83f37ff42611bc22828e52 +size 3899416896 diff --git a/assets/mhnfs_data/mhnfs_checkpoint.ckpt b/assets/mhnfs_data/mhnfs_checkpoint.ckpt new file mode 100644 index 0000000000000000000000000000000000000000..89fb2781a24f2398d093d3fdc743995ef399235b --- /dev/null +++ b/assets/mhnfs_data/mhnfs_checkpoint.ckpt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:25fcfdb7c6355b7781edaefc9ec56351f012356b17e4087f72b0a78c6d8e2300 +size 313588174 diff --git a/assets/mhnfs_overview.png b/assets/mhnfs_overview.png new file mode 100644 index 0000000000000000000000000000000000000000..d56d2c0c46fbe515685c4ca805cce81e6d215541 --- /dev/null +++ b/assets/mhnfs_overview.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f89731eaf842e6018b4153d60193ea57442fb5933774135a653d4b70ac48afe2 +size 466946 diff --git a/assets/test_reference_data/ecfps.npy b/assets/test_reference_data/ecfps.npy new file mode 100644 index 0000000000000000000000000000000000000000..e329a21664a03188ea6cf18774762266c493d8a5 --- /dev/null +++ b/assets/test_reference_data/ecfps.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:056a628c308cf69e647f2c86090f8f93c2aedcd719845f57f11e653ce6d9d70b +size 24704 diff --git a/assets/test_reference_data/model_input_query.pt b/assets/test_reference_data/model_input_query.pt new file mode 100644 index 0000000000000000000000000000000000000000..513ad54892b7e5e6bf4cdec17ddd0dfbdd449352 --- /dev/null +++ b/assets/test_reference_data/model_input_query.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e889558eb3300355b5c6ea0ce1518bb949141238b8d26b257ec1bd496baeda18 +size 36715 diff --git a/assets/test_reference_data/model_input_support_actives.pt b/assets/test_reference_data/model_input_support_actives.pt new file mode 100644 index 0000000000000000000000000000000000000000..17bed4d70ffe83df73704148a9331b53388d0a11 --- /dev/null +++ b/assets/test_reference_data/model_input_support_actives.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e5e55816e09597d267fb91297a56f58a4f4420ed32340650be4c1dd37efe1656 +size 72683 diff --git a/assets/test_reference_data/model_input_support_inactives.pt b/assets/test_reference_data/model_input_support_inactives.pt new file mode 100644 index 0000000000000000000000000000000000000000..b59ba16f1a2537f88a66cdf633db24cbf369da45 --- /dev/null +++ b/assets/test_reference_data/model_input_support_inactives.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e62e8b18da47d1c9475c18bc2ad50a563f10f0d0bced247d848e453321a13ced +size 72683 diff --git a/assets/test_reference_data/model_predictions.pt b/assets/test_reference_data/model_predictions.pt new file mode 100644 index 0000000000000000000000000000000000000000..fd8e6d8a51d47afcf6f0de8858dafc30d6ce899d --- /dev/null +++ b/assets/test_reference_data/model_predictions.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a7e63ad2ad9b664e3301479427f8d5cf005c979d7cc9e4bce033f18640eb4df0 +size 747 diff --git a/assets/test_reference_data/preprocessed_features.npy b/assets/test_reference_data/preprocessed_features.npy new file mode 100644 index 0000000000000000000000000000000000000000..f1dc33ac3aa07b8821a65da45e4003495ea31144 --- /dev/null +++ b/assets/test_reference_data/preprocessed_features.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1e97dc7eb85509c6b07156292b57a1bee4eaa8d60fbdb40c7e2e5738c8c6a460 +size 54080 diff --git a/assets/test_reference_data/rdkit_descr_quantils.npy b/assets/test_reference_data/rdkit_descr_quantils.npy new file mode 100644 index 0000000000000000000000000000000000000000..b50a6ef8039d1f6438cc59d82b17cd2590605de9 --- /dev/null +++ b/assets/test_reference_data/rdkit_descr_quantils.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cde4d2fd8658cdbcd55e75f14cb360cfa1b239f99d281c1f7296449636e94c6a +size 4928 diff --git a/assets/test_reference_data/rdkit_descrs.npy b/assets/test_reference_data/rdkit_descrs.npy new file mode 100644 index 0000000000000000000000000000000000000000..ece95389f506b58a7036eb0a269e1657e980cad4 --- /dev/null +++ b/assets/test_reference_data/rdkit_descrs.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d1b06153004b3f2ac02f0cefd16b0f17225527bbf53f8efe6e43c035b3d21690 +size 2528 diff --git a/assets/test_reference_data/smiles.pkl b/assets/test_reference_data/smiles.pkl new file mode 100644 index 0000000000000000000000000000000000000000..77e90c8a544d78a2023574e841bf0d81b832bf05 --- /dev/null +++ b/assets/test_reference_data/smiles.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e0168a7aaa6f7f3eca611a42d70782bae9eb970194449320d37b64f5a8c264f9 +size 179 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..496ae46a66a97f77254f8e181c48ba09ec723c2e --- /dev/null +++ b/requirements.txt @@ -0,0 +1,10 @@ +rdkit==2022.3.3 +pytorch-lightning==1.6.1 +torch==1.13.1 +numpy==1.21.5 +pandas==1.3.5 +omegaconf==2.1.2 +mols2grid==1.1.1 +scikit-learn +statsmodels==0.13.5 +streamlit \ No newline at end of file diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/__pycache__/__init__.cpython-37.pyc b/src/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..16665065e80e1a7344af4b094b1df63ace5189d8 Binary files /dev/null and b/src/__pycache__/__init__.cpython-37.pyc differ diff --git a/src/__pycache__/prediction_pipeline.cpython-37.pyc b/src/__pycache__/prediction_pipeline.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cc9e266ea5f1b7b0ebca3f8e977514f3c4c89f4f Binary files /dev/null and b/src/__pycache__/prediction_pipeline.cpython-37.pyc differ diff --git a/src/app/__pycache__/constants.cpython-37.pyc b/src/app/__pycache__/constants.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..66b6c4afa19390d172e3ff5ed5baafd3159db677 Binary files /dev/null and b/src/app/__pycache__/constants.cpython-37.pyc differ diff --git a/src/app/__pycache__/layout.cpython-37.pyc b/src/app/__pycache__/layout.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cddf59c07ff5cce1e03d8b4fd9c8ca1e5ea6391f Binary files /dev/null and b/src/app/__pycache__/layout.cpython-37.pyc differ diff --git a/src/app/__pycache__/prediction_utils.cpython-37.pyc b/src/app/__pycache__/prediction_utils.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2a5351001c478b42f20803d341b223ea166c78e7 Binary files /dev/null and b/src/app/__pycache__/prediction_utils.cpython-37.pyc differ diff --git a/src/app/constants.py b/src/app/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..bd6844863797c2e005833589612792853fec2ccc --- /dev/null +++ b/src/app/constants.py @@ -0,0 +1,269 @@ +""" +This file includes all the constant content shown in the app +""" + +# -------------------------------------------------------------------------------------- + +summary_text = (''' + This application allows you to make **activity predictions** for + **biological targets** for which you have only a **little knowledge** in + terms of known active and inactive molecules. + + **Provide** via the sidebar:\n + - some active molecules, + - some inactive molecules, and + - molecules you want to predict. + + Hit **Predict** and explore the predictions! + + For more **information** about the **model** and **how to provide the + molecules**, please visit the **Additional Information** tab. + ''') + +mhnfs_text =(''' +
+ MHNfs is a few-shot drug discovery model which consists of a context + module , a cross-attention module , and a similarity module + as described here: https://openreview.net/pdf?id=XrMWUuEevr. +
+
+ +
+ Abstract. A central task in computational drug discovery is to construct + models from known active molecules to find further promising molecules for + subsequent screening. However, typically only very few active molecules are + known. Therefore, few-shot learning methods have the potential to improve the + effectiveness of this critical phase of the drug discovery process. We introduce + a new method for few-shot drug discovery. Its main idea is to enrich a molecule + representation by knowledge about known context or reference molecules. Our + novel concept for molecule representation enrichment is to associate molecules + from both the support set and the query set with a large set of reference + (context) molecules through a modern Hopfield network. Intuitively, this + enrichment step is analogous to a human expert who would associate a given + molecule with familiar molecules whose properties are known. The enrichment step + reinforces and amplifies the covariance structure of the data, while + simultaneously removing spurious correlations arising from the decoration of + molecules. Our approach is compared with other few-shot methods for drug + discovery on the FS-Mol benchmark dataset. On FS-Mol, our approach outperforms + all compared methods and therefore sets a new state-of-the art for few-shot + learning in drug discovery. An ablation study shows that the enrichment step of + our method is the key to improve the predictive quality. In a domain shift + experiment, we further demonstrate the robustness of our method. Code is + available at https://github.com/ml-jku/MHNfs. +
+
+
+ ''') + +citation_text = ''' + ### + @inproceedings{ + schimunek2023contextenriched, + title={Context-enriched molecule representations improve few-shot drug discovery}, + author={Johannes Schimunek and Philipp Seidl and Lukas Friedrich and Daniel Kuhn and Friedrich Rippmann and Sepp Hochreiter and GΓΌnter + Klambauer}, + booktitle={The Eleventh International Conference on Learning Representations}, + year={2023}, + url={https://openreview.net/forum?id=XrMWUuEevr} + } + ''' + +few_shot_learning_text = ( + ''' +
+ Few-shot learning is a machine learning sub-field which aims to provide + predictive models for scenarios in which only little data is known/available.
+
+ + MHNfs is a few-shot learning model which is specifically designed for drug + discovery applications. It is built to use the input prompts in a way such that + the provided available knowledge, i.e. the known active and inactive molecules, + functions as context to predict the activity of the new requested molecules. + Precisely, the provided active and inactive molecules are associated with a + large set of general molecules - called context molecules - to enrich the + provided information and to remove spurious correlations arising from the + decoration of molecules. This is analogous to a Large Language Model which would + not only use the provided information in the current prompt as context but would + also have access to way more information, e.g., a prompting history. +
+ ''') + +under_the_hood_text = (''' +
+ The predictive model (MHNfs) used in this application was specifically designed and + trained for low-data scenarios. The model predicts whether a molecule is active or + inactive. The predicted activity value is a continuous value between 0 and 1, and, + similar to a probability, the higher/lower the value, the more confident the model + is that the molecule is active/inactive. + + The model was trained on the FS-Mol dataset which + includes 5120 tasks (roughly 5000 tasks were used for training, rest for evaluation). + The training tasks are listed here: + https://github.com/microsoft/FS-Mol/tree/main/datasets/targets. +
+ ''') + +usage_text = (''' +
+ To use this application, you need to provide 3 different sets of molecules: +
    +
  1. active molecules: set of known active molecules,
  2. +
  3. inactive molecules: set of known inactive molecules, and
  4. +
  5. molecules to predict: set of molecules you want to predict.
  6. +
+ These three sets can be provided via the sidebar. The sidebar also includes two + buttons predict and reset to run the prediction pipeline and to + reset it. +
+ ''') + +data_text = (''' +
+
+ + + + + + + ''') + +trust_text = (''' +
+ Just like all other machine learning models, the performance of MHNfs varies + and, generally, the model works well if the task is somehow close to tasks which + were used to train the model. The model performance for very different tasks is + unclear and might be poor.
+
+ + MHNfs was trained on the FS-Mol dataset which includes 5120 tasks (roughly + 5000 tasks were used for training, rest for evaluation). The training tasks are + listed here: https://github.com/microsoft/FS-Mol/tree/main/datasets/targets. +
+ ''') + +example_trustworthy_text = (''' +
+ Since the predicitve model has seen a lot of kinase related tasks during training, + the model is expected to generally perform well on kinase targets. For this example, + we use data for the target + CHEMBL5914. Notably, this specific kinase has not been seen + during training. Precisely, we use the available inhibition data while molecules + with an inhibition value greater (smaller) than 50 % are considered as active + (inactive).
+ + From the known available data, we have selected 4 "known" active molecules, + 8 "known" inactive molecules, and 11 molecules to predict.
+ + Molecules to predict: +
+ FC(F)(F)c1ccc(Cl)cc1CN1CCNc2ncc(-c3ccnc(N4CCNCC4)c3)cc21,
+ CS(=O)(=O)c1ccc(-n2nc(-c3cnc4[nH]ccc4c3)c3c(N)ncnc32)cc1,
+ O=C(Nc1ccccc1Cl)c1cnc2ccc(C3CCNCC3)cn12.O=C(O)C(=O)O,
+ CC(C)n1cnc2c(Nc3cccc(Cl)c3)nc(N[C@@H]3CCCC[C@@H]3N)nc21,
+ Nc1ncc(-c2ccc(NS(=O)(=O)C3CC3)cc2F)cc1-c1ccc2c(c1)CCNC2=O,
+ CCN1CCN(Cc2ccc(NC(=O)c3ccc(C)c(C#Cc4cccnc4)c3)cc2C(F)(F)F)CC1,
+ CN1CCN(c2ccc(-c3cnc4c(c3)N(Cc3cc(Cl)ccc3C(F)(F)F)CCN4)cn2)CC1,
+ CC(C)n1nc(-c2cnc(N)c(OC(F)(F)F)c2)cc1[C@H]1[C@@H]2CN(C3COC3)C[C@@H]21,
+ Nc1ncc(-c2cc([C@H]3[C@@H]4CN(C5COC5)C[C@@H]43)n(CC3CC3)n2)cc1C(F)(F)F,
+ Cc1ccc(NC(=O)C2(C(=O)Nc3ccc(Nc4ncc(F)c(-c5cc(F)c6nc(C)n(C(C)C)c6c5)n4)cc3)CC2)cc1,
+ C[C@@H](Oc1cc(-c2cnn(C3CCNCC3)c2)cnc1N)c1c(Cl)ccc(F)c1Cl +

+ + Known active molecules: +
+ CC(=O)N1CCN(c2cc(-c3cnc4c(c3)N(Cc3cc(Cl)ccc3C(F)(F)F)CCN4)ccn2)CC1,
+ CS(=O)(=O)c1cccc(Nc2nccc(-c3sc(N4CCOCC4)nc3-c3cccc(NS(=O)(=O)c4c(F)cccc4F)c3)n2)c1,
+ COc1cnccc1Nc1nc(-c2nn(Cc3c(F)cc(OCCO)cc3F)c3ccccc23)ncc1OC,
+ CN(C)[C@@H]1CC[C@@]2(C)[C@@H](CC[C@@H]3[C@@H]2CC[C@]2(C)C(c4cccc5cnccc45)=CC[C@@H]32)C1
+

+ + Known inactive molecules: +
+ c1cc(-c2c[nH]c3cnccc23)ccn1,
+ COc1ccc2c3ccnc(C(F)(F)F)c3n(CCCCN)c2c1,
+ CNS(=O)(=O)c1ccc(N(C)C)c(Nc2ncnc3cc(OC)c(OC)cc23)c1,
+ CN(C1CC1)S(=O)(=O)c1ccc(-c2cnc(N)c(-c3ccc4c(c3)CCNC4=O)c2)c(F)c1,
+ CCN1CCN(Cc2ccc(NC(=O)c3ccc(C)c(C#Cc4cnc5[nH]ccc5c4)c3)cc2C(F)(F)F)CC1,
+ CC(C)n1cc(-c2cc(-c3ccc(CN4CCOCC4)cc3)cnc2N)nn1,
+ CC(C)(O)[C@H](F)CN1Cc2cc(NC(=O)c3cnn4cccnc34)c(N3CCOCC3)cc2C1=O,
+ [2H]C([2H])([2H])C1(C([2H])([2H])[2H])Cn2nc(-c3ccc(F)cn3)c(-c3ccnc4[nH]ncc34)c2CO1
+

+ + Predictions:
+ +
+ ''') + +example_nottrustworthy_text = (''' +
+ For this example, we use data for the auxiliary transport protein target + CHEMBL5738. Precisely, we use the available Ki data + while molecules with a pCHEMBL value greater (smaller) than 5 are considered + as active (inactive).
+ + From the known available data, we have selected 4 "known" active molecules, + 3 "known" inactive molecules, and 10 molecules to predict.
+ + Molecules to predict: +
+ CC(C(=O)O)c1ccc(-c2ccccc2)c(F)c1,
+ O=S(=O)(O)Oc1cccc2cccc(Nc3ccccc3)c12,
+ CCCCCCCC/C=C\CCCCCCCC(=O)O,
+ C[C@]12C=CC(=O)C=C1CC[C@@H]1[C@@H]2[C@@H](O)C[C@@]2(C)[C@H]1CC[C@]2(O)C(=O)CO,
+ CCOC(=O)C(C)(C)Oc1ccc(Cl)cc1,
+ Cc1ccc(Cl)c(Nc2ccccc2C(=O)O)c1Cl,
+ O=C(O)Cc1ccccc1Nc1c(Cl)cccc1Cl,
+ CC(C)(Oc1ccc(CCNC(=O)c2ccc(Cl)cc2)cc1)C(=O)O,
+ O=C(c1ccccc1)c1ccc2n1CCC2C(=O)O,
+ CC(C)OC(=O)C(C)(C)Oc1ccc(C(=O)c2ccc(Cl)cc2)cc1
+

+ + Known active molecules: +
+ CC(C(=O)O)c1ccc(N2Cc3ccccc3C2=O)cc1,
+ CN1C(=O)CN=C(c2ccccc2)c2cc(Cl)ccc21,
+ CC(C)(Oc1ccc(C(=O)c2ccc(Cl)cc2)cc1)C(=O)O,
+ CC(=O)[C@H]1CC[C@H]2[C@@H]3CCC4=CC(=O)CC[C@]4(C)[C@H]3CC[C@]12C + +

+ + Known inactive molecules: +
+ CC(C)Cc1ccc(C(C)C(=O)O)cc1,
+ O=C1Nc2ccc(Cl)cc2C(c2ccccc2Cl)=NC1O,
+ C[C@@H]1C[C@H]2[C@@H]3CCC4=CC(=O)C=C[C@]4(C)[C@@]3(F)[C@@H](O)C[C@]2(C)[C@@]1(O)C(=O)CO +

+ + Predictions:
+ +
+ ''') \ No newline at end of file diff --git a/src/app/layout.py b/src/app/layout.py new file mode 100644 index 0000000000000000000000000000000000000000..61fe081e303a9d9b9033d02ecf172cc2e190295a --- /dev/null +++ b/src/app/layout.py @@ -0,0 +1,439 @@ +""" +This file defines the layout of the app including the header, sidebar, and tabs in the +main content area. +""" + +#--------------------------------------------------------------------------------------- +# Imports +import streamlit as st +import streamlit.components.v1 as components +from PIL import Image +import pandas as pd +import yaml + +from src.data_preprocessing.create_descriptors import handle_inputs +from src.app.constants import (summary_text, + mhnfs_text, + citation_text, + few_shot_learning_text, + under_the_hood_text, + usage_text, + data_text, + trust_text, + example_trustworthy_text, + example_nottrustworthy_text) +#--------------------------------------------------------------------------------------- +# Global variables +MAX_INPUT_LENGTH = 20 + +#--------------------------------------------------------------------------------------- +# Functions + +class LayoutMaker(): + """ + This class includes all the design choices regarding the layout of the app. This + class can be used in the main file to define header, sidebar, and main content area. + """ + + def __init__(self): + + # Initialize the inputs dictionary + self.inputs = dict() # this will be the storage for query and support set inputs + self.inputs_lists = dict() + + # Initialize prediction storage + self.predictions = None + + # Buttons + self.buttons = dict() # this will be the storage for buttons + + # content + self.summary_text = summary_text + self.mhnfs_text = mhnfs_text + self.citation_text = citation_text + self.few_shot_learning_text = few_shot_learning_text + self.under_the_hood_text = under_the_hood_text + self.usage_text = usage_text + self.data_text = data_text + self.trust_text = trust_text + self.example_trustworthy_text = example_trustworthy_text + self.example_nottrustworthy_text = example_nottrustworthy_text + + self.df_trustworthy = pd.read_csv("./assets/example_csv/predictions/" + "trustworthy_example.csv") + self.df_nottrustworthy = pd.read_csv("./assets/example_csv/predictions/" + "nottrustworthy_example.csv") + + self.max_input_length = MAX_INPUT_LENGTH + + def make_sidebar(self): + """ + This function defines the sidebar of the app. It includes the logo, query box, + support set boxes, and predict buttons. + It returns the stored inputs (for query and support set) and the buttons which + allow for user interactions. + """ + with st.sidebar: + # Logo + logo = Image.open("./assets/logo.png") + st.image(logo) + st.divider() + + # Query box + self._make_query_box() + st.divider() + + # Support set actives box + self._make_active_support_set_box() + st.divider() + + # Support set inactives box + self._make_inactive_support_set_box() + st.divider() + + # Predict buttons + self.buttons["predict"] = st.button("Predict...") + self.buttons["reset"] = st.button("Reset") + + return self.inputs, self.buttons + + def make_header(self): + """ + This function defines the header of the app. It consists only of a png image + in which the title and an overview is given. + """ + + header_container = st.container() + with header_container: + header = Image.open("./assets/header.png") + st.image(header) + + def make_main_content_area(self, + predictor, + inputs, + buttons, + create_prediction_df: callable, + create_molecule_grid_plot: callable): + + + tab1, tab2, tab3, tab4 = st.tabs(["Predictions", + "Paper / Cite", + "Additional Information", + "Examples"]) + + # Results tab + with tab1: + self._fill_tab_with_results_content(predictor, + inputs, + buttons, + create_prediction_df, + create_molecule_grid_plot) + + # Paper tab + with tab2: + self._fill_paper_and_citation_tab() + + # More explanations tab + with tab3: + self._fill_more_explanations_tab() + + with tab4: + self._fill_examples_tab() + + def _make_query_box(self): + """ + This function + a) defines the query box and + b) stores the query input in the inputs dictionary + """ + + st.info(":blue[Molecules to predict:]", icon="❓") + + query_container = st.container() + with query_container: + input_choice = st.radio( + "Input your data in SMILES notation via:", ["Text box", "CSV upload"] + ) + if input_choice == "Text box": + query_input = st.text_area( + label="SMILES input for query molecules", + label_visibility="hidden", + key="query_textbox", + value="CC(C)Sc1nc(C(C)(C)C)nc(OCC(=O)O)c1C#N, " + "Cc1nc(NCc2cccnc2)cc(=O)n1CC(=O)O", + ) + elif input_choice == "CSV upload": + query_file = st.file_uploader(key="query_csv", + label = "CSV upload for query mols", + label_visibility="hidden") + if query_file is not None: + query_input = pd.read_csv(query_file) + else: query_input = None + + # Update storage + self.inputs["query"] = query_input + + def _make_active_support_set_box(self): + """ + This function + a) defines the active support set box and + b) stores the active support set input in the inputs dictionary + """ + + st.info(":blue[Known active molecules:]", icon="✨") + active_container = st.container() + with active_container: + active_input_choice = st.radio( + "Input your data in SMILES notation via:", + ["Text box", "CSV upload"], + key="active_input_choice", + ) + + if active_input_choice == "Text box": + support_active_input = st.text_area( + label="SMILES input for active support set molecules", + label_visibility="hidden", + key="active_textbox", + value="Cc1nc(NCC2CCCCC2)c(C#N)c(=O)n1CC(=O)O, " + "CSc1nc(C(C)C)nc(OCC(=O)O)c1C#N" + ) + elif active_input_choice == "CSV upload": + support_active_file = st.file_uploader( + key="support_active_csv", + label = "CSV upload for active support set molecules", + label_visibility="hidden" + ) + if support_active_file is not None: + support_active_input = pd.read_csv(support_active_file) + else: support_active_input = None + + # Update storage + self.inputs["support_active"] = support_active_input + + def _make_inactive_support_set_box(self): + st.info(":blue[Known inactive molecules:]", icon="✨") + inactive_container = st.container() + with inactive_container: + inactive_input_choice = st.radio( + "Input your data in SMILES notation via:", + ["Text box", "CSV upload"], + key="inactive_input_choice", + ) + if inactive_input_choice == "Text box": + support_inactive_input = st.text_area( + label="SMILES input for inactive support set molecules", + label_visibility="hidden", + key="inactive_textbox", + value="CSc1nc(C)nc(OCC(=O)O)c1C#N, " + "CSc1nc(C)n(CC(=O)O)c(=O)c1C#N" + ) + elif inactive_input_choice == "CSV upload": + support_inactive_file = st.file_uploader( + key="support_inactive_csv", + label = "CSV upload for inactive support set molecules", + label_visibility="hidden" + ) + if support_inactive_file is not None: + support_inactive_input = pd.read_csv( + support_inactive_file + ) + else: support_inactive_input = None + + # Update storage + self.inputs["support_inactive"] = support_inactive_input + + def _fill_tab_with_results_content(self, predictor, inputs, buttons, + create_prediction_df, create_molecule_grid_plot): + tab_container = st.container() + with tab_container: + # Info + st.info(":blue[Summary:]", icon="πŸš€") + st.markdown(self.summary_text) + + # Results + st.info(":blue[Results:]",icon="πŸ‘¨β€πŸ’»") + + if buttons['predict']: + + # Check 1: Are all inputs provided? + if (inputs['query'] is None or + inputs['support_active'] is None or + inputs['support_inactive'] is None): + st.error("You didn't provide all necessary inputs.\n\n" + "Please provide all three necessary inputs via the " + "sidebar and hit the predict button again.") + else: + # Check 2: Less than max allowed molecules provided? + max_input_length = 0 + for key, input in inputs.items(): + input_list = handle_inputs(input) + self.inputs_lists[key] = input_list + max_input_length = max(max_input_length, len(input_list)) + + if max_input_length > self.max_input_length: + st.error("You provided too many molecules. The number of " + "molecules for each input is restricted to " + f"{self.max_input_length}.\n\n" + "For larger screenings, we suggest to clone the repo " + "and to run the model locally.") + else: + # Progress bar + progress_bar_text = ("I'm predicting activities. This might " + "need some minutes. Please wait...") + progress_bar = st.progress(50, text=progress_bar_text) + + # Results table + df = self._predict_and_create_results_table(predictor, + inputs, + create_prediction_df) + + progress_bar_text = ("Done. Here are the results:") + progress_bar = progress_bar.progress(100, text=progress_bar_text) + st.dataframe(df, use_container_width=True) + + col1, col2, col3, col4 = st.columns([1,1,1,1]) + # Provide download button for predictions + with col2: + self.buttons["download_results"] = st.download_button( + "Download predictions as CSV", + self._convert_df_to_binary(df), + file_name="predictions.csv", + ) + + # Provide download button for inputs + with col3: + with open("inputs.yml", 'w') as fl: + self.buttons["download_inputs"] = st.download_button( + "Download inputs as YML", + self._convert_to_yml(self.inputs_lists), + file_name="inputs.yml", + ) + st.divider() + + # Results grid + st.info(":blue[Grid plot of the predicted molecules:]", + icon="πŸ“Š") + mol_html_grid = create_molecule_grid_plot(df) + components.html(mol_html_grid, height=1000, scrolling=True) + + elif buttons['reset']: + self._reset() + + def _fill_paper_and_citation_tab(self): + st.info(":blue[**Paper: Context-enriched molecule representations improve " + "few-shot drug discovery**]", icon="πŸ“„") + st.markdown(self.mhnfs_text, unsafe_allow_html=True) + st.image("./assets/mhnfs_overview.png") + st.write("") + st.write("") + st.write("") + st.info(":blue[**Cite us / BibTex**]", icon="πŸ“š") + st.markdown(self.citation_text) + + def _fill_more_explanations_tab(self): + st.info(":blue[**Under the hood**]", icon="βš™οΈ") + st.markdown(self.under_the_hood_text, unsafe_allow_html=True) + st.write("") + st.write("") + + st.info(":blue[**About few-shot learning and the model MHNfs**]", icon="🎯") + st.markdown(self.few_shot_learning_text, unsafe_allow_html=True) + st.write("") + st.write("") + + st.info(":blue[**Usage**]", icon="πŸŽ›οΈ") + st.markdown(self.usage_text, unsafe_allow_html=True) + st.write("") + st.write("") + + st.info(":blue[**How to provide the data**]", icon="πŸ“€") + st.markdown(self.data_text, unsafe_allow_html=True) + st.write("") + st.write("") + + st.info(":blue[**When to trust the predictions**]", icon="πŸ”") + st.markdown(self.trust_text, unsafe_allow_html=True) + + def _fill_examples_tab(self): + st.info(":blue[**Example for trustworthy predictions**]", icon="βœ…") + st.markdown(self.example_trustworthy_text, unsafe_allow_html=True) + st.dataframe(self.df_trustworthy, use_container_width=True) + st.markdown("**Plot: Predictions for active and inactive molecules (model AUC=" + "0.96**)") + prediction_plot_tw = Image.open("./assets/example_csv/predictions/" + "trustworthy_example.png") + st.image(prediction_plot_tw) + st.write("") + st.write("") + + st.info(":blue[**Example for not trustworthy predictions**]", icon="⛔️") + st.markdown(self.example_nottrustworthy_text, unsafe_allow_html=True) + st.dataframe(self.df_nottrustworthy, use_container_width=True) + st.markdown("**Plot: Predictions for active and inactive molecules (model AUC=" + "0.42**)") + prediction_plot_ntw = Image.open("./assets/example_csv/predictions/" + "nottrustworthy_example.png") + st.image(prediction_plot_ntw) + + def _predict_and_create_results_table(self, + predictor, + inputs, + create_prediction_df: callable): + + df = create_prediction_df(predictor, + inputs['query'], + inputs['support_active'], + inputs['support_inactive']) + return df + + def _reset(self): + keys = list(st.session_state.keys()) + for key in keys: + st.session_state.pop(key) + + def _convert_df_to_binary(_self, df): + return df.to_csv(index=False).encode('utf-8') + + def _convert_to_yml(_self, inputs): + return yaml.dump(inputs) + content = """ + # Usage + As soon as you have a few active and inactive molecules for your task, you can + provide them here and make predictions for new molecules. + + ## About few-shot learning and the model MHNfs + **Few-shot learning** is a machine learning sub-field which aims to provide + predictive models for scenarios in which only little data is known/available. + + **MHNfs** is a few-shot learning model which is specifically designed for drug + discovery applications. It is built to use the input prompts in a way such that + the provided available knowledge - i.e. the known active and inactive molecules - + functions as context to predict the activity of the new requested molecules. + Precisely, the provided active and inactive molecules are associated with a + large set of general molecules - called context molecules - to enrich the + provided information and to remove spurious correlations arising from the + decoration of molecules. This is analogous to a Large Language Model which would + not only use the provided information in the current prompt as context but would + also have access to way more information, e.g. a prompting history. + + ## How to provide the data + * Molecules have to be provided in SMILES format. + * You can provide the molecules via the text boxes or via CSV upload. + - Text box: Replace the pseudo input by directly typing your molecules into + the text box. Please separate the molecules by comma. + - CSV upload: Upload a CSV file with the molecules. + * The CSV file should include a smiles column (both upper and lower + case "SMILES" are accepted). + * All other columns will be ignored. + + ## When to trust the predictions + Just like all other machine learning models, the performance of MHNfs varies + and, generally, the model works well if the task is somehow close to tasks which + were used to train the model. The model performance for very different tasks is + unclear and might be poor. + + MHNfs was trained on a the FS-Mol dataset which includes 5120 tasks (Roughly + 5000 tasks were used for training, rest for evaluation). The training tasks are + listed here: https://github.com/microsoft/FS-Mol/tree/main/datasets/targets. + """ + return content \ No newline at end of file diff --git a/src/app/prediction_utils.py b/src/app/prediction_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..37613b9483327ab3ac2b101a2dbab324050e1dd4 --- /dev/null +++ b/src/app/prediction_utils.py @@ -0,0 +1,33 @@ +""" +This module includes all functions which are called from the main app and are needed to +make activity predictions and to output the results. +""" + +#--------------------------------------------------------------------------------------- +# Deendencies +import pandas as pd +import mols2grid +#--------------------------------------------------------------------------------------- +# Define functions + +def create_prediction_df(predictor, query_smiles, support_activces_smiles, + support_inactives_smiles): + """ + This function creates a dataframe with the query molecules and the corresponding + predictions. + """ + # Make predictions + predictions = predictor.predict(query_smiles, support_activces_smiles, + support_inactives_smiles) + + smiles = predictor._return_query_mols_as_list() + + # Create dataframe + prediction_df = pd.DataFrame({"Molecule": smiles, + "Predicted activity": predictions.astype('str')}) + + return prediction_df + +def create_molecule_grid_plot(df, smiles_col="Molecule"): + mol_html_grid = mols2grid.display(df,smiles_col=smiles_col)._repr_html_() + return mol_html_grid \ No newline at end of file diff --git a/src/data_preprocessing/__init__.py b/src/data_preprocessing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/data_preprocessing/__pycache__/__init__.cpython-36.pyc b/src/data_preprocessing/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ea3026f90c7b5c8ae0b1a772cbfb3ef5678fbce Binary files /dev/null and b/src/data_preprocessing/__pycache__/__init__.cpython-36.pyc differ diff --git a/src/data_preprocessing/__pycache__/__init__.cpython-37.pyc b/src/data_preprocessing/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fcac0f6f0dd8f2817ee437e9ef28311708500bb6 Binary files /dev/null and b/src/data_preprocessing/__pycache__/__init__.cpython-37.pyc differ diff --git a/src/data_preprocessing/__pycache__/constants.cpython-37.pyc b/src/data_preprocessing/__pycache__/constants.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1f8dba5d790ecb68272c5420afaded6ed2e6a958 Binary files /dev/null and b/src/data_preprocessing/__pycache__/constants.cpython-37.pyc differ diff --git a/src/data_preprocessing/__pycache__/create_descriptors.cpython-36.pyc b/src/data_preprocessing/__pycache__/create_descriptors.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e169defd0f2b9ec7587cae9bcec9aff4afce3a77 Binary files /dev/null and b/src/data_preprocessing/__pycache__/create_descriptors.cpython-36.pyc differ diff --git a/src/data_preprocessing/__pycache__/create_descriptors.cpython-37.pyc b/src/data_preprocessing/__pycache__/create_descriptors.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..295a5c4e106008fc4c0684d43cf3ecc927f0ef56 Binary files /dev/null and b/src/data_preprocessing/__pycache__/create_descriptors.cpython-37.pyc differ diff --git a/src/data_preprocessing/__pycache__/create_model_inputs.cpython-37.pyc b/src/data_preprocessing/__pycache__/create_model_inputs.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f68b11a09f0b68718b3b033b27010ba1725101d Binary files /dev/null and b/src/data_preprocessing/__pycache__/create_model_inputs.cpython-37.pyc differ diff --git a/src/data_preprocessing/__pycache__/utils.cpython-37.pyc b/src/data_preprocessing/__pycache__/utils.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ce74fef01c85d90bfebc5abc8bfad540e3d57131 Binary files /dev/null and b/src/data_preprocessing/__pycache__/utils.cpython-37.pyc differ diff --git a/src/data_preprocessing/constants.py b/src/data_preprocessing/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..bc0f7cbaf5fd1fb47811bbe489b6487c20bf5d60 --- /dev/null +++ b/src/data_preprocessing/constants.py @@ -0,0 +1,11 @@ +USED_200_DESCR = [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,25,26,27,28,29,30, 31,32,33, + 34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56, + 57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79, + 80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101, + 102,103,104,105,106,107,108,109,110,111,112,113,114,115,116,117,118, + 119,120,121,122,123,124,125,126,127,128,129,130,131,132,133,134,135, + 136,137,138,139,140,141,142,143,144,145,146,147,148,149,150,151,152, + 153,154,155,156,157,158,159,160,161,162,163,164,165,166,167,168,169, + 170,171,172,173,174,175,176,177,178,179,180,181,182,183,184,185,186, + 187,188,189,190,191,192,193,194,195,196,197,198,199,200,201,202,203, + 204,205,206,207] diff --git a/src/data_preprocessing/create_descriptors.py b/src/data_preprocessing/create_descriptors.py new file mode 100644 index 0000000000000000000000000000000000000000..fc36d5cb84104e64046d21b3a7c06b9ee069a7b5 --- /dev/null +++ b/src/data_preprocessing/create_descriptors.py @@ -0,0 +1,148 @@ +""" +This file includes all necessary code to preprocess molecules (assumed to be in SMILES +format) and create descriptors which can be fed into MHNfs. +""" + +#--------------------------------------------------------------------------------------- +# Dependencies +import numpy as np +import pandas as pd +import pickle +from typing import List +from rdkit import Chem, DataStructs +from rdkit.Chem.rdchem import Mol +from rdkit.Chem import Descriptors, rdFingerprintGenerator + +from src.data_preprocessing.constants import USED_200_DESCR +from src.data_preprocessing.utils import Standardizer + +#--------------------------------------------------------------------------------------- +# Define main function + +def preprocess_molecules(input_molecules: [str, List[str], pd.DataFrame]): + """ + This function preprocesses molecules (assumed to be in SMILES format) and creates + descriptors which can be fed into MHNfs. + """ + + # Load needed objects + current_loc = __file__.rsplit("/",3)[0] + with open(current_loc + "/assets/data_preprocessing_objects/scaler_fitted.pkl", + "rb") as fl: + scaler = pickle.load(fl) + + with open(current_loc + "/assets/data_preprocessing_objects/ecdfs.pkl", "rb") as fl: + ecdfs = pickle.load(fl) + + # Ensure that input_molecules is an Iterable with strs + input_smiles = handle_inputs(input_molecules) + + # Create cleanded rdkit mol objects + input_molecules = create_cleaned_mol_objects(input_smiles) + + # Create fingerprints and descriptors + ecfps = create_ecfp_fps(input_molecules) + rdkit_descrs = create_rdkit_descriptors(input_molecules) + + # Create quantils + rdkit_descr_quantils = create_quantils(rdkit_descrs, ecdfs) + + # Concatenate features + raw_features = np.concatenate((ecfps, rdkit_descr_quantils), axis=1) + + # Normalize feature vectors + normalized_features = scaler.transform(raw_features) + + # Return feature vectors + return normalized_features + +#--------------------------------------------------------------------------------------- +# Define helper functions +def handle_inputs(input_molecules: [str, List[str], pd.DataFrame]): + """ + This function handles the input molecules. + """ + + if isinstance(input_molecules, list): + return input_molecules + + elif isinstance(input_molecules, pd.DataFrame): + input_molecules.columns = [c.lower() for c in input_molecules.columns] + if "smiles" not in input_molecules.columns: + raise ValueError(("Input DataFrame must have a column named 'Smiles'.")) + iterable = list(input_molecules["smiles"].values) + return iterable + + elif isinstance(input_molecules, str): + smiles_list = input_molecules.split(",") + smiles_list_cleaned = [smiles.strip() for smiles in smiles_list] + + smiles_list_cleaned = [smiles for smiles in smiles_list_cleaned if smiles != ""] + return smiles_list_cleaned + else: + raise TypeError(("Input molecules must be a string,a list of strings or a " + "pandas DataFrame.")) + +def create_ecfp_fps(mols: List[Mol]) -> np.ndarray: + """ + This function ECFP fingerprints for a list of molecules. + """ + ecfps = list() + + for mol in mols: + fp_sparse_vec = rdFingerprintGenerator.GetCountFPs( + [mol], fpType=rdFingerprintGenerator.MorganFP + )[0] + fp = np.zeros((0,), np.int8) + DataStructs.ConvertToNumpyArray(fp_sparse_vec, fp) + + ecfps.append(fp) + + return np.array(ecfps) + +def create_rdkit_descriptors(mols: List[Mol]) -> np.ndarray: + """ + This function creates RDKit descriptors for a list of molecules. + """ + rdkit_descriptors = list() + + for mol in mols: + descrs = [] + for _, descr_calc_fn in Descriptors._descList: + descrs.append(descr_calc_fn(mol)) + + descrs = np.array(descrs) + descrs = descrs[USED_200_DESCR] + rdkit_descriptors.append(descrs) + + return np.array(rdkit_descriptors) + +def create_quantils(raw_features: np.ndarray, ecdfs: list) -> np.ndarray: + + quantils = np.zeros_like(raw_features) + + for column in range(raw_features.shape[1]): + raw_values = raw_features[:, column].reshape(-1) + ecdf = ecdfs[column] + q = ecdf(raw_values) + quantils[:, column] = q + + return quantils + +def create_cleaned_mol_objects(smiles: List[str]) -> List[Mol]: + """ + This function creates cleaned RDKit mol objects from a list of SMILES. + """ + sm = Standardizer(canon_taut=True) + + mols = list() + for smile in smiles: + #try: + mol = Chem.MolFromSmiles(smile) + standardized_mol, _ = sm.standardize_mol(mol) + can_mol = Chem.MolFromSmiles(Chem.MolToSmiles(standardized_mol)) + mols.append(can_mol) + return mols + +#--------------------------------------------------------------------------------------- + diff --git a/src/data_preprocessing/create_model_inputs.py b/src/data_preprocessing/create_model_inputs.py new file mode 100644 index 0000000000000000000000000000000000000000..d254817c1b788c9ad866f878a88edb76a53d46b6 --- /dev/null +++ b/src/data_preprocessing/create_model_inputs.py @@ -0,0 +1,46 @@ +""" +In this file, the input functions for query and support set molecules are defined. +Input is assumed to be either a SMILES string, a list of SMILES strings, or a pandas +dataframe. +""" + +#--------------------------------------------------------------------------------------- +# Dependencies +import pandas as pd +from typing import List +import torch + +from src.data_preprocessing.create_descriptors import preprocess_molecules + +#--------------------------------------------------------------------------------------- +# Define main functions +def create_query_input(smiles_input: [str, List[str], pd.DataFrame]): + """ + This function creates the input for the query molecules. + """ + + # Create vector representation + numpy_vector_representation = preprocess_molecules(smiles_input) + assert len(numpy_vector_representation.shape) == 2 + + # Create pytorch tensor + tensor = torch.from_numpy(numpy_vector_representation).unsqueeze(1).float() + + return tensor + +def create_support_set_input(smiles_input: [str, List[str], pd.DataFrame]): + """ + This function creates the input for the support set molecules. + """ + + # Create vector representation + numpy_vector_representation = preprocess_molecules(smiles_input) + assert len(numpy_vector_representation.shape) == 2 + + size = numpy_vector_representation.shape[0] + + # Create pytorch tensors + tensor = torch.from_numpy(numpy_vector_representation).unsqueeze(0).float() + size = torch.tensor(size) + + return tensor, size \ No newline at end of file diff --git a/src/data_preprocessing/utils.py b/src/data_preprocessing/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fc42b448b08f891fa6314f02cd35f65d6af7a42d --- /dev/null +++ b/src/data_preprocessing/utils.py @@ -0,0 +1,247 @@ +## These MolStandardizer classes are due to Paolo Tosco +## It was taken from the FS-Mol github +## (https://github.com/microsoft/FS-Mol/blob/main/fs_mol/preprocessing/utils/ +## standardizer.py) +## They ensure that a sequence of standardization operations are applied +## https://gist.github.com/ptosco/7e6b9ab9cc3e44ba0919060beaed198e + +import logging + +from rdkit import Chem, RDLogger +from rdkit.Chem.MolStandardize import rdMolStandardize + + +class BaseLogger: + """ + Simple logging base class. + + Inherit from this class and call self.get_logger() to + get a logger bearing the class name. + """ + + DEFAULT_LOG_LEVEL = logging.WARNING + + def __init__(self): + self._log_level = self.DEFAULT_LOG_LEVEL + + def set_log_level(self, log_level): + if not getattr(logging, log_level): + raise TypeError(f"log_level {log_level} does not exist in logging") + self._log_level = log_level + + def get_logger(self): + """Return a logger bearing the class name.""" + logger = logging.getLogger(self.__class__.__name__) + if not logger.hasHandlers(): + handler = logging.StreamHandler() + formatter = logging.Formatter("[%(asctime)s:%(name)s:%(levelname)s] %(message)s") + handler.setFormatter(formatter) + logger.addHandler(handler) + logger.setLevel(self._log_level) + return logger + + +class Standardizer(BaseLogger): + """ + Simple wrapper class around rdkit Standardizer. + """ + + DEFAULT_CANON_TAUT = False + DEFAULT_METAL_DISCONNECT = False + MAX_TAUTOMERS = 100 + MAX_TRANSFORMS = 100 + MAX_RESTARTS = 200 + PREFER_ORGANIC = True + + def __init__( + self, + metal_disconnect=None, + canon_taut=None, + ): + """ + Constructor. + + All parameters are optional. + :param metal_disconnect: if True, metallorganic complexes are + disconnected + :param canon_taut: if True, molecules are converted to their + canonical tautomer + """ + super().__init__() + if metal_disconnect is None: + metal_disconnect = self.DEFAULT_METAL_DISCONNECT + if canon_taut is None: + canon_taut = self.DEFAULT_CANON_TAUT + self._canon_taut = canon_taut + self._metal_disconnect = metal_disconnect + self._taut_enumerator = None + self._rdlogger = None + self._uncharger = None + self._lfrag_chooser = None + self._metal_disconnector = None + self._normalizer = None + self._reionizer = None + self._params = None + + @property + def params(self): + """Return the MolStandardize CleanupParameters.""" + if self._params is None: + self._params = rdMolStandardize.CleanupParameters() + self._params.maxTautomers = self.MAX_TAUTOMERS + self._params.maxTransforms = self.MAX_TRANSFORMS + self._params.maxRestarts = self.MAX_RESTARTS + self._params.preferOrganic = self.PREFER_ORGANIC + self._params.tautomerRemoveSp3Stereo = False + return self._params + + @property + def canon_taut(self): + """Return whether tautomer canonicalization will be done.""" + return self._canon_taut + + @property + def metal_disconnect(self): + """Return whether metallorganic complexes will be disconnected.""" + return self._metal_disconnect + + @property + def taut_enumerator(self): + """Return the TautomerEnumerator object.""" + if self._taut_enumerator is None: + self._taut_enumerator = rdMolStandardize.TautomerEnumerator(self.params) + return self._taut_enumerator + + @property + def uncharger(self): + """Return the Uncharger object.""" + if self._uncharger is None: + self._uncharger = rdMolStandardize.Uncharger() + return self._uncharger + + @property + def lfrag_chooser(self): + """Return the LargestFragmentChooser object.""" + if self._lfrag_chooser is None: + self._lfrag_chooser = rdMolStandardize.LargestFragmentChooser(self.params.preferOrganic) + return self._lfrag_chooser + + @property + def metal_disconnector(self): + """Return the MetalDisconnector object.""" + if self._metal_disconnector is None: + self._metal_disconnector = rdMolStandardize.MetalDisconnector() + return self._metal_disconnector + + @property + def normalizer(self): + """Return the Normalizer object.""" + if self._normalizer is None: + self._normalizer = rdMolStandardize.Normalizer( + self.params.normalizationsFile, self.params.maxRestarts + ) + return self._normalizer + + @property + def reionizer(self): + """Return the Reionizer object.""" + if self._reionizer is None: + self._reionizer = rdMolStandardize.Reionizer(self.params.acidbaseFile) + return self._reionizer + + def charge_parent(self, mol_in): + """Sequentially apply a series of MolStandardize operations: + + * MetalDisconnector + * Normalizer + * Reionizer + * LargestFragmentChooser + * Uncharger + + The net result is that a desalted, normalized, neutral + molecule with implicit Hs is returned. + """ + params = Chem.RemoveHsParameters() + params.removeAndTrackIsotopes = True + mol_in = Chem.RemoveHs(mol_in, params, sanitize=False) + if self._metal_disconnect: + mol_in = self.metal_disconnector.Disconnect(mol_in) + normalized = self.normalizer.normalize(mol_in) + Chem.SanitizeMol(normalized) + normalized = self.reionizer.reionize(normalized) + Chem.AssignStereochemistry(normalized) + normalized = self.lfrag_chooser.choose(normalized) + normalized = self.uncharger.uncharge(normalized) + # need this to reassess aromaticity on things like + # cyclopentadienyl, tropylium, azolium, etc. + Chem.SanitizeMol(normalized) + return Chem.RemoveHs(Chem.AddHs(normalized)) + + def standardize_mol(self, mol_in): + """ + Standardize a single molecule. + + :param mol_in: a Chem.Mol + :return: * (standardized Chem.Mol, n_taut) tuple + if success. n_taut will be negative if + tautomer enumeration was aborted due + to reaching a limit + * (None, error_msg) if failure + + This calls self.charge_parent() and, if self._canon_taut + is True, runs tautomer canonicalization. + """ + logger = self.get_logger() + if self._rdlogger is None: + self._rdlogger = RDLogger.logger() + self._rdlogger.setLevel(RDLogger.CRITICAL) + n_tautomers = 0 + if isinstance(mol_in, Chem.Mol): + name = None + try: + name = mol_in.GetProp("_Name") + except KeyError: + pass + if not name: + name = "NONAME" + else: + error = f"Expected SMILES or Chem.Mol as input, got {str(type(mol_in))}" + logger.critical(error) + return None, error + try: + mol_out = self.charge_parent(mol_in) + except Exception as e: + error = f"charge_parent FAILED: {str(e).strip()}" + logger.critical(error) + return None, error + if self._canon_taut: + try: + res = self.taut_enumerator.Enumerate(mol_out, False) + except TypeError: + # we are still on the pre-2021 RDKit API + res = self.taut_enumerator.Enumerate(mol_out) + except Exception as e: + # something else went wrong + error = f"canon_taut FAILED: {str(e).strip()}" + logger.critical(error) + return None, error + n_tautomers = len(res) + if hasattr(res, "status"): + completed = res.status == rdMolStandardize.TautomerEnumeratorStatus.Completed + else: + # we are still on the pre-2021 RDKit API + completed = len(res) < 1000 + if not completed: + n_tautomers = -n_tautomers + try: + mol_out = self.taut_enumerator.PickCanonical(res) + except AttributeError: + # we are still on the pre-2021 RDKit API + mol_out = max([(self.taut_enumerator.ScoreTautomer(m), m) for m in res])[1] + except Exception as e: + # something else went wrong + error = f"canon_taut FAILED: {str(e).strip()}" + logger.critical(error) + return None, error + mol_out.SetProp("_Name", name) + return mol_out, n_tautomers \ No newline at end of file diff --git a/src/mhnfs/__init__.py b/src/mhnfs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/mhnfs/__pycache__/__init__.cpython-37.pyc b/src/mhnfs/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fba5ad560e223c1e1893671503c646d9c56befd9 Binary files /dev/null and b/src/mhnfs/__pycache__/__init__.cpython-37.pyc differ diff --git a/src/mhnfs/__pycache__/initialization.cpython-37.pyc b/src/mhnfs/__pycache__/initialization.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2d2ef68d4230e8bb8d83a729df808c1938a8b55b Binary files /dev/null and b/src/mhnfs/__pycache__/initialization.cpython-37.pyc differ diff --git a/src/mhnfs/__pycache__/model.cpython-37.pyc b/src/mhnfs/__pycache__/model.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f3ee2a7bbffadf8060504f08dd61abe0a918a882 Binary files /dev/null and b/src/mhnfs/__pycache__/model.cpython-37.pyc differ diff --git a/src/mhnfs/__pycache__/modules.cpython-37.pyc b/src/mhnfs/__pycache__/modules.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a84eb2cf9aaeb0bf00f019b67066dbf143d1414a Binary files /dev/null and b/src/mhnfs/__pycache__/modules.cpython-37.pyc differ diff --git a/src/mhnfs/hopfield/LICENSE b/src/mhnfs/hopfield/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..47a0284f6bbbea8399f555abb5ff5932c873e95f --- /dev/null +++ b/src/mhnfs/hopfield/LICENSE @@ -0,0 +1,79 @@ +From Hopfield layers: + +Copyright (c) 2020, Institute for Machine Learning, Johannes Kepler University Linz (Bernhard SchΓ€fl) +All rights reserved. + +All other contributions: +Copyright (c) 2020 the respective contributors +All rights reserved. + +From PyTorch: + +Copyright (c) 2016- Facebook, Inc (Adam Paszke) +Copyright (c) 2014- Facebook, Inc (Soumith Chintala) +Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) +Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) +Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) +Copyright (c) 2011-2013 NYU (Clement Farabet) +Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) +Copyright (c) 2006 Idiap Research Institute (Samy Bengio) +Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) + +From Caffe2: + +Copyright (c) 2016-present, Facebook Inc. All rights reserved. + +All contributions by Facebook: +Copyright (c) 2016 Facebook Inc. + +All contributions by Google: +Copyright (c) 2015 Google Inc. +All rights reserved. + +All contributions by Yangqing Jia: +Copyright (c) 2015 Yangqing Jia +All rights reserved. + +All contributions from Caffe: +Copyright(c) 2013, 2014, 2015, the respective contributors +All rights reserved. + +All other contributions: +Copyright(c) 2015, 2016 the respective contributors +All rights reserved. + +Caffe2 uses a copyright model similar to Caffe: each contributor holds +copyright over their contributions to Caffe2. The project versioning records +all such contribution and copyright details. If a contributor wants to further +mark their specific copyright on a particular contribution, they should +indicate their copyright solely in the commit message of the change when it is +committed. + +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + +3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America + and IDIAP Research Institute nor the names of its contributors may be + used to endorse or promote products derived from this software without + specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. diff --git a/src/mhnfs/hopfield/README.md b/src/mhnfs/hopfield/README.md new file mode 100644 index 0000000000000000000000000000000000000000..bc3afb01437ec35fcf8864ae87e6055b4886d473 --- /dev/null +++ b/src/mhnfs/hopfield/README.md @@ -0,0 +1,119 @@ +# Hopfield Networks is All You Need + +_Hubert Ramsauer1, +Bernhard SchΓ€fl1, +Johannes Lehner1, +Philipp Seidl1, +Michael Widrich1, +Lukas Gruber1, +Markus Holzleitner1, +Milena PavloviΔ‡3, 4, +Geir Kjetil Sandve4, +Victor Greiff3, +David Kreil2, +Michael Kopp2, +GΓΌnter Klambauer1, +Johannes Brandstetter1, +Sepp Hochreiter1, 2_ + +1 ELLIS Unit Linz and LIT AI Lab, Institute for Machine Learning, Johannes Kepler University Linz, Austria +2 Institute of Advanced Research in Artificial Intelligence (IARAI) +3 Department of Immunology, University of Oslo, Norway +4 Department of Informatics, University of Oslo, Norway + +--- + +##### Detailed blog post on this paper as well as the necessary background on Hopfield networks at [this link](https://ml-jku.github.io/hopfield-layers/). + +--- + +The transformer and BERT models pushed the performance on NLP tasks to new levels via their attention mechanism. We show +that this attention mechanism is the update rule of a modern Hopfield network with continuous states. This new Hopfield +network can store exponentially (with the dimension) many patterns,converges with one update, and has exponentially +small retrieval errors. The number of stored patterns must be traded off against convergence speed and retrieval error. +The new Hopfield network has three types of energy minima (fixed points of the update): + +1. global fixed point averaging over all patterns, +2. metastable states averaging over a subset of patterns, and +3. fixed points which store a single pattern. + +Transformers learn an attention mechanism by constructing an embedding of patterns and queries into an associative +space. Transformer and BERT models operate in their first layers preferably in the global averaging regime, while they +operate in higher layers in metastable states. The gradient in transformers is maximal in the regime of metastable +states, is uniformly distributed when averaging globally, and vanishes when a fixed point is near a stored pattern. +Based on the Hopfield network interpretation, we analyzed learning of transformer and BERT architectures. Learning starts +with attention heads that average and then most of them switch to metastable states. However, the majority of heads in +the first layers still averages and can be replaced by averaging operations like the Gaussian weighting that we propose. +In contrast, heads in the last layers steadily learn and seem to use metastable states to collect information created in +lower layers. These heads seem a promising target for improving transformers. Neural networks that integrate Hopfield +networks that are equivalent to attention heads outperform other methods on immune repertoire classification, where the +Hopfield net stores several hundreds of thousands of patterns. + +With _this_ repository, we provide a PyTorch implementation of a new layer called +β€œHopfield” which allows to equip deep learning architectures with Hopfield networks as new memory concepts. + +The full paper is available at [https://arxiv.org/abs/2008.02217](https://arxiv.org/abs/2008.02217). + +## Requirements + +The software was developed and tested on the following 64-bit operating systems: + +- CentOS Linux release 8.1.1911 (Core) +- macOS 10.15.5 (Catalina) + +As the development environment, [Python](https://www.python.org) 3.8.3 in combination with [PyTorch](https://pytorch.org) 1.6.0 was used (a version of at least 1.5.0 should be sufficient). More details on how to install PyTorch are available on the [official project page](https://pytorch.org). + +## Usage + +To get up and running with Hopfield-based networks, only one argument needs to be set, the size (depth) of the input. + +```python +hopfield = Hopfield(input_size=...) +``` + +It is also possible to replace commonly used pooling functions with a Hopfield-based one. Internally, a state pattern is trained, which in turn is used to compute pooling weights with respect to the input. + +```python +hopfield_pooling = HopfieldPooling(input_size=...) +``` + +A second variant of our Hopfield-based modules is one which employs a trainable but fixed lookup mechanism. Internally, one or multiple stored patterns and pattern projections are trained (optionally in a non-shared manner), which in turn are used as a lookup mechanism independent of the input data. + +```python +hopfield_lookup = HopfieldLayer(input_size=...) +``` + +The usage is as simple as with the main module, but equally powerful. + +## Examples + +Generally, the Hopfield layer is designed to be used to implement or to substitute different layers like: + +- Pooling layers: We consider the Hopfield layer as a pooling layer if only one static state (query) pattern exists. Then, it is de facto a pooling over the sequence, which results from the softmax values applied on the stored patterns. Therefore, our Hopfield layer can act as a pooling layer. + +- Permutation equivariant layers: Our Hopfield layer can be used as a plug-in replacement for permutation equivariant layers. Since the Hopfield layer is an associative memory it assumes no dependency between the input patterns. + +- GRU & LSTM layers: Our Hopfield layer can be used as a plug-in replacement for GRU & LSTM layers. Optionally, for substituting GRU & LSTM layers, positional encoding might be considered. + +- Attention layers: Our Hopfield layer can act as an attention layer, where state (query) and stored (key) patterns are different, and need to be associated. + +The folder [examples](examples/) contains multiple demonstrations on how to use the Hopfield, HopfieldPooling as well as the HopfieldLayer modules. To successfully run the contained [Jupyter notebooks](https://jupyter.org), additional third-party modules like [pandas](https://pandas.pydata.org) and [seaborn](https://seaborn.pydata.org) are required. + +- [Bit Pattern Set](examples/bit_pattern/bit_pattern_demo.ipynb): The dataset of this demonstration falls into the category of binary classification tasks in the domain of Multiple Instance Learning (MIL) problems. Each bag comprises a collection of bit pattern instances, wheres each instance is a sequence of 0s and 1s. The positive class has specific bit patterns injected, which are absent in the negative one. This demonstration shows, that Hopfield, HopfieldPooling and HopfieldLayer are capable of learning and filtering each bag with respect to the class-defining bit patterns. + +- [Latch Sequence Set](examples/latch_sequence/latch_sequence_demo.ipynb): We study an easy example of learning long-term dependencies by using a simple latch task, see [Hochreiter and Mozer](https://link.springer.com/chapter/10.1007/3-540-44668-0_92). The essence of this task is that a sequence of inputs is presented, beginning with one of two symbols, A or B, and after a variable number of time steps, the model has to output a corresponding symbol. Thus, the task requires memorizing the original input over time. It has to be noted, that both class-defining symbols must only appear at the first position of a sequence. This task was specifically designed to demonstrate the capability of recurrent neural networks to capture long term dependencies. This demonstration shows, that Hopfield, HopfieldPooling and HopfieldLayer adapt extremely fast to this specific task, concentrating only on the first entry of the sequence. + +- [Attention-based Deep Multiple Instance Learning](examples/mnist_bags/mnist_bags_demo.ipynb): The dataset of this demonstration falls into the category of binary classification tasks in the domain of Multiple Instance Learning (MIL) problems, see [Ilse and Tomczak](https://arxiv.org/abs/1802.04712). Each bag comprises a collection of 28x28 grayscale images/instances, whereas each instance is a sequence of pixel values in the range of [0; 255]. The amount of instances per pag is drawn from a Gaussian with specified mean and variance. The positive class is defined by the presence of the target number/digit, whereas the negative one by its absence. + +## Disclaimer + +Some implementations of this repository are based on existing ones of the official [PyTorch repository v1.6.0](https://github.com/pytorch/pytorch/tree/v1.6.0) and accordingly extended and modified. In the following, the involved parts are listed: + +- The implementation of [HopfieldCore](modules/activation.py#L11) is based on the implementation of [MultiheadAttention](https://github.com/pytorch/pytorch/blob/b31f58de6fa8bbda5353b3c77d9be4914399724d/torch/nn/modules/activation.py#L771). +- The implementation of [hopfield_core_forward](modules/functional.py#L8) is based on the implementation of [multi_head_attention_forward](https://github.com/pytorch/pytorch/blob/b31f58de6fa8bbda5353b3c77d9be4914399724d/torch/nn/functional.py#L3854). +- The implementation of [HopfieldEncoderLayer](modules/transformer.py#L12) is based on the implementation of [TransformerEncoderLayer](https://github.com/pytorch/pytorch/blob/b31f58de6fa8bbda5353b3c77d9be4914399724d/torch/nn/modules/transformer.py#L241). +- The implementation of [HopfieldDecoderLayer](modules/transformer.py#L88) is based on the implementation of [TransformerDecoderLayer](https://github.com/pytorch/pytorch/blob/b31f58de6fa8bbda5353b3c77d9be4914399724d/torch/nn/modules/transformer.py#L303). + +## License + +This repository is BSD-style licensed (see [LICENSE](LICENSE)), except where noted otherwise. \ No newline at end of file diff --git a/src/mhnfs/hopfield/__init__.py b/src/mhnfs/hopfield/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/mhnfs/hopfield/__pycache__/__init__.cpython-37.pyc b/src/mhnfs/hopfield/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a185a4fc0f402dd6c0a9a82821c064767effe6c0 Binary files /dev/null and b/src/mhnfs/hopfield/__pycache__/__init__.cpython-37.pyc differ diff --git a/src/mhnfs/hopfield/__pycache__/__init__.cpython-38.pyc b/src/mhnfs/hopfield/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2e46ddbbf242fac3d42ab6d882c300182e12c288 Binary files /dev/null and b/src/mhnfs/hopfield/__pycache__/__init__.cpython-38.pyc differ diff --git a/src/mhnfs/hopfield/__pycache__/__init__.cpython-39.pyc b/src/mhnfs/hopfield/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f41640cb4dbe924a93ab5b0a7fc03b89300f2ed Binary files /dev/null and b/src/mhnfs/hopfield/__pycache__/__init__.cpython-39.pyc differ diff --git a/src/mhnfs/hopfield/auxiliary/__init__.py b/src/mhnfs/hopfield/auxiliary/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/mhnfs/hopfield/auxiliary/data.py b/src/mhnfs/hopfield/auxiliary/data.py new file mode 100644 index 0000000000000000000000000000000000000000..2c3e393145921d6588d6b0adc41f036b6bd5f4a9 --- /dev/null +++ b/src/mhnfs/hopfield/auxiliary/data.py @@ -0,0 +1,252 @@ +import torch + +from math import ceil +from torch.utils.data import Dataset +from typing import Dict, Optional, Sequence, Tuple, Union + + +class BitPatternSet(Dataset): + """ + Binary multiple instance learning (MIL) data set comprising bit patterns as instances, + with implanted bit patterns unique to one of the classes. + """ + + def __init__(self, num_bags: int, num_instances: int, num_signals: int, num_signals_per_bag: int = 1, + fraction_targets: float = 0.5, num_bits: int = 8, dtype: torch.dtype = torch.float32, + seed_signals: int = 43, seed_data: int = 44): + """ + Create new binary bit pattern data set conforming to the specified properties. + + :param num_bags: amount of bags + :param num_instances: amount of instances per bag + :param num_signals: amount of unique signals used to distinguish both classes + :param num_signals_per_bag: amount of unique signals to be implanted per bag + :param fraction_targets: fraction of targets + :param num_bits: amount of bits per instance + :param dtype: data type of instances + :param seed_signals: random seed used to generate the signals of the data set (excl. samples) + :param seed_data: random seed used to generate the samples of the data set (excl. signals) + """ + super(BitPatternSet, self).__init__() + assert (type(num_bags) == int) and (num_bags > 0), r'"num_bags" must be a positive integer!' + assert (type(num_instances) == int) and (num_instances > 0), r'"num_instances" must be a positive integer!' + assert (type(num_signals) == int) and (num_signals > 0), r'"num_signals" must be a positive integer!' + assert (type(num_signals_per_bag) == int) and (num_signals_per_bag >= 0) and ( + num_signals_per_bag <= num_instances), r'"num_signals_per_bag" must be a non-negative integer!' + assert (type(fraction_targets) == float) and (fraction_targets > 0) and ( + fraction_targets < 1), r'"fraction_targets" must be in interval (0; 1)!' + assert (type(num_bits) == int) and (num_bits > 0), r'"num_bits" must be a positive integer!' + assert ((2 ** num_bits) - 1) > num_signals, r'"num_signals" must be smaller than "2 ** num_bits - 1"!' + assert type(seed_signals) == int, r'"seed_signals" must be an integer!' + assert type(seed_data) == int, r'"seed_data" must be an integer!' + + self.__num_bags = num_bags + self.__num_instances = num_instances + self.__num_signals = num_signals + self.__num_signals_per_bag = num_signals_per_bag + self.__fraction_targets = fraction_targets + self.__num_targets = min(self.__num_bags, max(1, ceil(self.__num_bags * self.__fraction_targets))) + self.__num_bits = num_bits + self.__dtype = dtype + self.__seed_signals = seed_signals + self.__seed_data = seed_data + self.__data, self.__targets, self.__signals = self._generate_bit_pattern_set() + + def __len__(self) -> int: + """ + Fetch amount of bags. + + :return: amount of bags + """ + return self.__num_bags + + def __getitem__(self, item_index: int) -> Dict[str, torch.Tensor]: + """ + Fetch specific bag. + + :param item_index: specific bag to fetch + :return: specific bag as dictionary of tensors + """ + return {r'data': self.__data[item_index].to(dtype=self.__dtype), + r'target': self.__targets[item_index].to(dtype=self.__dtype)} + + def _generate_bit_pattern_set(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Generate underlying bit pattern data set. + + :return: tuple containing generated bags, targets and signals + """ + torch.random.manual_seed(seed=self.__seed_signals) + + # Generate signal patterns. + generated_signals = torch.randint(low=0, high=2, size=(self.__num_signals, self.__num_bits)) + check_instances = True + while check_instances: + generated_signals = torch.unique(input=generated_signals, dim=0) + generated_signals = generated_signals[generated_signals.sum(axis=1) != 0] + missing_signals = self.__num_signals - generated_signals.shape[0] + if missing_signals > 0: + generated_signals = torch.cat(tensors=( + generated_signals, torch.randint(low=0, high=2, size=(missing_signals, self.__num_bits))), dim=0) + else: + check_instances = False + + # Generate data and target tensors. + torch.random.manual_seed(seed=self.__seed_data) + generated_data = torch.randint(low=0, high=2, size=(self.__num_bags, self.__num_instances, self.__num_bits)) + generated_targets = torch.zeros(size=(self.__num_bags,), dtype=generated_data.dtype) + generated_targets[:self.__num_targets] = 1 + + # Check invalid (all-zero and signal) instances and re-sample them. + check_instances = True + while check_instances: + invalid_instances = (generated_data.sum(axis=2) == 0).logical_or( + torch.sum(torch.stack([(generated_data == _).all(axis=2) for _ in generated_signals]), axis=0)) + if invalid_instances.sum() > 0: + generated_data[invalid_instances] = torch.randint( + low=0, high=2, size=(invalid_instances.sum(), self.__num_bits)) + else: + check_instances = False + + # Re-implant signal into respective bags. + for data_index in range(self.__num_targets): + implantation_indices = [] + for _ in range(self.__num_signals_per_bag): + while True: + current_implantation_index = torch.randint(low=0, high=generated_data.shape[1], size=(1,)) + if current_implantation_index not in implantation_indices: + implantation_indices.append(current_implantation_index) + break + current_signal_index = torch.randint(low=0, high=generated_signals.shape[0], size=(1,)) + generated_data[data_index, current_implantation_index] = generated_signals[current_signal_index] + + return generated_data, generated_targets, generated_signals + + @property + def num_bags(self) -> int: + return self.__num_bags + + @property + def num_instances(self) -> int: + return self.__num_instances + + @property + def num_bits(self) -> int: + return self.__num_bits + + @property + def num_targets_high(self) -> int: + return self.__num_targets + + @property + def num_targets_low(self) -> int: + return self.__num_bags - self.__num_targets + + @property + def num_signals(self) -> int: + return self.__num_signals + + @property + def num_signals_per_bag(self) -> int: + return self.__num_signals_per_bag + + @property + def initial_seed(self) -> int: + return self.__seed + + @property + def bags(self) -> torch.Tensor: + return self.__data.clone() + + @property + def targets(self) -> torch.Tensor: + return self.__targets.clone() + + @property + def signals(self) -> torch.Tensor: + return self.__signals.clone() + + +class LatchSequenceSet(Dataset): + """ + Latch data set comprising patterns as one-hot encoded instances. + """ + + def __init__(self, num_samples: int, num_instances: int = 20, num_characters: int = 6, + dtype: torch.dtype = torch.float32, seed: int = 43): + """ + Create new latch sequence data set conforming to the specified properties. + + :param num_samples: amount of samples + :param num_instances: amount of instances per sample + :param num_characters: amount of different characters + :param dtype: data type of samples + :param seed: random seed used to generate the samples of the data set + """ + super(LatchSequenceSet, self).__init__() + assert (type(num_samples) == int) and (num_samples > 0), r'"num_samples" must be a positive integer!' + assert (type(num_instances) == int) and (num_instances > 0), r'"num_instances" must be a positive integer!' + assert (type(num_characters) == int) and (num_characters > 0), r'"num_characters" must be a positive integer!' + assert type(seed) == int, r'"seed" must be an integer!' + + self.__num_samples = num_samples + self.__num_instances = num_instances + self.__num_characters = num_characters + self.__dtype = dtype + self.__seed = seed + self.__data, self.__targets = self._generate_latch_sequences() + + def __len__(self) -> int: + """ + Fetch amount of samples. + + :return: amount of samples + """ + return self.__num_samples + + def __getitem__(self, item_index: int) -> Dict[str, torch.Tensor]: + """ + Fetch specific sample. + + :param item_index: specific sample to fetch + :return: specific sample as dictionary of tensors + """ + return {r'data': self.__data[item_index].to(dtype=self.__dtype), + r'target': self.__targets[item_index].to(dtype=self.__dtype)} + + def _generate_latch_sequences(self) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Generate underlying latch sequence data set. + + :return: tuple containing generated data and targets + """ + torch.random.manual_seed(seed=self.__seed) + + # Generate data and target tensors. + generated_data = torch.randint( + low=2, high=self.__num_characters, size=(self.__num_samples, self.__num_instances)) + generated_signal = torch.randint(low=0, high=2, size=(self.__num_samples,)) + generated_data[:, 0] = generated_signal + generated_data = torch.nn.functional.one_hot(input=generated_data, num_classes=self.__num_characters) + + return generated_data, generated_signal + + @property + def num_samples(self) -> int: + return self.__num_samples + + @property + def num_instances(self) -> int: + return self.__num_instances + + @property + def num_characters(self) -> int: + return self.__num_characters + + @property + def initial_seed(self) -> int: + return self.__seed + + @property + def targets(self) -> torch.Tensor: + return self.__targets.clone() diff --git a/src/mhnfs/hopfield/examples/bit_pattern/auxiliary/__init__.py b/src/mhnfs/hopfield/examples/bit_pattern/auxiliary/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/mhnfs/hopfield/examples/bit_pattern/auxiliary/data.py b/src/mhnfs/hopfield/examples/bit_pattern/auxiliary/data.py new file mode 100644 index 0000000000000000000000000000000000000000..2c3e393145921d6588d6b0adc41f036b6bd5f4a9 --- /dev/null +++ b/src/mhnfs/hopfield/examples/bit_pattern/auxiliary/data.py @@ -0,0 +1,252 @@ +import torch + +from math import ceil +from torch.utils.data import Dataset +from typing import Dict, Optional, Sequence, Tuple, Union + + +class BitPatternSet(Dataset): + """ + Binary multiple instance learning (MIL) data set comprising bit patterns as instances, + with implanted bit patterns unique to one of the classes. + """ + + def __init__(self, num_bags: int, num_instances: int, num_signals: int, num_signals_per_bag: int = 1, + fraction_targets: float = 0.5, num_bits: int = 8, dtype: torch.dtype = torch.float32, + seed_signals: int = 43, seed_data: int = 44): + """ + Create new binary bit pattern data set conforming to the specified properties. + + :param num_bags: amount of bags + :param num_instances: amount of instances per bag + :param num_signals: amount of unique signals used to distinguish both classes + :param num_signals_per_bag: amount of unique signals to be implanted per bag + :param fraction_targets: fraction of targets + :param num_bits: amount of bits per instance + :param dtype: data type of instances + :param seed_signals: random seed used to generate the signals of the data set (excl. samples) + :param seed_data: random seed used to generate the samples of the data set (excl. signals) + """ + super(BitPatternSet, self).__init__() + assert (type(num_bags) == int) and (num_bags > 0), r'"num_bags" must be a positive integer!' + assert (type(num_instances) == int) and (num_instances > 0), r'"num_instances" must be a positive integer!' + assert (type(num_signals) == int) and (num_signals > 0), r'"num_signals" must be a positive integer!' + assert (type(num_signals_per_bag) == int) and (num_signals_per_bag >= 0) and ( + num_signals_per_bag <= num_instances), r'"num_signals_per_bag" must be a non-negative integer!' + assert (type(fraction_targets) == float) and (fraction_targets > 0) and ( + fraction_targets < 1), r'"fraction_targets" must be in interval (0; 1)!' + assert (type(num_bits) == int) and (num_bits > 0), r'"num_bits" must be a positive integer!' + assert ((2 ** num_bits) - 1) > num_signals, r'"num_signals" must be smaller than "2 ** num_bits - 1"!' + assert type(seed_signals) == int, r'"seed_signals" must be an integer!' + assert type(seed_data) == int, r'"seed_data" must be an integer!' + + self.__num_bags = num_bags + self.__num_instances = num_instances + self.__num_signals = num_signals + self.__num_signals_per_bag = num_signals_per_bag + self.__fraction_targets = fraction_targets + self.__num_targets = min(self.__num_bags, max(1, ceil(self.__num_bags * self.__fraction_targets))) + self.__num_bits = num_bits + self.__dtype = dtype + self.__seed_signals = seed_signals + self.__seed_data = seed_data + self.__data, self.__targets, self.__signals = self._generate_bit_pattern_set() + + def __len__(self) -> int: + """ + Fetch amount of bags. + + :return: amount of bags + """ + return self.__num_bags + + def __getitem__(self, item_index: int) -> Dict[str, torch.Tensor]: + """ + Fetch specific bag. + + :param item_index: specific bag to fetch + :return: specific bag as dictionary of tensors + """ + return {r'data': self.__data[item_index].to(dtype=self.__dtype), + r'target': self.__targets[item_index].to(dtype=self.__dtype)} + + def _generate_bit_pattern_set(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Generate underlying bit pattern data set. + + :return: tuple containing generated bags, targets and signals + """ + torch.random.manual_seed(seed=self.__seed_signals) + + # Generate signal patterns. + generated_signals = torch.randint(low=0, high=2, size=(self.__num_signals, self.__num_bits)) + check_instances = True + while check_instances: + generated_signals = torch.unique(input=generated_signals, dim=0) + generated_signals = generated_signals[generated_signals.sum(axis=1) != 0] + missing_signals = self.__num_signals - generated_signals.shape[0] + if missing_signals > 0: + generated_signals = torch.cat(tensors=( + generated_signals, torch.randint(low=0, high=2, size=(missing_signals, self.__num_bits))), dim=0) + else: + check_instances = False + + # Generate data and target tensors. + torch.random.manual_seed(seed=self.__seed_data) + generated_data = torch.randint(low=0, high=2, size=(self.__num_bags, self.__num_instances, self.__num_bits)) + generated_targets = torch.zeros(size=(self.__num_bags,), dtype=generated_data.dtype) + generated_targets[:self.__num_targets] = 1 + + # Check invalid (all-zero and signal) instances and re-sample them. + check_instances = True + while check_instances: + invalid_instances = (generated_data.sum(axis=2) == 0).logical_or( + torch.sum(torch.stack([(generated_data == _).all(axis=2) for _ in generated_signals]), axis=0)) + if invalid_instances.sum() > 0: + generated_data[invalid_instances] = torch.randint( + low=0, high=2, size=(invalid_instances.sum(), self.__num_bits)) + else: + check_instances = False + + # Re-implant signal into respective bags. + for data_index in range(self.__num_targets): + implantation_indices = [] + for _ in range(self.__num_signals_per_bag): + while True: + current_implantation_index = torch.randint(low=0, high=generated_data.shape[1], size=(1,)) + if current_implantation_index not in implantation_indices: + implantation_indices.append(current_implantation_index) + break + current_signal_index = torch.randint(low=0, high=generated_signals.shape[0], size=(1,)) + generated_data[data_index, current_implantation_index] = generated_signals[current_signal_index] + + return generated_data, generated_targets, generated_signals + + @property + def num_bags(self) -> int: + return self.__num_bags + + @property + def num_instances(self) -> int: + return self.__num_instances + + @property + def num_bits(self) -> int: + return self.__num_bits + + @property + def num_targets_high(self) -> int: + return self.__num_targets + + @property + def num_targets_low(self) -> int: + return self.__num_bags - self.__num_targets + + @property + def num_signals(self) -> int: + return self.__num_signals + + @property + def num_signals_per_bag(self) -> int: + return self.__num_signals_per_bag + + @property + def initial_seed(self) -> int: + return self.__seed + + @property + def bags(self) -> torch.Tensor: + return self.__data.clone() + + @property + def targets(self) -> torch.Tensor: + return self.__targets.clone() + + @property + def signals(self) -> torch.Tensor: + return self.__signals.clone() + + +class LatchSequenceSet(Dataset): + """ + Latch data set comprising patterns as one-hot encoded instances. + """ + + def __init__(self, num_samples: int, num_instances: int = 20, num_characters: int = 6, + dtype: torch.dtype = torch.float32, seed: int = 43): + """ + Create new latch sequence data set conforming to the specified properties. + + :param num_samples: amount of samples + :param num_instances: amount of instances per sample + :param num_characters: amount of different characters + :param dtype: data type of samples + :param seed: random seed used to generate the samples of the data set + """ + super(LatchSequenceSet, self).__init__() + assert (type(num_samples) == int) and (num_samples > 0), r'"num_samples" must be a positive integer!' + assert (type(num_instances) == int) and (num_instances > 0), r'"num_instances" must be a positive integer!' + assert (type(num_characters) == int) and (num_characters > 0), r'"num_characters" must be a positive integer!' + assert type(seed) == int, r'"seed" must be an integer!' + + self.__num_samples = num_samples + self.__num_instances = num_instances + self.__num_characters = num_characters + self.__dtype = dtype + self.__seed = seed + self.__data, self.__targets = self._generate_latch_sequences() + + def __len__(self) -> int: + """ + Fetch amount of samples. + + :return: amount of samples + """ + return self.__num_samples + + def __getitem__(self, item_index: int) -> Dict[str, torch.Tensor]: + """ + Fetch specific sample. + + :param item_index: specific sample to fetch + :return: specific sample as dictionary of tensors + """ + return {r'data': self.__data[item_index].to(dtype=self.__dtype), + r'target': self.__targets[item_index].to(dtype=self.__dtype)} + + def _generate_latch_sequences(self) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Generate underlying latch sequence data set. + + :return: tuple containing generated data and targets + """ + torch.random.manual_seed(seed=self.__seed) + + # Generate data and target tensors. + generated_data = torch.randint( + low=2, high=self.__num_characters, size=(self.__num_samples, self.__num_instances)) + generated_signal = torch.randint(low=0, high=2, size=(self.__num_samples,)) + generated_data[:, 0] = generated_signal + generated_data = torch.nn.functional.one_hot(input=generated_data, num_classes=self.__num_characters) + + return generated_data, generated_signal + + @property + def num_samples(self) -> int: + return self.__num_samples + + @property + def num_instances(self) -> int: + return self.__num_instances + + @property + def num_characters(self) -> int: + return self.__num_characters + + @property + def initial_seed(self) -> int: + return self.__seed + + @property + def targets(self) -> torch.Tensor: + return self.__targets.clone() diff --git a/src/mhnfs/hopfield/examples/bit_pattern/bit_pattern_demo.ipynb b/src/mhnfs/hopfield/examples/bit_pattern/bit_pattern_demo.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..e1755c9079da99dc1ea7af7e0d619b5082ea5dcf --- /dev/null +++ b/src/mhnfs/hopfield/examples/bit_pattern/bit_pattern_demo.ipynb @@ -0,0 +1,946 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Example: Bit Pattern Set

\n", + "\n", + "The dataset of this demonstration falls into the category of binary classification tasks in the domain of Multiple Instance Learning (MIL) problems. Each bag comprises a collection of bit pattern instances, wheres each instance is a sequence of 0s and 1s. The positive class has specific bit patterns injected, which are absent in the negative one. This demonstration shows, that Hopfield, HopfieldPooling and HopfieldLayer are capable of learning and filtering each bag with respect to the class-defining bit patterns.\n", + "\n", + "This demonstration instructs how to apply Hopfield, HopfieldPooling and HopfieldLayer for an exemplary Multiple Instance Learning problem.\n", + "\n", + "

In the chapters Adapt Hopfield-based Network, Adapt Hopfield-based Pooling and Adapt Hopfield-based Lookup you can explore and try the new functionalities of our new Hopfield layer.

\n", + "\n", + "In order to run this notebook, a few modules need to be imported. The installation of third-party modules is not covered here." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "# Import general modules used e.g. for plotting.\n", + "import matplotlib.pyplot as plt\n", + "import os\n", + "import pandas as pd\n", + "import seaborn as sns\n", + "import sys\n", + "import torch\n", + "\n", + "# Importing Hopfield-specific modules.\n", + "from auxiliary.data import BitPatternSet\n", + "from modules import Hopfield, HopfieldPooling, HopfieldLayer\n", + "\n", + "# Import auxiliary modules.\n", + "from distutils.version import LooseVersion\n", + "from typing import List, Tuple\n", + "\n", + "# Importing PyTorch specific modules.\n", + "from torch import Tensor\n", + "from torch.nn import Flatten, Linear, Module, Sequential\n", + "from torch.nn.functional import binary_cross_entropy_with_logits\n", + "from torch.nn.utils import clip_grad_norm_\n", + "from torch.optim import AdamW\n", + "from torch.utils.data import DataLoader\n", + "from torch.utils.data.sampler import SubsetRandomSampler\n", + "\n", + "# Set plotting style.\n", + "sns.set()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Specific minimum versions of Python itself as well as of some used modules is required." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Installed Python version: 3.8.8 (βœ“)\n", + "Installed PyTorch version: 1.7.0 (βœ“)\n" + ] + } + ], + "source": [ + "python_check = '(\\u2713)' if sys.version_info >= (3, 8) else '(\\u2717)'\n", + "pytorch_check = '(\\u2713)' if torch.__version__ >= LooseVersion(r'1.5') else '(\\u2717)'\n", + "\n", + "print(f'Installed Python version: {sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro} {python_check}')\n", + "print(f'Installed PyTorch version: {torch.__version__} {pytorch_check}')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Create Dataset

\n", + "\n", + "The dataset itself falls into the category of binary classification tasks in the domain of Multiple Instance Learning (MIL) problems. Defining arguments are:\n", + "

\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
ArgumentValue (used in this demo)Description
num_bags2048Amount of bags (samples) of the full dataset.
num_instances16Amount of instances per bag (sample).
num_signals8Amount of unique instances indicative for the positive class.
num_signals_per_bag1Amount of signals implanted into one bag of a positive class.
num_bits8Amount of \"bits\" (feature size) per instance.
...defaultThe remaining arguments are not explicitly used in this demo.
\n", + "\n", + "Let's define the dataset using previously mentioned properties as well as a logging directory for storing all auxiliary outputs like performance plots." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "bit_pattern_set = BitPatternSet(\n", + " num_bags=2048,\n", + " num_instances=16,\n", + " num_signals=8,\n", + " num_signals_per_bag=1,\n", + " num_bits=8)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "log_dir = f'resources/'\n", + "os.makedirs(log_dir, exist_ok=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Create Auxiliaries

\n", + "\n", + "Before digging into Hopfield-based networks, a few auxiliary variables and functions need to be defined. This is nothing special with respect to Hopfield-based networks, but rather common preparation work of (almost) every machine learning setting (e.g. definition of a data loader as well as a training loop). We will see, that this comprises the most work of this whole demo." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "device = torch.device(r'cuda:0' if torch.cuda.is_available() else r'cpu')\n", + "\n", + "# Create data loader of training set.\n", + "sampler_train = SubsetRandomSampler(list(range(256, 2048 - 256)))\n", + "data_loader_train = DataLoader(dataset=bit_pattern_set, batch_size=32, sampler=sampler_train)\n", + "\n", + "# Create data loader of validation set.\n", + "sampler_eval = SubsetRandomSampler(list(range(256)) + list(range(2048 - 256, 2048)))\n", + "data_loader_eval = DataLoader(dataset=bit_pattern_set, batch_size=32, sampler=sampler_eval)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "def train_epoch(network: Module,\n", + " optimiser: AdamW,\n", + " data_loader: DataLoader\n", + " ) -> Tuple[float, float]:\n", + " \"\"\"\n", + " Execute one training epoch.\n", + " \n", + " :param network: network instance to train\n", + " :param optimiser: optimiser instance responsible for updating network parameters\n", + " :param data_loader: data loader instance providing training data\n", + " :return: tuple comprising training loss as well as accuracy\n", + " \"\"\"\n", + " network.train()\n", + " losses, accuracies = [], []\n", + " for sample_data in data_loader:\n", + " data, target = sample_data[r'data'], sample_data[r'target']\n", + " data, target = data.to(device=device), target.to(device=device)\n", + "\n", + " # Process data by Hopfield-based network.\n", + " model_output = network.forward(input=data)\n", + "\n", + " # Update network parameters.\n", + " optimiser.zero_grad()\n", + " loss = binary_cross_entropy_with_logits(input=model_output, target=target, reduction=r'mean')\n", + " loss.backward()\n", + " clip_grad_norm_(parameters=network.parameters(), max_norm=1.0, norm_type=2)\n", + " optimiser.step()\n", + "\n", + " # Compute performance measures of current model.\n", + " accuracy = (model_output.sigmoid().round() == target).to(dtype=torch.float32).mean()\n", + " accuracies.append(accuracy.detach().item())\n", + " losses.append(loss.detach().item())\n", + " \n", + " # Report progress of training procedure.\n", + " return (sum(losses) / len(losses), sum(accuracies) / len(accuracies))\n", + "\n", + "\n", + "def eval_iter(network: Module,\n", + " data_loader: DataLoader\n", + " ) -> Tuple[float, float]:\n", + " \"\"\"\n", + " Evaluate the current model.\n", + " \n", + " :param network: network instance to evaluate\n", + " :param data_loader: data loader instance providing validation data\n", + " :return: tuple comprising validation loss as well as accuracy\n", + " \"\"\"\n", + " network.eval()\n", + " with torch.no_grad():\n", + " losses, accuracies = [], []\n", + " for sample_data in data_loader:\n", + " data, target = sample_data[r'data'], sample_data[r'target']\n", + " data, target = data.to(device=device), target.to(device=device)\n", + "\n", + " # Process data by Hopfield-based network.\n", + " model_output = network.forward(input=data)\n", + " loss = binary_cross_entropy_with_logits(input=model_output, target=target, reduction=r'mean')\n", + "\n", + " # Compute performance measures of current model.\n", + " accuracy = (model_output.sigmoid().round() == target).to(dtype=torch.float32).mean()\n", + " accuracies.append(accuracy.detach().item())\n", + " losses.append(loss.detach().item())\n", + "\n", + " # Report progress of validation procedure.\n", + " return (sum(losses) / len(losses), sum(accuracies) / len(accuracies))\n", + "\n", + "\n", + "def operate(network: Module,\n", + " optimiser: AdamW,\n", + " data_loader_train: DataLoader,\n", + " data_loader_eval: DataLoader,\n", + " num_epochs: int = 1\n", + " ) -> Tuple[pd.DataFrame, pd.DataFrame]:\n", + " \"\"\"\n", + " Train the specified network by gradient descent using backpropagation.\n", + " \n", + " :param network: network instance to train\n", + " :param optimiser: optimiser instance responsible for updating network parameters\n", + " :param data_loader_train: data loader instance providing training data\n", + " :param data_loader_eval: data loader instance providing validation data\n", + " :param num_epochs: amount of epochs to train\n", + " :return: data frame comprising training as well as evaluation performance\n", + " \"\"\"\n", + " losses, accuracies = {r'train': [], r'eval': []}, {r'train': [], r'eval': []}\n", + " for epoch in range(num_epochs):\n", + " \n", + " # Train network.\n", + " performance = train_epoch(network, optimiser, data_loader_train)\n", + " losses[r'train'].append(performance[0])\n", + " accuracies[r'train'].append(performance[1])\n", + " \n", + " # Evaluate current model.\n", + " performance = eval_iter(network, data_loader_eval)\n", + " losses[r'eval'].append(performance[0])\n", + " accuracies[r'eval'].append(performance[1])\n", + " \n", + " # Report progress of training and validation procedures.\n", + " return pd.DataFrame(losses), pd.DataFrame(accuracies)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "def set_seed(seed: int = 42) -> None:\n", + " \"\"\"\n", + " Set seed for all underlying (pseudo) random number sources.\n", + " \n", + " :param seed: seed to be used\n", + " :return: None\n", + " \"\"\"\n", + " torch.manual_seed(42)\n", + " torch.backends.cudnn.deterministic = True\n", + " torch.backends.cudnn.benchmark = False\n", + "\n", + "\n", + "def plot_performance(loss: pd.DataFrame,\n", + " accuracy: pd.DataFrame,\n", + " log_file: str\n", + " ) -> None:\n", + " \"\"\"\n", + " Plot and save loss and accuracy.\n", + " \n", + " :param loss: loss to be plotted\n", + " :param accuracy: accuracy to be plotted\n", + " :param log_file: target file for storing the resulting plot\n", + " :return: None\n", + " \"\"\"\n", + " fig, ax = plt.subplots(1, 2, figsize=(20, 7))\n", + " \n", + " loss_plot = sns.lineplot(data=loss, ax=ax[0])\n", + " loss_plot.set(xlabel=r'Epoch', ylabel=r'Cross-entropy Loss')\n", + " \n", + " accuracy_plot = sns.lineplot(data=accuracy, ax=ax[1])\n", + " accuracy_plot.set(xlabel=r'Epoch', ylabel=r'Accuracy')\n", + " \n", + " ax[1].yaxis.set_label_position(r'right')\n", + " fig.tight_layout()\n", + " fig.savefig(log_file)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Hopfield-based Network

\n", + "\n", + "The instantiation of the heart of a Hopfield-based network, the module Hopfield, is rather straightforward. Only one argument, the size of the input, needs to be set.\n", + "

\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
ArgumentValue (used in this demo)Description
input_sizenum_bits (8)Size (depth) of the input (state pattern).
...defaultThe remaining arguments are not explicitly used in this example.
\n", + "\n", + "An additional output projection is defined, to downproject the result of Hopfield to the correct output size. Afterwards, everything is wrapped into a container of type torch.nn.Sequential and a corresponding optimiser is defined. Now the Hopfield-based network and all auxiliaries are set up and ready to associate!" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "set_seed()\n", + "hopfield = Hopfield(\n", + " input_size=bit_pattern_set.num_bits)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "output_projection = Linear(in_features=hopfield.output_size * bit_pattern_set.num_instances, out_features=1)\n", + "network = Sequential(hopfield, Flatten(), output_projection, Flatten(start_dim=0)).to(device=device)\n", + "optimiser = AdamW(params=network.parameters(), lr=1e-3)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Operate Hopfield-based Network

" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "losses, accuracies = operate(\n", + " network=network,\n", + " optimiser=optimiser,\n", + " data_loader_train=data_loader_train,\n", + " data_loader_eval=data_loader_eval,\n", + " num_epochs=250)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plot_performance(loss=losses, accuracy=accuracies, log_file=f'{log_dir}/hopfield_base.pdf')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Adapt Hopfield-based Network

\n", + "

We can now explore the functionality of our Hopfield layer Hopfield.

\n", + "\n", + "As described in the paper the Hopfield layer allows:\n", + "- association of two sets\n", + "- multiple updates\n", + "- variable beta\n", + "- changing the dimension of the associative space\n", + "- pattern normalization\n", + "- static patterns for fixed pattern search\n", + "\n", + "This time, additional arguments are set to increase the training as well as the validation performance of the Hopfield-based network.\n", + "

\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
ArgumentValue (used in this demo)Description
input_sizenum_bits (8)Size (depth) of the input (state pattern).
hidden_size8Size (depth) of the association space.
num_heads8Amount of parallel association heads.
update_steps_max3Number of updates in one Hopfield head.
scaling0.25Beta parameter that determines the kind of fixed point.
dropout0.5Dropout probability applied on the association matrix.
...defaultThe remaining arguments are not explicitly used in this example.
" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "set_seed()\n", + "hopfield = Hopfield(\n", + " input_size=bit_pattern_set.num_bits,\n", + " hidden_size=8,\n", + " num_heads=8,\n", + " update_steps_max=3,\n", + " scaling=0.25,\n", + " dropout=0.5)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "output_projection = Linear(in_features=hopfield.output_size * bit_pattern_set.num_instances, out_features=1)\n", + "network = Sequential(hopfield, Flatten(), output_projection, Flatten(start_dim=0)).to(device=device)\n", + "optimiser = AdamW(params=network.parameters(), lr=1e-3)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "losses, accuracies = operate(\n", + " network=network,\n", + " optimiser=optimiser,\n", + " data_loader_train=data_loader_train,\n", + " data_loader_eval=data_loader_eval,\n", + " num_epochs=250)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plot_performance(loss=losses, accuracy=accuracies, log_file=f'{log_dir}/hopfield_adapted.pdf')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Hopfield-based Pooling

\n", + "\n", + "The previous examples manually downprojected the result of Hopfield by applying a linear layer afterwards. It would've also been possible to apply some kind of pooling. Exactly for such use cases, the module HopfieldPooling might be the right choice. Internally, a state pattern is trained, which in turn is used to compute pooling weights with respect to the input." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "set_seed()\n", + "hopfield_pooling = HopfieldPooling(\n", + " input_size=bit_pattern_set.num_bits)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "output_projection = Linear(in_features=hopfield_pooling.output_size, out_features=1)\n", + "network = Sequential(hopfield_pooling, output_projection, Flatten(start_dim=0)).to(device=device)\n", + "optimiser = AdamW(params=network.parameters(), lr=1e-3)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Operate Hopfield-based Pooling

" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "losses, accuracies = operate(\n", + " network=network,\n", + " optimiser=optimiser,\n", + " data_loader_train=data_loader_train,\n", + " data_loader_eval=data_loader_eval,\n", + " num_epochs=250)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plot_performance(loss=losses, accuracy=accuracies, log_file=f'{log_dir}/hopfield_pooling.pdf')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Adapt Hopfield-based Pooling

\n", + "

We can now again explore the functionality of our Hopfield-based pooling layer HopfieldPooling.

\n", + "\n", + "Again, additional arguments are set to increase the training as well as the validation performance of the Hopfield-based pooling.\n", + "

\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
ArgumentValue (used in this demo)Description
input_sizenum_bits (8)Size (depth) of the input (state pattern).
hidden_size8Size (depth) of the association space.
num_heads8Amount of parallel association heads.
update_steps_max3Number of updates in one Hopfield head.
scaling0.25Beta parameter that determines the kind of fixed point.
dropout0.5Dropout probability applied on the association matrix.
...defaultThe remaining arguments are not explicitly used in this example.
" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "set_seed()\n", + "hopfield_pooling = HopfieldPooling(\n", + " input_size=bit_pattern_set.num_bits,\n", + " hidden_size=8,\n", + " num_heads=8,\n", + " update_steps_max=3,\n", + " scaling=0.25,\n", + " dropout=0.5)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "output_projection = Linear(in_features=hopfield_pooling.output_size, out_features=1)\n", + "network = Sequential(hopfield_pooling, output_projection, Flatten(start_dim=0)).to(device=device)\n", + "optimiser = AdamW(params=network.parameters(), lr=1e-3)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "losses, accuracies = operate(\n", + " network=network,\n", + " optimiser=optimiser,\n", + " data_loader_train=data_loader_train,\n", + " data_loader_eval=data_loader_eval,\n", + " num_epochs=250)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plot_performance(loss=losses, accuracy=accuracies, log_file=f'{log_dir}/hopfield_pooling_adapted.pdf')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Hopfield-based Lookup

\n", + "\n", + "In contrast to the first Hopfield setting, in which the state patterns as well as the stored patterns are directly dependent on the input, HopfieldLayer employs a trainable but fixed stored pattern matrix, which in turn acts as a learnable lookup table." + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [], + "source": [ + "bit_samples_unique = [_[r'data'] for _ in data_loader_train]\n", + "bit_samples_unique = torch.cat(bit_samples_unique).view(-1, bit_samples_unique[0].shape[2]).unique(dim=0)" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [], + "source": [ + "set_seed()\n", + "hopfield_lookup = HopfieldLayer(\n", + " input_size=bit_pattern_set.num_bits,\n", + " quantity=len(bit_samples_unique))" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [], + "source": [ + "output_projection = Linear(in_features=hopfield_lookup.output_size * bit_pattern_set.num_instances, out_features=1)\n", + "network = Sequential(hopfield_lookup, Flatten(start_dim=1), output_projection, Flatten(start_dim=0)).to(device=device)\n", + "optimiser = AdamW(params=network.parameters(), lr=1e-3)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Operate Hopfield-based Lookup

" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [], + "source": [ + "losses, accuracies = operate(\n", + " network=network,\n", + " optimiser=optimiser,\n", + " data_loader_train=data_loader_train,\n", + " data_loader_eval=data_loader_eval,\n", + " num_epochs=250)" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plot_performance(loss=losses, accuracy=accuracies, log_file=f'{log_dir}/hopfield_lookup.pdf')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Adapt Hopfield-based Lookup

\n", + "

We can now again explore the functionality of our Hopfield-based lookup layer HopfieldLayer.

\n", + "\n", + "This lookup setting is especially pronounced, if the state patterns are initialized with a subset of the training set (and optionally provide the corresponding training targets as pattern projection inputs).\n", + "\n", + "Again, additional arguments are set to increase the training as well as the validation performance of the Hopfield-based lookup.\n", + "

\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
ArgumentValue (used in this demo)Description
lookup_weights_as_separatedTrueSeparate lookup weights from lookup target weights (e.g. to set lookup target weights separately).
lookup_targets_as_trainableFalseEmploy trainable lookup target weights (used as pattern projection input).
" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [], + "source": [ + "set_seed()\n", + "hopfield_lookup = HopfieldLayer(\n", + " input_size=bit_pattern_set.num_bits,\n", + " quantity=len(bit_samples_unique),\n", + " lookup_weights_as_separated=True,\n", + " lookup_targets_as_trainable=False,\n", + " normalize_stored_pattern_affine=True,\n", + " normalize_pattern_projection_affine=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Initialize the trainable but fixed stored patterns with all unique samples from the training set. In this way, the Hopfield-based lookup already starts with meaningful stored patterns (instead of random noise). This may enhance the performance of the network, especially at the beginning of the training." + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [], + "source": [ + "with torch.no_grad():\n", + " hopfield_lookup.lookup_weights[:] = bit_samples_unique.unsqueeze(dim=0)" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [], + "source": [ + "output_projection = Linear(in_features=hopfield_lookup.output_size * bit_pattern_set.num_instances, out_features=1)\n", + "network = Sequential(hopfield_lookup, Flatten(start_dim=1), output_projection, Flatten(start_dim=0)).to(device=device)\n", + "optimiser = AdamW(params=network.parameters(), lr=1e-3)" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [], + "source": [ + "losses, accuracies = operate(\n", + " network=network,\n", + " optimiser=optimiser,\n", + " data_loader_train=data_loader_train,\n", + " data_loader_eval=data_loader_eval,\n", + " num_epochs=250)" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plot_performance(loss=losses, accuracy=accuracies, log_file=f'{log_dir}/hopfield_lookup_adapted.pdf')" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.8" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/src/mhnfs/hopfield/examples/bit_pattern/modules/__init__.py b/src/mhnfs/hopfield/examples/bit_pattern/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c11713739acb1aca3b32cca53acf0c3faf1fa9f6 --- /dev/null +++ b/src/mhnfs/hopfield/examples/bit_pattern/modules/__init__.py @@ -0,0 +1,898 @@ +import torch +import torch.nn as nn + +from math import sqrt +from torch import Tensor +from torch.nn import Module, Parameter +from typing import Optional, Tuple, Union + +from .activation import HopfieldCore + + +class Hopfield(Module): + """ + Module with underlying Hopfield association. + """ + + def __init__(self, + input_size: Optional[int] = None, + hidden_size: Optional[int] = None, + output_size: Optional[int] = None, + pattern_size: Optional[int] = None, + num_heads: int = 1, + scaling: Optional[Union[float, Tensor]] = None, + update_steps_max: Optional[Union[int, Tensor]] = 0, + update_steps_eps: Union[float, Tensor] = 1e-4, + + normalize_stored_pattern: bool = True, + normalize_stored_pattern_affine: bool = True, + normalize_state_pattern: bool = True, + normalize_state_pattern_affine: bool = True, + normalize_pattern_projection: bool = True, + normalize_pattern_projection_affine: bool = True, + normalize_hopfield_space: bool = False, + normalize_hopfield_space_affine: bool = False, + stored_pattern_as_static: bool = False, + state_pattern_as_static: bool = False, + pattern_projection_as_static: bool = False, + pattern_projection_as_connected: bool = False, + stored_pattern_size: Optional[int] = None, + pattern_projection_size: Optional[int] = None, + + batch_first: bool = True, + association_activation: Optional[str] = None, + dropout: float = 0.0, + input_bias: bool = True, + concat_bias_pattern: bool = False, + add_zero_association: bool = False, + disable_out_projection: bool = False + ): + """ + Initialise new instance of a Hopfield module. + + :param input_size: depth of the input (state pattern) + :param hidden_size: depth of the association space + :param output_size: depth of the output projection + :param pattern_size: depth of patterns to be selected + :param num_heads: amount of parallel association heads + :param scaling: scaling of association heads, often represented as beta (one entry per head) + :param update_steps_max: maximum count of association update steps (None equals to infinity) + :param update_steps_eps: minimum difference threshold between two consecutive association update steps + :param normalize_stored_pattern: apply normalization on stored patterns + :param normalize_stored_pattern_affine: additionally enable affine normalization of stored patterns + :param normalize_state_pattern: apply normalization on state patterns + :param normalize_state_pattern_affine: additionally enable affine normalization of state patterns + :param normalize_pattern_projection: apply normalization on the pattern projection + :param normalize_pattern_projection_affine: additionally enable affine normalization of pattern projection + :param normalize_hopfield_space: enable normalization of patterns in the Hopfield space + :param normalize_hopfield_space_affine: additionally enable affine normalization of patterns in Hopfield space + :param stored_pattern_as_static: interpret specified stored patterns as being static + :param state_pattern_as_static: interpret specified state patterns as being static + :param pattern_projection_as_static: interpret specified pattern projections as being static + :param pattern_projection_as_connected: connect pattern projection with stored pattern + :param stored_pattern_size: depth of input (stored pattern) + :param pattern_projection_size: depth of input (pattern projection) + :param batch_first: flag for specifying if the first dimension of data fed to "forward" reflects the batch size + :param association_activation: additional activation to be applied on the result of the Hopfield association + :param dropout: dropout probability applied on the association matrix + :param input_bias: bias to be added to input (state and stored pattern as well as pattern projection) + :param concat_bias_pattern: bias to be concatenated to stored pattern as well as pattern projection + :param add_zero_association: add a new batch of zeros to stored pattern as well as pattern projection + :param disable_out_projection: disable output projection + """ + super(Hopfield, self).__init__() + assert type(batch_first) == bool, f'"batch_first" needs to be a boolean, not {type(batch_first)}.' + assert (association_activation is None) or (type(association_activation) == str) + + # Initialise Hopfield association module. + self.association_core = HopfieldCore( + embed_dim=input_size, num_heads=num_heads, dropout=dropout, bias=input_bias, + add_bias_kv=concat_bias_pattern, add_zero_attn=add_zero_association, kdim=stored_pattern_size, + vdim=pattern_projection_size, head_dim=hidden_size, pattern_dim=pattern_size, out_dim=output_size, + disable_out_projection=disable_out_projection, key_as_static=stored_pattern_as_static, + query_as_static=state_pattern_as_static, value_as_static=pattern_projection_as_static, + value_as_connected=pattern_projection_as_connected, normalize_pattern=normalize_hopfield_space, + normalize_pattern_affine=normalize_hopfield_space_affine) + self.association_activation = None + if association_activation is not None: + self.association_activation = getattr(torch, association_activation, None) + + # Initialise stored pattern normalization. + self.norm_stored_pattern = None + if normalize_stored_pattern_affine: + assert normalize_stored_pattern, "affine normalization without normalization has no effect." + if normalize_stored_pattern: + normalized_shape = input_size if stored_pattern_size is None else stored_pattern_size + assert normalized_shape is not None, "stored pattern size required for setting up normalisation" + self.norm_stored_pattern = nn.LayerNorm( + normalized_shape=normalized_shape, elementwise_affine=normalize_stored_pattern_affine) + + # Initialise state pattern normalization. + self.norm_state_pattern = None + if normalize_state_pattern_affine: + assert normalize_state_pattern, "affine normalization without normalization has no effect." + if normalize_state_pattern: + assert input_size is not None, "input size required for setting up normalisation" + self.norm_state_pattern = nn.LayerNorm( + normalized_shape=input_size, elementwise_affine=normalize_state_pattern_affine) + + # Initialise pattern projection normalization. + self.norm_pattern_projection = None + if normalize_pattern_projection_affine: + assert normalize_pattern_projection, "affine normalization without normalization has no effect." + if normalize_pattern_projection: + normalized_shape = input_size if pattern_projection_size is None else pattern_projection_size + assert normalized_shape is not None, "pattern projection size required for setting up normalisation" + self.norm_pattern_projection = nn.LayerNorm( + normalized_shape=normalized_shape, elementwise_affine=normalize_pattern_projection_affine) + + # Initialise remaining auxiliary properties. + if self.association_core.static_execution: + self.__scaling = 1.0 if scaling is None else scaling + else: + assert self.association_core.head_dim > 0, f'invalid hidden dimension encountered.' + self.__scaling = (1.0 / sqrt(self.association_core.head_dim)) if scaling is None else scaling + self.__batch_first = batch_first + self.__update_steps_max = update_steps_max + self.__update_steps_eps = update_steps_eps + self.reset_parameters() + + def reset_parameters(self) -> None: + """ + Reset Hopfield association. + + :return: None + """ + for module in (self.association_core, self.norm_stored_pattern, + self.norm_state_pattern, self.norm_pattern_projection): + if hasattr(module, r'reset_parameters'): + module.reset_parameters() + + def _maybe_transpose(self, *args: Tuple[Tensor, ...]) -> Union[Tensor, Tuple[Tensor, ...]]: + """ + Eventually transpose specified data. + + :param args: tensors to eventually transpose (dependent on the state of "batch_first") + :return: eventually transposed tensors + """ + transposed_result = tuple(_.transpose(0, 1) for _ in args) if self.__batch_first else args + return transposed_result[0] if len(transposed_result) == 1 else transposed_result + + def _associate(self, data: Union[Tensor, Tuple[Tensor, Tensor, Tensor]], + return_raw_associations: bool = False, return_projected_patterns: bool = False, + stored_pattern_padding_mask: Optional[Tensor] = None, + association_mask: Optional[Tensor] = None) -> Tuple[Optional[Tensor], ...]: + """ + Apply Hopfield association module on specified data. + + :param data: data to be processed by Hopfield core module + :param return_raw_associations: return raw association (softmax) values, unmodified + :param return_projected_patterns: return pattern projection values, unmodified + :param stored_pattern_padding_mask: mask to be applied on stored patterns + :param association_mask: mask to be applied on inner association matrix + :return: Hopfield-processed input data + """ + assert (type(data) == Tensor) or ((type(data) == tuple) and (len(data) == 3)), \ + r'either one tensor to be used as "stored pattern", "state pattern" and' \ + r' "pattern_projection" must be provided, or three separate ones.' + if type(data) == Tensor: + stored_pattern, state_pattern, pattern_projection = data, data, data + else: + stored_pattern, state_pattern, pattern_projection = data + + # Optionally transpose data. + stored_pattern, state_pattern, pattern_projection = self._maybe_transpose( + stored_pattern, state_pattern, pattern_projection) + + # Optionally apply stored pattern normalization. + if self.norm_stored_pattern is not None: + stored_pattern = self.norm_stored_pattern(input=stored_pattern.reshape( + shape=(-1, stored_pattern.shape[2]))).reshape(shape=stored_pattern.shape) + + # Optionally apply state pattern normalization. + if self.norm_state_pattern is not None: + state_pattern = self.norm_state_pattern(input=state_pattern.reshape( + shape=(-1, state_pattern.shape[2]))).reshape(shape=state_pattern.shape) + + # Optionally apply pattern projection normalization. + if self.norm_pattern_projection is not None: + pattern_projection = self.norm_pattern_projection(input=pattern_projection.reshape( + shape=(-1, pattern_projection.shape[2]))).reshape(shape=pattern_projection.shape) + + # Apply Hopfield association and optional activation function. + return self.association_core( + query=state_pattern, key=stored_pattern, value=pattern_projection, + key_padding_mask=stored_pattern_padding_mask, need_weights=False, attn_mask=association_mask, + scaling=self.__scaling, update_steps_max=self.__update_steps_max, update_steps_eps=self.__update_steps_eps, + return_raw_associations=return_raw_associations, return_pattern_projections=return_projected_patterns) + + def forward(self, input: Union[Tensor, Tuple[Tensor, Tensor, Tensor]], + stored_pattern_padding_mask: Optional[Tensor] = None, + association_mask: Optional[Tensor] = None) -> Tensor: + """ + Apply Hopfield association on specified data. + + :param input: data to be processed by Hopfield association module + :param stored_pattern_padding_mask: mask to be applied on stored patterns + :param association_mask: mask to be applied on inner association matrix + :return: Hopfield-processed input data + """ + association_output = self._maybe_transpose(self._associate( + data=input, return_raw_associations=False, + stored_pattern_padding_mask=stored_pattern_padding_mask, + association_mask=association_mask)[0]) + if self.association_activation is not None: + association_output = self.association_activation(association_output) + return association_output + + def get_association_matrix(self, input: Union[Tensor, Tuple[Tensor, Tensor, Tensor]], + stored_pattern_padding_mask: Optional[Tensor] = None, + association_mask: Optional[Tensor] = None) -> Tensor: + """ + Fetch Hopfield association matrix gathered by passing through the specified data. + + :param input: data to be passed through the Hopfield association + :param stored_pattern_padding_mask: mask to be applied on stored patterns + :param association_mask: mask to be applied on inner association matrix + :return: association matrix as computed by the Hopfield core module + """ + with torch.no_grad(): + return self._associate( + data=input, return_raw_associations=True, + stored_pattern_padding_mask=stored_pattern_padding_mask, + association_mask=association_mask)[2] + + def get_projected_pattern_matrix(self, input: Union[Tensor, Tuple[Tensor, Tensor, Tensor]], + stored_pattern_padding_mask: Optional[Tensor] = None, + association_mask: Optional[Tensor] = None) -> Tensor: + """ + Fetch Hopfield projected pattern matrix gathered by passing through the specified data. + + :param input: data to be passed through the Hopfield association + :param stored_pattern_padding_mask: mask to be applied on stored patterns + :param association_mask: mask to be applied on inner association matrix + :return: pattern projection matrix as computed by the Hopfield core module + """ + with torch.no_grad(): + return self._associate( + data=input, return_projected_patterns=True, + stored_pattern_padding_mask=stored_pattern_padding_mask, + association_mask=association_mask)[3] + + @property + def batch_first(self) -> bool: + return self.__batch_first + + @property + def scaling(self) -> Union[float, Tensor]: + return self.__scaling.clone() if type(self.__scaling) == Tensor else self.__scaling + + @property + def stored_pattern_dim(self) -> Optional[int]: + return self.association_core.kdim + + @property + def state_pattern_dim(self) -> Optional[int]: + return self.association_core.embed_dim + + @property + def pattern_projection_dim(self) -> Optional[int]: + return self.association_core.vdim + + @property + def input_size(self) -> Optional[int]: + return self.state_pattern_dim + + @property + def hidden_size(self) -> Optional[int]: + return self.association_core.head_dim + + @property + def output_size(self) -> Optional[int]: + return self.association_core.out_dim + + @property + def pattern_size(self) -> Optional[int]: + return self.association_core.pattern_dim + + @property + def update_steps_max(self) -> Optional[Union[int, Tensor]]: + return self.__update_steps_max.clone() if type(self.__update_steps_max) == Tensor else self.__update_steps_max + + @property + def update_steps_eps(self) -> Optional[Union[float, Tensor]]: + return self.__update_steps_eps.clone() if type(self.__update_steps_eps) == Tensor else self.__update_steps_eps + + @property + def stored_pattern_as_static(self) -> bool: + return self.association_core.key_as_static + + @property + def state_pattern_as_static(self) -> bool: + return self.association_core.query_as_static + + @property + def pattern_projection_as_static(self) -> bool: + return self.association_core.value_as_static + + @property + def normalize_stored_pattern(self) -> bool: + return self.norm_stored_pattern is not None + + @property + def normalize_stored_pattern_affine(self) -> bool: + return self.normalize_stored_pattern and self.norm_stored_pattern.elementwise_affine + + @property + def normalize_state_pattern(self) -> bool: + return self.norm_state_pattern is not None + + @property + def normalize_state_pattern_affine(self) -> bool: + return self.normalize_state_pattern and self.norm_state_pattern.elementwise_affine + + @property + def normalize_pattern_projection(self) -> bool: + return self.norm_pattern_projection is not None + + @property + def normalize_pattern_projection_affine(self) -> bool: + return self.normalize_pattern_projection and self.norm_pattern_projection.elementwise_affine + + @property + def normalize_hopfield_space(self) -> bool: + return self.hopfield.normalize_hopfield_space + + @property + def normalize_hopfield_space_affine(self) -> bool: + return self.hopfield.normalize_hopfield_space_affine + + +class HopfieldPooling(Module): + """ + Wrapper class encapsulating a trainable but fixed state pattern and "Hopfield" in + one combined module to be used as a Hopfield-based pooling layer. + """ + + def __init__(self, + input_size: int, + hidden_size: Optional[int] = None, + output_size: Optional[int] = None, + pattern_size: Optional[int] = None, + num_heads: int = 1, + scaling: Optional[Union[float, Tensor]] = None, + update_steps_max: Optional[Union[int, Tensor]] = 0, + update_steps_eps: Union[float, Tensor] = 1e-4, + + normalize_stored_pattern: bool = True, + normalize_stored_pattern_affine: bool = True, + normalize_state_pattern: bool = True, + normalize_state_pattern_affine: bool = True, + normalize_pattern_projection: bool = True, + normalize_pattern_projection_affine: bool = True, + normalize_hopfield_space: bool = False, + normalize_hopfield_space_affine: bool = False, + stored_pattern_as_static: bool = False, + state_pattern_as_static: bool = False, + pattern_projection_as_static: bool = False, + pattern_projection_as_connected: bool = False, + stored_pattern_size: Optional[int] = None, + pattern_projection_size: Optional[int] = None, + + batch_first: bool = True, + association_activation: Optional[str] = None, + dropout: float = 0.0, + input_bias: bool = True, + concat_bias_pattern: bool = False, + add_zero_association: bool = False, + disable_out_projection: bool = False, + quantity: int = 1, + trainable: bool = True + ): + """ + Initialise a new instance of a Hopfield-based pooling layer. + + :param input_size: depth of the input (state pattern) + :param hidden_size: depth of the association space + :param output_size: depth of the output projection + :param pattern_size: depth of patterns to be selected + :param num_heads: amount of parallel association heads + :param scaling: scaling of association heads, often represented as beta (one entry per head) + :param update_steps_max: maximum count of association update steps (None equals to infinity) + :param update_steps_eps: minimum difference threshold between two consecutive association update steps + :param normalize_stored_pattern: apply normalization on stored patterns + :param normalize_stored_pattern_affine: additionally enable affine normalization of stored patterns + :param normalize_state_pattern: apply normalization on state patterns + :param normalize_state_pattern_affine: additionally enable affine normalization of state patterns + :param normalize_pattern_projection: apply normalization on the pattern projection + :param normalize_pattern_projection_affine: additionally enable affine normalization of pattern projection + :param normalize_hopfield_space: enable normalization of patterns in the Hopfield space + :param normalize_hopfield_space_affine: additionally enable affine normalization of patterns in Hopfield space + :param stored_pattern_as_static: interpret specified stored patterns as being static + :param state_pattern_as_static: interpret specified state patterns as being static + :param pattern_projection_as_static: interpret specified pattern projections as being static + :param pattern_projection_as_connected: connect pattern projection with stored pattern + :param stored_pattern_size: depth of input (stored pattern) + :param pattern_projection_size: depth of input (pattern projection) + :param batch_first: flag for specifying if the first dimension of data fed to "forward" reflects the batch size + :param association_activation: additional activation to be applied on the result of the Hopfield association + :param dropout: dropout probability applied on the association matrix + :param input_bias: bias to be added to input (state and stored pattern as well as pattern projection) + :param concat_bias_pattern: bias to be concatenated to stored pattern as well as pattern projection + :param add_zero_association: add a new batch of zeros to stored pattern as well as pattern projection + :param disable_out_projection: disable output projection + :param quantity: amount of state patterns + :param trainable: state pattern used for pooling is trainable + """ + super(HopfieldPooling, self).__init__() + self.hopfield = Hopfield( + input_size=input_size, hidden_size=hidden_size, output_size=output_size, pattern_size=pattern_size, + num_heads=num_heads, scaling=scaling, update_steps_max=update_steps_max, update_steps_eps=update_steps_eps, + normalize_stored_pattern=normalize_stored_pattern, + normalize_stored_pattern_affine=normalize_stored_pattern_affine, + normalize_state_pattern=normalize_state_pattern, + normalize_state_pattern_affine=normalize_state_pattern_affine, + normalize_pattern_projection=normalize_pattern_projection, + normalize_pattern_projection_affine=normalize_pattern_projection_affine, + normalize_hopfield_space=normalize_hopfield_space, + normalize_hopfield_space_affine=normalize_hopfield_space_affine, + stored_pattern_as_static=stored_pattern_as_static, state_pattern_as_static=state_pattern_as_static, + pattern_projection_as_static=pattern_projection_as_static, + pattern_projection_as_connected=pattern_projection_as_connected, stored_pattern_size=stored_pattern_size, + pattern_projection_size=pattern_projection_size, batch_first=batch_first, + association_activation=association_activation, dropout=dropout, input_bias=input_bias, + concat_bias_pattern=concat_bias_pattern, add_zero_association=add_zero_association, + disable_out_projection=disable_out_projection) + self._quantity = quantity + pooling_weight_size = self.hopfield.hidden_size if state_pattern_as_static else self.hopfield.input_size + self.pooling_weights = nn.Parameter(torch.empty(size=(*( + (1, quantity) if batch_first else (quantity, 1) + ), input_size if pooling_weight_size is None else pooling_weight_size)), requires_grad=trainable) + self.reset_parameters() + + def reset_parameters(self) -> None: + """ + Reset pooling weights and underlying Hopfield association. + + :return: None + """ + if hasattr(self.hopfield, r'reset_parameters'): + self.hopfield.reset_parameters() + + # Explicitly initialise pooling weights. + nn.init.normal_(self.pooling_weights, mean=0.0, std=0.02) + + def _prepare_input(self, input: Union[Tensor, Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tensor, Tensor]: + """ + Prepare input for Hopfield association. + + :param input: data to be prepared + :return: stored pattern, expanded state pattern as well as pattern projection + """ + assert (type(input) == Tensor) or ((type(input) == tuple) and (len(input) == 2)), \ + r'either one tensor to be used as "stored pattern" and' \ + r' "pattern_projection" must be provided, or two separate ones.' + if type(input) == Tensor: + stored_pattern, pattern_projection = input, input + else: + stored_pattern, pattern_projection = input + + batch_size = stored_pattern.shape[0 if self.batch_first else 1] + return stored_pattern, self.pooling_weights.expand(size=(*( + (batch_size, self.quantity) if self.batch_first else (self.quantity, batch_size) + ), self.pooling_weights.shape[2])), pattern_projection + + def forward(self, input: Union[Tensor, Tuple[Tensor, Tensor]], stored_pattern_padding_mask: Optional[Tensor] = None, + association_mask: Optional[Tensor] = None) -> Tensor: + """ + Compute Hopfield-based pooling on specified data. + + :param input: data to be pooled + :param stored_pattern_padding_mask: mask to be applied on stored patterns + :param association_mask: mask to be applied on inner association matrix + :return: Hopfield-pooled input data + """ + return self.hopfield( + input=self._prepare_input(input=input), + stored_pattern_padding_mask=stored_pattern_padding_mask, + association_mask=association_mask).flatten(start_dim=1) + + def get_association_matrix(self, input: Union[Tensor, Tuple[Tensor, Tensor]], + stored_pattern_padding_mask: Optional[Tensor] = None, + association_mask: Optional[Tensor] = None) -> Tensor: + """ + Fetch Hopfield association matrix used for pooling gathered by passing through the specified data. + + :param input: data to be passed through the Hopfield association + :param stored_pattern_padding_mask: mask to be applied on stored patterns + :param association_mask: mask to be applied on inner association matrix + :return: association matrix as computed by the Hopfield core module + """ + with torch.no_grad(): + return self.hopfield.get_association_matrix( + input=self._prepare_input(input=input), + stored_pattern_padding_mask=stored_pattern_padding_mask, + association_mask=association_mask) + + def get_projected_pattern_matrix(self, input: Union[Tensor, Tuple[Tensor, Tensor]], + stored_pattern_padding_mask: Optional[Tensor] = None, + association_mask: Optional[Tensor] = None) -> Tensor: + """ + Fetch Hopfield projected pattern matrix gathered by passing through the specified data. + + :param input: data to be passed through the Hopfield association + :param stored_pattern_padding_mask: mask to be applied on stored patterns + :param association_mask: mask to be applied on inner association matrix + :return: pattern projection matrix as computed by the Hopfield core module + """ + with torch.no_grad(): + return self.hopfield.get_projected_pattern_matrix( + input=self._prepare_input(input=input), + stored_pattern_padding_mask=stored_pattern_padding_mask, + association_mask=association_mask) + + @property + def batch_first(self) -> bool: + return self.hopfield.batch_first + + @property + def scaling(self) -> Union[float, Tensor]: + return self.hopfield.scaling + + @property + def stored_pattern_dim(self) -> Optional[int]: + return self.hopfield.stored_pattern_dim + + @property + def state_pattern_dim(self) -> Optional[int]: + return self.hopfield.state_pattern_dim + + @property + def pattern_projection_dim(self) -> Optional[int]: + return self.hopfield.pattern_projection_dim + + @property + def input_size(self) -> Optional[int]: + return self.hopfield.input_size + + @property + def hidden_size(self) -> int: + return self.hopfield.hidden_size + + @property + def output_size(self) -> Optional[int]: + return self.hopfield.output_size + + @property + def pattern_size(self) -> Optional[int]: + return self.hopfield.pattern_size + + @property + def quantity(self) -> int: + return self._quantity + + @property + def update_steps_max(self) -> Optional[Union[int, Tensor]]: + return self.hopfield.update_steps_max + + @property + def update_steps_eps(self) -> Optional[Union[float, Tensor]]: + return self.hopfield.update_steps_eps + + @property + def stored_pattern_as_static(self) -> bool: + return self.hopfield.stored_pattern_as_static + + @property + def state_pattern_as_static(self) -> bool: + return self.hopfield.state_pattern_as_static + + @property + def pattern_projection_as_static(self) -> bool: + return self.hopfield.pattern_projection_as_static + + @property + def normalize_stored_pattern(self) -> bool: + return self.hopfield.normalize_stored_pattern + + @property + def normalize_stored_pattern_affine(self) -> bool: + return self.hopfield.normalize_stored_pattern_affine + + @property + def normalize_state_pattern(self) -> bool: + return self.hopfield.normalize_state_pattern + + @property + def normalize_state_pattern_affine(self) -> bool: + return self.hopfield.normalize_state_pattern_affine + + @property + def normalize_pattern_projection(self) -> bool: + return self.hopfield.normalize_pattern_projection + + @property + def normalize_pattern_projection_affine(self) -> bool: + return self.hopfield.normalize_pattern_projection_affine + + +class HopfieldLayer(Module): + """ + Wrapper class encapsulating a trainable but fixed stored pattern, pattern projection and "Hopfield" in + one combined module to be used as a Hopfield-based pooling layer. + """ + + def __init__(self, + input_size: int, + hidden_size: Optional[int] = None, + output_size: Optional[int] = None, + pattern_size: Optional[int] = None, + num_heads: int = 1, + scaling: Optional[Union[float, Tensor]] = None, + update_steps_max: Optional[Union[int, Tensor]] = 0, + update_steps_eps: Union[float, Tensor] = 1e-4, + lookup_weights_as_separated: bool = False, + lookup_targets_as_trainable: bool = True, + + normalize_stored_pattern: bool = True, + normalize_stored_pattern_affine: bool = True, + normalize_state_pattern: bool = True, + normalize_state_pattern_affine: bool = True, + normalize_pattern_projection: bool = True, + normalize_pattern_projection_affine: bool = True, + normalize_hopfield_space: bool = False, + normalize_hopfield_space_affine: bool = False, + stored_pattern_as_static: bool = False, + state_pattern_as_static: bool = False, + pattern_projection_as_static: bool = False, + pattern_projection_as_connected: bool = False, + stored_pattern_size: Optional[int] = None, + pattern_projection_size: Optional[int] = None, + + batch_first: bool = True, + association_activation: Optional[str] = None, + dropout: float = 0.0, + input_bias: bool = True, + concat_bias_pattern: bool = False, + add_zero_association: bool = False, + disable_out_projection: bool = False, + quantity: int = 1, + trainable: bool = True + ): + """ + Initialise a new instance of a Hopfield-based lookup layer. + + :param input_size: depth of the input (state pattern) + :param hidden_size: depth of the association space + :param output_size: depth of the output projection + :param pattern_size: depth of patterns to be selected + :param num_heads: amount of parallel association heads + :param scaling: scaling of association heads, often represented as beta (one entry per head) + :param update_steps_max: maximum count of association update steps (None equals to infinity) + :param update_steps_eps: minimum difference threshold between two consecutive association update steps + :param lookup_weights_as_separated: separate lookup weights from lookup target weights + :param lookup_targets_as_trainable: employ trainable lookup target weights (used as pattern projection input) + :param normalize_stored_pattern: apply normalization on stored patterns + :param normalize_stored_pattern_affine: additionally enable affine normalization of stored patterns + :param normalize_state_pattern: apply normalization on state patterns + :param normalize_state_pattern_affine: additionally enable affine normalization of state patterns + :param normalize_pattern_projection: apply normalization on the pattern projection + :param normalize_pattern_projection_affine: additionally enable affine normalization of pattern projection + :param normalize_hopfield_space: enable normalization of patterns in the Hopfield space + :param normalize_hopfield_space_affine: additionally enable affine normalization of patterns in Hopfield space + :param stored_pattern_as_static: interpret specified stored patterns as being static + :param state_pattern_as_static: interpret specified state patterns as being static + :param pattern_projection_as_static: interpret specified pattern projections as being static + :param pattern_projection_as_connected: connect pattern projection with stored pattern + :param stored_pattern_size: depth of input (stored pattern) + :param pattern_projection_size: depth of input (pattern projection) + :param batch_first: flag for specifying if the first dimension of data fed to "forward" reflects the batch size + :param association_activation: additional activation to be applied on the result of the Hopfield association + :param dropout: dropout probability applied on the association matrix + :param input_bias: bias to be added to input (state and stored pattern as well as pattern projection) + :param concat_bias_pattern: bias to be concatenated to stored pattern as well as pattern projection + :param add_zero_association: add a new batch of zeros to stored pattern as well as pattern projection + :param disable_out_projection: disable output projection + :param quantity: amount of stored patterns + :param trainable: stored pattern used for lookup is trainable + """ + super(HopfieldLayer, self).__init__() + self.hopfield = Hopfield( + input_size=input_size, hidden_size=hidden_size, output_size=output_size, pattern_size=pattern_size, + num_heads=num_heads, scaling=scaling, update_steps_max=update_steps_max, update_steps_eps=update_steps_eps, + normalize_stored_pattern=normalize_stored_pattern, + normalize_stored_pattern_affine=normalize_stored_pattern_affine, + normalize_state_pattern=normalize_state_pattern, + normalize_state_pattern_affine=normalize_state_pattern_affine, + normalize_pattern_projection=normalize_pattern_projection, + normalize_pattern_projection_affine=normalize_pattern_projection_affine, + normalize_hopfield_space=normalize_hopfield_space, + normalize_hopfield_space_affine=normalize_hopfield_space_affine, + stored_pattern_as_static=stored_pattern_as_static, state_pattern_as_static=state_pattern_as_static, + pattern_projection_as_static=pattern_projection_as_static, + pattern_projection_as_connected=pattern_projection_as_connected, stored_pattern_size=stored_pattern_size, + pattern_projection_size=pattern_projection_size, batch_first=batch_first, + association_activation=association_activation, dropout=dropout, input_bias=input_bias, + concat_bias_pattern=concat_bias_pattern, add_zero_association=add_zero_association, + disable_out_projection=disable_out_projection) + self._quantity = quantity + lookup_weight_size = self.hopfield.hidden_size if stored_pattern_as_static else self.hopfield.stored_pattern_dim + self.lookup_weights = nn.Parameter(torch.empty(size=(*( + (1, quantity) if batch_first else (quantity, 1) + ), input_size if lookup_weight_size is None else lookup_weight_size)), requires_grad=trainable) + + if lookup_weights_as_separated: + target_weight_size = self.lookup_weights.shape[ + 2] if pattern_projection_size is None else pattern_projection_size + self.target_weights = nn.Parameter(torch.empty(size=(*( + (1, quantity) if batch_first else (quantity, 1) + ), target_weight_size)), requires_grad=lookup_targets_as_trainable) + else: + self.register_parameter(name=r'target_weights', param=None) + self.reset_parameters() + + def reset_parameters(self) -> None: + """ + Reset lookup and lookup target weights, including underlying Hopfield association. + + :return: None + """ + if hasattr(self.hopfield, r'reset_parameters'): + self.hopfield.reset_parameters() + + # Explicitly initialise lookup and target weights. + nn.init.normal_(self.lookup_weights, mean=0.0, std=0.02) + if self.target_weights is not None: + nn.init.normal_(self.target_weights, mean=0.0, std=0.02) + + def _prepare_input(self, input: Tensor) -> Tuple[Tensor, Tensor, Tensor]: + """ + Prepare input for Hopfield association. + + :param input: data to be prepared + :return: stored pattern, expanded state pattern as well as pattern projection + """ + batch_size = input.shape[0 if self.batch_first else 1] + stored_pattern = self.lookup_weights.expand(size=(*( + (batch_size, self.quantity) if self.batch_first else (self.quantity, batch_size) + ), self.lookup_weights.shape[2])) + if self.target_weights is None: + pattern_projection = stored_pattern + else: + pattern_projection = self.target_weights.expand(size=(*( + (batch_size, self.quantity) if self.batch_first else (self.quantity, batch_size) + ), self.target_weights.shape[2])) + + return stored_pattern, input, pattern_projection + + def forward(self, input: Tensor, stored_pattern_padding_mask: Optional[Tensor] = None, + association_mask: Optional[Tensor] = None) -> Tensor: + """ + Compute Hopfield-based lookup on specified data. + + :param input: data to used in lookup + :param stored_pattern_padding_mask: mask to be applied on stored patterns + :param association_mask: mask to be applied on inner association matrix + :return: result of Hopfield-based lookup on input data + """ + return self.hopfield( + input=self._prepare_input(input=input), + stored_pattern_padding_mask=stored_pattern_padding_mask, + association_mask=association_mask) + + def get_association_matrix(self, input: Tensor, stored_pattern_padding_mask: Optional[Tensor] = None, + association_mask: Optional[Tensor] = None) -> Tensor: + """ + Fetch Hopfield association matrix used for lookup gathered by passing through the specified data. + + :param input: data to be passed through the Hopfield association + :param stored_pattern_padding_mask: mask to be applied on stored patterns + :param association_mask: mask to be applied on inner association matrix + :return: association matrix as computed by the Hopfield core module + """ + with torch.no_grad(): + return self.hopfield.get_association_matrix( + input=self._prepare_input(input=input), + stored_pattern_padding_mask=stored_pattern_padding_mask, + association_mask=association_mask) + + def get_projected_pattern_matrix(self, input: Union[Tensor, Tuple[Tensor, Tensor]], + stored_pattern_padding_mask: Optional[Tensor] = None, + association_mask: Optional[Tensor] = None) -> Tensor: + """ + Fetch Hopfield projected pattern matrix gathered by passing through the specified data. + + :param input: data to be passed through the Hopfield association + :param stored_pattern_padding_mask: mask to be applied on stored patterns + :param association_mask: mask to be applied on inner association matrix + :return: pattern projection matrix as computed by the Hopfield core module + """ + with torch.no_grad(): + return self.hopfield.get_projected_pattern_matrix( + input=self._prepare_input(input=input), + stored_pattern_padding_mask=stored_pattern_padding_mask, + association_mask=association_mask) + + @property + def batch_first(self) -> bool: + return self.hopfield.batch_first + + @property + def scaling(self) -> Union[float, Tensor]: + return self.hopfield.scaling + + @property + def stored_pattern_dim(self) -> Optional[int]: + return self.hopfield.stored_pattern_dim + + @property + def state_pattern_dim(self) -> Optional[int]: + return self.hopfield.state_pattern_dim + + @property + def pattern_projection_dim(self) -> Optional[int]: + return self.hopfield.pattern_projection_dim + + @property + def input_size(self) -> Optional[int]: + return self.hopfield.input_size + + @property + def hidden_size(self) -> int: + return self.hopfield.hidden_size + + @property + def output_size(self) -> Optional[int]: + return self.hopfield.output_size + + @property + def pattern_size(self) -> Optional[int]: + return self.hopfield.pattern_size + + @property + def quantity(self) -> int: + return self._quantity + + @property + def update_steps_max(self) -> Optional[Union[int, Tensor]]: + return self.hopfield.update_steps_max + + @property + def update_steps_eps(self) -> Optional[Union[float, Tensor]]: + return self.hopfield.update_steps_eps + + @property + def stored_pattern_as_static(self) -> bool: + return self.hopfield.stored_pattern_as_static + + @property + def state_pattern_as_static(self) -> bool: + return self.hopfield.state_pattern_as_static + + @property + def pattern_projection_as_static(self) -> bool: + return self.hopfield.pattern_projection_as_static + + @property + def normalize_stored_pattern(self) -> bool: + return self.hopfield.normalize_stored_pattern + + @property + def normalize_stored_pattern_affine(self) -> bool: + return self.hopfield.normalize_stored_pattern_affine + + @property + def normalize_state_pattern(self) -> bool: + return self.hopfield.normalize_state_pattern + + @property + def normalize_state_pattern_affine(self) -> bool: + return self.hopfield.normalize_state_pattern_affine + + @property + def normalize_pattern_projection(self) -> bool: + return self.hopfield.normalize_pattern_projection + + @property + def normalize_pattern_projection_affine(self) -> bool: + return self.hopfield.normalize_pattern_projection_affine diff --git a/src/mhnfs/hopfield/examples/bit_pattern/modules/activation.py b/src/mhnfs/hopfield/examples/bit_pattern/modules/activation.py new file mode 100644 index 0000000000000000000000000000000000000000..6dcd6e56cddb954cf6b049687bdf5e7783aa2bc9 --- /dev/null +++ b/src/mhnfs/hopfield/examples/bit_pattern/modules/activation.py @@ -0,0 +1,337 @@ +import torch +import torch.nn as nn + +from torch import Tensor +from torch.nn import Linear, Module, Parameter +from typing import Optional + +from .functional import hopfield_core_forward + +try: + from torch.nn.modules.linear import _LinearWithBias +except ImportError: + _LinearWithBias = None + + +class HopfieldCore(Module): + r"""Allows the model to jointly attend to information + from different representation subspaces. + See references: "Hopfield Networks is All You Need" and + "Attention Is All You Need" (on which this implementation is partly based on). + + .. math:: + \text{HopfieldHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O + \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) + + Args: + embed_dim: total dimension of the model. + num_heads: parallel attention heads. + dropout: a Dropout layer on attn_output_weights. Default: 0.0. + bias: add bias as module parameter. Default: True. + add_bias_kv: add bias to the key and value sequences at dim=0. + add_zero_attn: add a new batch of zeros to the key and + value sequences at dim=1. + kdim: total number of features in key. Default: None. + vdim: total number of features in value. Default: None. + + Note: if kdim and vdim are None, they will be set to embed_dim such that + query, key, and value have the same number of features. + + Examples:: + + >>> hopfield_attn = HopfieldCore(embed_dim, num_heads) + >>> attn_output, attn_output_weights, attn_matrix = hopfield_attn(query, key, value) + """ + __annotations__ = { + 'bias_k': torch._jit_internal.Optional[torch.Tensor], + 'bias_v': torch._jit_internal.Optional[torch.Tensor], + } + + def __init__(self, + embed_dim=None, # type: Optional[int] + num_heads=1, # type: int + dropout=0.0, # type: float + bias=True, # type: bool + add_bias_kv=False, # type: bool + add_zero_attn=False, # type: bool + kdim=None, # type: Optional[int] + vdim=None, # type: Optional[int] + + head_dim=None, # type: Optional[int] + pattern_dim=None, # type: Optional[int] + out_dim=None, # type: Optional[int] + disable_out_projection=False, # type: bool + key_as_static=False, # type: bool + query_as_static=False, # type: bool + value_as_static=False, # type: bool + value_as_connected=False, # type: bool + normalize_pattern=False, # type: bool + normalize_pattern_affine=False # type: bool + ): + super(HopfieldCore, self).__init__() + + assert (type(key_as_static) == bool) and (type(query_as_static) == bool) and (type(value_as_static) == bool) + self.key_as_static, self.query_as_static, self.value_as_static = key_as_static, query_as_static, value_as_static + num_non_static = 3 - (self.key_as_static + self.query_as_static + self.value_as_static) + assert 0 <= num_non_static < 4 + + self.value_as_connected = value_as_connected + self.normalize_pattern, self.normalize_pattern_affine = normalize_pattern, normalize_pattern_affine + self.disable_out_projection = disable_out_projection + + # In case of a static-only executions, check corresponding projections and normalizations. + self.static_execution = self._check_execution_mode() + if self.static_execution: + embed_dim, kdim, vdim = None, None, None + if embed_dim is None: + assert self.static_execution, r'static-only execution requires all projections to be deactivated.' + + # Check and set all other properties, conditioned on . + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self._qkv_same_embed_dim = all(( + self.kdim == embed_dim, self.vdim == embed_dim, pattern_dim is None, not self.value_as_connected)) + assert (not self.value_as_connected) or (self.kdim == self.vdim), r'key and value need to be of same dimension.' + + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = None + self.pattern_dim = pattern_dim + self.virtual_hopfield_dim = None + self.virtual_pattern_dim = None + if not self.static_execution: + if head_dim is None: + self.head_dim = embed_dim // num_heads + assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads." + else: + assert head_dim > 0, "dimension of the association space has to be positive." + self.head_dim = head_dim + if self.pattern_dim is None: + self.pattern_dim = self.head_dim + self.virtual_hopfield_dim = self.num_heads * self.head_dim + self.virtual_pattern_dim = self.num_heads * self.pattern_dim + + self.out_dim = embed_dim if out_dim is None else out_dim + assert disable_out_projection or (self.out_dim > 0), "output projection dimension has to be positive." + + if normalize_pattern_affine: + assert normalize_pattern, "affine pattern normalization without pattern normalization has no effect." + self.p_norm_weight = Parameter(torch.Tensor(head_dim)) + self.p_norm_bias = Parameter(torch.Tensor(head_dim)) + else: + self.register_parameter('p_norm_weight', None) + self.register_parameter('p_norm_bias', None) + + if self._qkv_same_embed_dim is False: + if query_as_static: + self.register_parameter('q_proj_weight', None) + else: + self.q_proj_weight = Parameter(torch.Tensor(self.virtual_hopfield_dim, embed_dim)) + if key_as_static: + self.register_parameter('k_proj_weight', None) + else: + self.k_proj_weight = Parameter(torch.Tensor(self.virtual_hopfield_dim, self.kdim)) + if value_as_static: + self.register_parameter('v_proj_weight', None) + else: + self.v_proj_weight = Parameter(torch.Tensor( + self.virtual_pattern_dim, + self.virtual_hopfield_dim if (value_as_connected and not key_as_static) else self.vdim)) + self.register_parameter('in_proj_weight', None) + else: + if num_non_static > 0: + self.in_proj_weight = Parameter(torch.empty( + (not query_as_static) * self.virtual_hopfield_dim + + (not key_as_static) * self.virtual_hopfield_dim + + (not value_as_static) * self.virtual_pattern_dim, embed_dim)) + else: + self.register_parameter('in_proj_weight', None) + self.register_parameter('q_proj_weight', None) + self.register_parameter('k_proj_weight', None) + self.register_parameter('v_proj_weight', None) + + if bias and (num_non_static > 0): + self.in_proj_bias = Parameter(torch.empty( + (not query_as_static) * self.virtual_hopfield_dim + + (not key_as_static) * self.virtual_hopfield_dim + self.virtual_pattern_dim)) + else: + self.register_parameter('in_proj_bias', None) + if disable_out_projection: + self.register_parameter('out_proj', None) + else: + if bias and _LinearWithBias is not None: + self.out_proj = _LinearWithBias(self.virtual_pattern_dim, self.out_dim) + else: + self.out_proj = Linear(self.virtual_pattern_dim, self.out_dim, bias=bias) + + self.bias_k, self.bias_v = None, None + if add_bias_kv: + if not key_as_static: + self.bias_k = Parameter(torch.empty(1, 1, self.virtual_hopfield_dim)) + if not value_as_static: + self.bias_v = Parameter(torch.empty(1, 1, self.virtual_hopfield_dim)) + assert not (self.bias_k is None and self.bias_v is None), r'cannot set key/value bias if both are static.' + + self.add_zero_attn = add_zero_attn + self.reset_parameters() + + def _check_execution_mode(self) -> bool: + return all(( + self.key_as_static, self.query_as_static, self.value_as_static, not self.value_as_connected, + not self.normalize_pattern, not self.normalize_pattern_affine, self.disable_out_projection + )) + + def reset_parameters(self): + if self.p_norm_weight is not None: + nn.init.ones_(self.p_norm_weight) + nn.init.zeros_(self.p_norm_bias) + + if self._qkv_same_embed_dim and (self.in_proj_weight is not None): + nn.init.normal_(self.in_proj_weight, mean=0.0, std=0.02) + else: + if self.q_proj_weight is not None: + nn.init.normal_(self.q_proj_weight, mean=0.0, std=0.02) + if self.k_proj_weight is not None: + nn.init.normal_(self.k_proj_weight, mean=0.0, std=0.02) + if self.v_proj_weight is not None: + nn.init.normal_(self.v_proj_weight, mean=0.0, std=0.02) + + if self.in_proj_bias is not None: + nn.init.constant_(self.in_proj_bias, 0.0) + if not self.disable_out_projection: + nn.init.normal_(self.out_proj.weight, mean=0.0, std=0.02) + if self.out_proj.bias is not None: + nn.init.constant_(self.out_proj.bias, 0.0) + if self.bias_k is not None: + nn.init.normal_(self.bias_k, mean=0.0, std=0.02) + if self.bias_v is not None: + nn.init.normal_(self.bias_v, mean=0.0, std=0.02) + + def __setstate__(self, state): + super(HopfieldCore, self).__setstate__(state) + + def forward(self, + query, # type: Tensor + key, # type: Tensor + value, # type: Tensor + key_padding_mask=None, # type: Optional[Tensor] + need_weights=True, # type: bool + attn_mask=None, # type: Optional[Tensor] + + scaling=None, # type: Optional[Tensor] + update_steps_max=0, # type: Optional[int] + update_steps_eps=1e-4, # type: float + return_raw_associations=False, # type: bool + return_pattern_projections=False # type: bool + ): + # type: (...) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]] + r""" + Args: + query, key, value: map a query and a set of key-value pairs to an output. + See "Attention Is All You Need" for more details. + See "Hopfield Networks is All You Need" for more details in the setting of Hopfield networks. + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. When given a binary mask and a value is True, + the corresponding value on the attention layer will be ignored. When given + a byte mask and a value is non-zero, the corresponding value on the attention + layer will be ignored. + need_weights: output attn_output_weights. + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + + scaling: scaling of association heads, often represented as beta (one entry per head). + update_steps_max: maximum count of association update steps (None equals to infinity). + update_steps_eps: minimum difference threshold between two consecutive association update steps. + return_raw_associations: return raw association (softmax) values, unmodified. + return_pattern_projections: return pattern projection values, unmodified. + + Shape: + - Inputs: + - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the position + with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + + - scaling: :math:`(num_heads,)`, where num_heads is the amount of heads. + + - Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_output_weights: :math:`(N, L, S)` where N is the batch size, + L is the target sequence length, S is the source sequence length. + - attn_raw: :math:``(N, num_heads, L, S)`, where N is the batch size, + L is the target sequence length, S is the source sequence length. + """ + if self.query_as_static and self.key_as_static: + assert query.shape[2] == key.shape[2], \ + f'query shape[2] of {query.shape[2]} and key shape[2] of {key.shape[2]} need to be equal' + head_dim, embed_dim_to_check = query.shape[2], query.shape[2] + else: + assert self.query_as_static or (query.shape[2] == self.embed_dim), \ + f'query shape[2] of {query.shape[2]} invalid, needs to be {self.embed_dim}.' + assert (not self.query_as_static) or (self.query_as_static and query.shape[2] == self.head_dim), \ + f'query shape[2] of {query.shape[2]} invalid, needs to be {self.head_dim}' + + assert self.key_as_static or (key.shape[2] == self.kdim), \ + f'key shape[2] of {key.shape[2]} invalid, needs to be {self.kdim}.' + assert (not self.key_as_static) or (self.key_as_static and key.shape[2] == self.head_dim), \ + f'key shape[2] of {key.shape[2]} invalid, needs to be {self.head_dim}' + head_dim, embed_dim_to_check = self.head_dim, self.head_dim if self.query_as_static else self.embed_dim + + assert self.value_as_static or (value.shape[2] == self.vdim), \ + f'value shape[2] of {value.shape[2]} invalid, needs to be {self.vdim}.' + assert any(( + not self.value_as_static, self.value_as_static and value.shape[2] == self.pattern_dim, + self.disable_out_projection) + ), f'value shape[2] of {value.shape[2]} invalid, needs to be {self.pattern_dim}' + + out_weights, out_bias = None, None + if not self.disable_out_projection: + out_weights, out_bias = self.out_proj.weight, self.out_proj.bias + + if not self._qkv_same_embed_dim: + return hopfield_core_forward( + query=query, key=key, value=value, embed_dim_to_check=embed_dim_to_check, num_heads=self.num_heads, + in_proj_weight=self.in_proj_weight, in_proj_bias=self.in_proj_bias, bias_k=self.bias_k, + bias_v=self.bias_v, add_zero_attn=self.add_zero_attn, dropout_p=self.dropout, + out_proj_weight=out_weights, out_proj_bias=out_bias, training=self.training, + key_padding_mask=key_padding_mask, need_weights=need_weights, attn_mask=attn_mask, + use_separate_proj_weight=True, q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight, + v_proj_weight=self.v_proj_weight, + + key_as_static=self.key_as_static, query_as_static=self.query_as_static, + value_as_static=self.value_as_static, value_as_connected=self.value_as_connected, + normalize_pattern=self.normalize_pattern, + p_norm_weight=self.p_norm_weight, p_norm_bias=self.p_norm_bias, + head_dim=head_dim, pattern_dim=self.pattern_dim, scaling=scaling, + update_steps_max=update_steps_max, update_steps_eps=update_steps_eps, + return_raw_associations=return_raw_associations, return_projected_patterns=return_pattern_projections) + else: + return hopfield_core_forward( + query=query, key=key, value=value, embed_dim_to_check=embed_dim_to_check, num_heads=self.num_heads, + in_proj_weight=self.in_proj_weight, in_proj_bias=self.in_proj_bias, bias_k=self.bias_k, + bias_v=self.bias_v, add_zero_attn=self.add_zero_attn, dropout_p=self.dropout, + out_proj_weight=out_weights, out_proj_bias=out_bias, training=self.training, + key_padding_mask=key_padding_mask, need_weights=need_weights, attn_mask=attn_mask, + + key_as_static=self.key_as_static, query_as_static=self.query_as_static, + value_as_static=self.value_as_static, value_as_connected=self.value_as_connected, + normalize_pattern=self.normalize_pattern, + p_norm_weight=self.p_norm_weight, p_norm_bias=self.p_norm_bias, + head_dim=head_dim, pattern_dim=self.pattern_dim, scaling=scaling, + update_steps_max=update_steps_max, update_steps_eps=update_steps_eps, + return_raw_associations=return_raw_associations, return_projected_patterns=return_pattern_projections) diff --git a/src/mhnfs/hopfield/examples/bit_pattern/modules/functional.py b/src/mhnfs/hopfield/examples/bit_pattern/modules/functional.py new file mode 100644 index 0000000000000000000000000000000000000000..2619c45fd87dc5f0348db393aa9b305611b4dd32 --- /dev/null +++ b/src/mhnfs/hopfield/examples/bit_pattern/modules/functional.py @@ -0,0 +1,450 @@ +import torch +import torch.nn as nn + +from torch.tensor import Tensor +from typing import Optional, Tuple, Union + + +def hopfield_core_forward(query, # type: Tensor + key, # type: Tensor + value, # type: Tensor + embed_dim_to_check, # type: int + num_heads, # type: int + in_proj_weight, # type: Optional[Tensor] + in_proj_bias, # type: Optional[Tensor] + bias_k, # type: Optional[Tensor] + bias_v, # type: Optional[Tensor] + add_zero_attn, # type: bool + dropout_p, # type: float + out_proj_weight, # type: Tensor + out_proj_bias, # type: Tensor + training=True, # type: bool + key_padding_mask=None, # type: Optional[Tensor] + need_weights=True, # type: bool + attn_mask=None, # type: Optional[Tensor] + use_separate_proj_weight=False, # type: bool + q_proj_weight=None, # type: Optional[Tensor] + k_proj_weight=None, # type: Optional[Tensor] + v_proj_weight=None, # type: Optional[Tensor] + static_k=None, # type: Optional[Tensor] + static_v=None, # type: Optional[Tensor] + + key_as_static=False, # type: bool + query_as_static=False, # type: bool + value_as_static=False, # type: bool + value_as_connected=False, # type: bool + normalize_pattern=False, # type: bool + p_norm_weight=None, # type: Optional[Tensor] + p_norm_bias=None, # type: Optional[Tensor] + head_dim=None, # type: Optional[int] + pattern_dim=None, # type: Optional[int] + scaling=None, # type: Optional[Union[float, Tensor]] + update_steps_max=0, # type: Optional[Union[int, Tensor]] + update_steps_eps=1e-4, # type: Union[float, Tensor] + return_raw_associations=False, # type: bool + return_projected_patterns=False # type: bool + ): + # type: (...) -> Tuple[Tensor, Optional[Tensor]] + r""" + Args: + query, key, value: map a query and a set of key-value pairs to an output. + See "Attention Is All You Need" for more details. + See "Hopfield Networks is All You Need" for more details in the setting of Hopfield networks. + embed_dim_to_check: total dimension of the model (in case of default head dimension). + num_heads: parallel attention heads. + in_proj_weight, in_proj_bias: input projection weight and bias. + bias_k, bias_v: bias of the key and value sequences to be added at dim=0. + add_zero_attn: add a new batch of zeros to the key and + value sequences at dim=1. + dropout_p: probability of an element to be zeroed. + out_proj_weight, out_proj_bias: the output projection weight and bias. + training: apply dropout if is ``True``. + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. This is an binary mask. When the value is True, + the corresponding value on the attention layer will be filled with -inf. + need_weights: output attn_output_weights. + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + use_separate_proj_weight: the function accept the proj. weights for query, key, + and value in different forms. If false, in_proj_weight will be used, which is + a combination of q_proj_weight, k_proj_weight, v_proj_weight. + q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias. + static_k, static_v: static key and value used for attention operators. + + key_as_static: interpret specified key as being static. + query_as_static: interpret specified key as being static. + value_as_static: interpret specified key as being static. + value_as_connected: connect value projection with key projection. + normalize_pattern: enable normalization of patterns. + p_norm_weight, p_norm_bias: pattern normalization weight and bias. + head_dim: dimensionality of each head. + pattern_dim: dimensionality of each projected value input. + scaling: scaling of association heads, often represented as beta (one entry per head). + update_steps_max: maximum count of association update steps (None equals to infinity). + update_steps_eps: minimum difference threshold between two consecutive association update steps. + return_raw_associations: return raw association (softmax) values, unmodified. + return_projected_patterns: return pattern projection values, unmodified. + + Shape: + Inputs: + - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions + will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + - static_k: :math:`(N*num_heads, S, head_dim)`, where S is the source sequence length, N is the batch size. + - static_v: :math:`(N*num_heads, S, head_dim)`, where S is the source sequence length, N is the batch size. + + - scaling: :math:`(num_heads,)`, where num_heads is the amount of heads. + + Outputs: + - attn_output: :math:`(L, N, E)`, where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_output_weights: :math:`(N, L, S)`, where N is the batch size, + L is the target sequence length, S is the source sequence length. + - attn_raw: :math:``(N, num_heads, L, S)`, where N is the batch size, + L is the target sequence length, S is the source sequence length. + """ + if not torch.jit.is_scripting(): + tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, + out_proj_weight, out_proj_bias) + if any([type(t) is not Tensor for t in tens_ops]) and nn.functional.has_torch_function(tens_ops): + return nn.functional.handle_torch_function( + hopfield_core_forward, tens_ops, query, key, value, + embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias, + bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight, + out_proj_bias, training=training, key_padding_mask=key_padding_mask, + need_weights=need_weights, attn_mask=attn_mask, + use_separate_proj_weight=use_separate_proj_weight, + q_proj_weight=q_proj_weight, k_proj_weight=k_proj_weight, + v_proj_weight=v_proj_weight, static_k=static_k, static_v=static_v, + key_as_static=key_as_static, query_as_static=query_as_static, + value_as_static=value_as_static, value_as_connected=value_as_connected, + normalize_pattern=normalize_pattern, p_norm_weight=p_norm_weight, p_norm_bias=p_norm_bias, + head_dim=head_dim, pattern_dim=pattern_dim, scaling=scaling, update_steps_max=update_steps_max, + update_steps_eps=update_steps_eps, return_raw_associations=return_raw_associations) + tgt_len, bsz, embed_dim = query.shape[0], value.shape[1], query.shape[2] + assert embed_dim == embed_dim_to_check + # allow MHA to have different sizes for the feature dimension + assert key.size(0) == value.size(0) and key.size(1) == value.size(1) + + assert (scaling is None) or (type(scaling) in (float, torch.Tensor)) + if type(scaling) == torch.Tensor: + assert scaling.ndimension() == 1 and scaling.shape[0] == num_heads, "only one entry per head." + + assert (update_steps_max is None) or (type(update_steps_max) in (int, torch.Tensor)) + if type(update_steps_max) == torch.Tensor: + assert update_steps_max.ndimension() == 1 and update_steps_max.shape[0] == num_heads, "only one entry per head." + elif type(update_steps_max) == int: + update_steps_max = torch.tensor([update_steps_max] * num_heads, dtype=torch.int32, device=query.device) + elif update_steps_max is None: + update_steps_max = -torch.ones(size=(num_heads,), dtype=torch.int32, device=query.device) + + assert type(update_steps_eps) in (float, torch.Tensor) + if type(update_steps_eps) == torch.Tensor: + assert update_steps_eps.ndimension() == 1 and update_steps_eps.shape[0] == num_heads, "only one entry per head." + assert (update_steps_eps <= 0.0).sum() == 0, "only positive thresholds allowed." + update_steps_eps = update_steps_eps.to(device=query.device) + elif type(update_steps_eps) == float: + assert update_steps_eps > 0, "only positive thresholds allowed." + update_steps_eps = torch.tensor([update_steps_eps] * num_heads, dtype=query.dtype, device=query.device) + + # Adapt dimensionality of each each. + if head_dim is None: + head_dim = embed_dim // num_heads + assert head_dim * num_heads == embed_dim, r'embed_dim must be divisible by num_heads.' + hopfield_dim = num_heads * head_dim + + # Adapt dimensionality of each value projection. + if pattern_dim is None: + pattern_dim = head_dim + assert (not value_as_connected) or (pattern_dim == head_dim) + + q, k, v, xi, src_len = None, None, None, None, 0 + update_step, xi_old, xi_difference_norm = 0, None, float(r'+inf') + update_active_heads = torch.tensor([[[True]]] * num_heads * bsz, device=query.device) + assert update_active_heads.any(), "at least one head needs to be active." + + #################################################################################################################### + # BEGIN HOPFIELD UPDATE ITERATION # + #################################################################################################################### + + while update_active_heads.any(): + + # The query is already projected into the "Hopfield" space at "update_step" equals 0. + # No more projection necessary if "update_step" greater than 0. + if update_step == 0: + if not use_separate_proj_weight: + + if torch.equal(query, key) and torch.equal(key, value) and not ( + key_as_static or query_as_static or value_as_static): + # self-attention + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1) + + elif torch.equal(key, value) and not (key_as_static or value_as_static): + # encoder-decoder attention + _start, _end = 0, hopfield_dim + if query_as_static: + q = query.repeat(1, num_heads, 1) + else: + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + q = nn.functional.linear(query, _w, _b) + _start = hopfield_dim + _end = None + + if key is None: + assert value is None + k = None + v = None + else: + + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1) + + else: + _start, _end = 0, hopfield_dim + if query_as_static: + q = query.repeat(1, num_heads, 1) + else: + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + q = nn.functional.linear(query, _w, _b) + _start += hopfield_dim + _end += hopfield_dim + + if key_as_static: + k = key.repeat(1, num_heads, 1) + else: + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + k = nn.functional.linear(key, _w, _b) + _start += hopfield_dim + _end += hopfield_dim + + if value_as_static: + v = value.repeat(1, num_heads, 1) + else: + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + v = nn.functional.linear(value, _w, _b) + else: + _start, _end = 0, hopfield_dim + if query_as_static: + q = query.repeat(1, num_heads, 1) + else: + q_proj_weight_non_opt = torch.jit._unwrap_optional(q_proj_weight) + len1, len2 = q_proj_weight_non_opt.size() + assert len1 == hopfield_dim and len2 == query.size(-1) + if in_proj_bias is not None: + q = nn.functional.linear(query, q_proj_weight_non_opt, in_proj_bias[_start:_end]) + _start += hopfield_dim + _end += hopfield_dim + else: + q = nn.functional.linear(query, q_proj_weight_non_opt, in_proj_bias) + + v = value + if key_as_static: + k = key.repeat(1, num_heads, 1) + else: + k_proj_weight_non_opt = torch.jit._unwrap_optional(k_proj_weight) + len1, len2 = k_proj_weight_non_opt.size() + assert len1 == hopfield_dim and len2 == key.size(-1) + + _bias = None if in_proj_bias is None else in_proj_bias[_start:_end] + k = nn.functional.linear(key, k_proj_weight_non_opt, _bias) + if value_as_connected: + v = nn.functional.linear(v, k_proj_weight_non_opt, _bias) + _start += hopfield_dim + _end += num_heads * pattern_dim + + if value_as_static: + if not (value_as_connected or key_as_static): + v = v.repeat(1, num_heads, 1) + else: + v_proj_weight_non_opt = torch.jit._unwrap_optional(v_proj_weight) + len1, len2 = v_proj_weight_non_opt.size() + assert len1 == (num_heads * pattern_dim) and len2 == v.size(-1) + if in_proj_bias is not None: + v = nn.functional.linear(v, v_proj_weight_non_opt, in_proj_bias[_start:]) + else: + v = nn.functional.linear(v, v_proj_weight_non_opt, in_proj_bias) + + if attn_mask is not None: + assert attn_mask.dtype == torch.float32 or attn_mask.dtype == torch.float64 or \ + attn_mask.dtype == torch.float16 or attn_mask.dtype == torch.uint8 or \ + attn_mask.dtype == torch.bool, \ + 'Only float, byte, and bool types are supported for attn_mask, not {}'.format(attn_mask.dtype) + if attn_mask.dtype == torch.uint8: + warnings.warn( + "Byte tensor for attn_mask in nn.HopfieldCore is deprecated. Use bool tensor instead.") + attn_mask = attn_mask.to(torch.bool) + + if attn_mask.dim() == 2: + attn_mask = attn_mask.unsqueeze(0) + if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: + raise RuntimeError('The size of the 2D attn_mask is not correct.') + elif attn_mask.dim() == 3: + if list(attn_mask.size()) != [bsz * num_heads, query.size(0), key.size(0)]: + raise RuntimeError('The size of the 3D attn_mask is not correct.') + else: + raise RuntimeError("attn_mask's dimension {} is not supported".format(attn_mask.dim())) + # attn_mask's dim is 3 now. + + # Optionally normalize patterns. + if normalize_pattern: + q = torch.nn.functional.layer_norm( + input=q.reshape(shape=(-1, head_dim)), normalized_shape=(head_dim,), + weight=p_norm_weight, bias=p_norm_bias).reshape(shape=q.shape) + k = torch.nn.functional.layer_norm( + input=k.reshape(shape=(-1, head_dim)), normalized_shape=(head_dim,), + weight=p_norm_weight, bias=p_norm_bias).reshape(shape=k.shape) + + else: + active_xi = xi.masked_select(mask=update_active_heads).view(size=(-1, *xi.shape[1:])) + active_k = k.masked_select(mask=update_active_heads).view(size=(-1, *k.shape[1:])) + q = torch.masked_scatter(input=q, mask=update_active_heads, source=torch.bmm(active_xi, active_k)) + + # Optionally scale association heads (each head separately). + if type(scaling) == float: + q = q * scaling + elif type(scaling) == torch.Tensor: + q = q * scaling.view(1, 1, -1).repeat(repeats=(1, 1, q.shape[2] // scaling.shape[0])) + + if update_step == 0: + # convert ByteTensor key_padding_mask to bool + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + warnings.warn( + "Byte tensor for key_padding_mask in nn.HopfieldCore is deprecated. Use bool tensor instead.") + key_padding_mask = key_padding_mask.to(torch.bool) + + if bias_k is not None and bias_v is not None: + if static_k is None and static_v is None and key_as_static is None and value_as_static is None: + k = torch.cat([k, bias_k.repeat(1, bsz, 1)]) + v = torch.cat([v, bias_v.repeat(1, bsz, 1)]) + if attn_mask is not None: + attn_mask = nn.functional.pad(attn_mask, [0, 1]) + if key_padding_mask is not None: + key_padding_mask = nn.functional.pad(key_padding_mask, [0, 1]) + else: + assert static_k is None, "bias cannot be added to static key." + assert static_v is None, "bias cannot be added to static value." + assert not key_as_static, "bias cannot be added to static key." + assert not value_as_static, "bias cannot be added to static value." + else: + assert bias_k is None + assert bias_v is None + + q = q.contiguous().view(tgt_len, -1, head_dim).transpose(0, 1) + if k is not None: + k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) + if v is not None: + v = v.contiguous().view(v.shape[0], bsz * num_heads, -1).transpose(0, 1) + + if static_k is not None: + assert static_k.size(0) == bsz * num_heads + assert static_k.size(2) == head_dim + k = static_k + + if static_v is not None: + assert static_v.size(0) == bsz * num_heads + assert static_v.size(2) == pattern_dim + v = static_v + + src_len = k.size(1) + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz + assert key_padding_mask.size(1) == src_len + + if add_zero_attn: + src_len += 1 + k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1) + v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1) + if attn_mask is not None: + attn_mask = nn.functional.pad(attn_mask, [0, 1]) + if key_padding_mask is not None: + key_padding_mask = nn.functional.pad(key_padding_mask, [0, 1]) + + attn_output_weights = torch.bmm(q, k.transpose(1, 2)) + assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len] + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_output_weights.masked_fill_(attn_mask, float('-inf')) + else: + attn_output_weights += attn_mask + + if key_padding_mask is not None: + attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) + attn_output_weights = attn_output_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2), + float('-inf'), + ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len) + + # Compute new xi for Hopfield retrieve iterations. + if xi is None: + xi = nn.functional.softmax(attn_output_weights, dim=-1) + else: + xi = torch.masked_scatter(input=xi, mask=update_active_heads, source=nn.functional.softmax( + attn_output_weights.masked_select(mask=update_active_heads).view(size=(-1, *xi.shape[1:])), dim=-1)) + + # Compute threshold-based stopping criterion for Hopfield retrieve iterations. + with torch.no_grad(): + xi_active = xi.view(size=(bsz, num_heads, tgt_len, src_len)) + update_active_heads = (update_step < update_steps_max) | (update_steps_max < 0) + if xi_old is not None: + update_active_heads &= ((xi_old - xi_active).norm(p=2, dim=(2, 3)).max(axis=0)[0]) > update_steps_eps + update_active_heads = update_active_heads.unsqueeze(dim=1).unsqueeze(dim=2).repeat(repeats=(bsz, 1, 1)) + xi_old = xi_active + update_step += 1 + + #################################################################################################################### + # END HOPFIELD UPDATE ITERATION # + #################################################################################################################### + + attn_output_weights = nn.functional.dropout(xi, p=dropout_p, training=training) + attn_output = torch.bmm(attn_output_weights, v) + assert list(attn_output.shape[:2]) == [bsz * num_heads, tgt_len] + attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, -1) + if out_proj_weight is not None: + assert attn_output.shape[2] == num_heads * pattern_dim + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) + + xi = xi.view(bsz, num_heads, tgt_len, src_len) if return_raw_associations else None + v = v.view(bsz, num_heads, src_len, -1) if return_projected_patterns else None + if need_weights: + # average attention weights over heads + attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) + return attn_output, attn_output_weights.sum(dim=1) / num_heads, xi, v + else: + return attn_output, None, xi, v diff --git a/src/mhnfs/hopfield/examples/bit_pattern/modules/transformer.py b/src/mhnfs/hopfield/examples/bit_pattern/modules/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..884e0cb4b57610cf1daf8147f2c3d59f17824750 --- /dev/null +++ b/src/mhnfs/hopfield/examples/bit_pattern/modules/transformer.py @@ -0,0 +1,209 @@ +import torch +import torch.nn as nn + +from copy import deepcopy +from torch import Tensor +from torch.nn.modules import Module +from typing import Optional, Tuple, Union + +from . import Hopfield + + +class HopfieldEncoderLayer(Module): + """ + Module with underlying Hopfield association to be used as an encoder in transformer-like architectures. + """ + + def __init__(self, + hopfield_association: Hopfield, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation: str = r'relu' + ): + """ + Initialise a new instance of a Hopfield association-based encoder module. + + :param hopfield_association: instance of Hopfield association module + :param dim_feedforward: depth of the linear projections applied internally + :param activation: activation to be applied on the result of the internal linear projections + :param dropout: dropout probability to be applied internally + """ + super(HopfieldEncoderLayer, self).__init__() + self.hopfield_association = deepcopy(hopfield_association) + + self.linear_residual = nn.Linear(self.hopfield_association.state_pattern_dim, dim_feedforward) + self.dropout_residual = nn.Dropout(dropout) + self.linear_output = nn.Linear(dim_feedforward, self.hopfield_association.state_pattern_dim) + + self.norm_residual = nn.LayerNorm(self.hopfield_association.state_pattern_dim) + self.norm_output = nn.LayerNorm(self.hopfield_association.state_pattern_dim) + self.dropout_hopfield_association = nn.Dropout(dropout) + self.dropout_output = nn.Dropout(dropout) + + self.activation_residual = getattr(torch, activation, None) + assert self.activation_residual is not None, r'invalid activation function supplied.' + self.reset_parameters() + + def reset_parameters(self) -> None: + """ + Reset parameters, including Hopfield association. + + :return: None + """ + for module in (self.hopfield_association, self.linear_residual, + self.linear_output, self.norm_residual, self.norm_output): + if hasattr(module, r'reset_parameters'): + module.reset_parameters() + + def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None) -> Tensor: + """ + Apply Hopfield encoding on specified data. + + :param src: data to be processed by Hopfield encoder module + :param src_mask: mask to be applied on association matrix + :param src_key_padding_mask: mask to be applied on stored patterns + :return: Hopfield-encoded input data + """ + data_associated = self.hopfield_association( + input=src, stored_pattern_padding_mask=src_key_padding_mask, association_mask=src_mask) + src = src + self.dropout_hopfield_association(input=data_associated) + src = self.norm_residual(input=src) + + result_residual_inner = self.activation_residual(input=self.linear_residual(input=src)) + data_associated = self.linear_output(input=self.dropout_residual(input=result_residual_inner)) + src = src + self.dropout_output(input=data_associated) + + return self.norm_output(input=src) + + def get_association_matrix(self, input: Union[Tensor, Tuple[Tensor, Tensor, Tensor]]) -> Tensor: + """ + Fetch Hopfield association matrix gathered by passing through the specified data. + + :param input: data to be passed through the Hopfield association + :return: association matrix as computed by the Hopfield core module + """ + return self.hopfield_association.get_association_matrix(input=input) + + @property + def batch_first(self) -> int: + return self.hopfield_association.batch_first + + @property + def input_size(self) -> int: + return self.hopfield_association.input_size + + @property + def output_size(self) -> int: + return self.linear_output.out_features + + +class HopfieldDecoderLayer(Module): + + def __init__(self, + hopfield_association_self: Hopfield, + hopfield_association_cross: Hopfield, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation: str = r'relu' + ): + """ + Initialise a new instance of a Hopfield association-based encoder module. + + :param hopfield_association_self: instance of Hopfield self-association module + :param hopfield_association_cross: instance of Hopfield cross-association module + :param dim_feedforward: depth of the linear projections applied internally + :param dropout: dropout probability to be applied internally + :param activation: activation to be applied on the result of the internal linear projections + """ + super(HopfieldDecoderLayer, self).__init__() + self.hopfield_association_self = deepcopy(hopfield_association_self) + self.hopfield_association_cross = deepcopy(hopfield_association_cross) + + self.linear_residual = nn.Linear(self.hopfield_association_self.state_pattern_dim, dim_feedforward) + self.dropout_residual = nn.Dropout(dropout) + self.linear_output = nn.Linear(dim_feedforward, self.hopfield_association_self.state_pattern_dim) + + self.norm_residual_self = nn.LayerNorm(self.hopfield_association_self.state_pattern_dim) + self.norm_residual_cross = nn.LayerNorm(self.hopfield_association_self.state_pattern_dim) + self.norm_output = nn.LayerNorm(self.hopfield_association_self.state_pattern_dim) + self.dropout_hopfield_association_self = nn.Dropout(dropout) + self.dropout_hopfield_association_cross = nn.Dropout(dropout) + self.dropout_output = nn.Dropout(dropout) + + self.activation_residual = getattr(torch, activation, None) + assert self.activation_residual is not None, r'invalid activation function supplied.' + self.reset_parameters() + + def reset_parameters(self) -> None: + """ + Reset parameters, including Hopfield association. + + :return: None + """ + for module in (self.hopfield_association_self, self.hopfield_association_cross, + self.linear_residual, self.linear_output, self.norm_residual_self, + self.norm_residual_cross, self.norm_output): + if hasattr(module, r'reset_parameters'): + module.reset_parameters() + + def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None) -> Tensor: + """ + Apply Hopfield decoding on specified data. + + :param tgt: data to be processed by Hopfield decoder module (self-association) + :param memory: data to be processed by Hopfield encoder module (cross-association) + :param tgt_mask: mask to be applied on self-association matrix + :param memory_mask: mask to be applied on cross-association matrix + :param tgt_key_padding_mask: mask to be applied on stored patterns + :param memory_key_padding_mask: mask to be applied on state patterns as well as pattern projection + :return: Hopfield-decoded input + """ + data_associated = self.hopfield_association_self( + input=tgt, stored_pattern_padding_mask=tgt_key_padding_mask, + association_mask=tgt_mask) + tgt = tgt + self.dropout_hopfield_association_self(input=data_associated) + tgt = self.norm_residual_self(input=tgt) + + data_associated = self.hopfield_association_cross( + input=(memory, tgt, memory), stored_pattern_padding_mask=memory_key_padding_mask, + association_mask=memory_mask) + tgt = tgt + self.dropout_hopfield_association_cross(input=data_associated) + tgt = self.norm_residual_cross(input=tgt) + + result_residual_inner = self.activation_residual(input=self.linear_residual(input=tgt)) + data_associated = self.linear_output(input=self.dropout_residual(input=result_residual_inner)) + tgt = tgt + self.dropout_output(input=data_associated) + return self.norm_output(input=tgt) + + def get_association_matrix_self(self, input: Union[Tensor, Tuple[Tensor, Tensor, Tensor]]) -> Tensor: + """ + Fetch Hopfield self-association matrix gathered by passing through the specified data. + + :param input: data to be passed through the Hopfield association + :return: association matrix as computed by the Hopfield core module + """ + return self.hopfield_association_self.get_association_matrix(input=input) + + def get_association_matrix_cross(self, input: Union[Tensor, Tuple[Tensor, Tensor, Tensor]]) -> Tensor: + """ + Fetch Hopfield cross-association matrix gathered by passing through the specified data. + + :param input: data to be passed through the Hopfield association + :return: association matrix as computed by the Hopfield core module + """ + return self.hopfield_association_cross.get_association_matrix(input=input) + + @property + def batch_first(self) -> int: + return self.hopfield_association_self.batch_first + + @property + def input_size(self) -> int: + return self.hopfield_association_self.input_size + + @property + def output_size(self) -> int: + return self.linear_output_self.out_features diff --git a/src/mhnfs/hopfield/examples/bit_pattern/resources/hopfield_adapted.pdf b/src/mhnfs/hopfield/examples/bit_pattern/resources/hopfield_adapted.pdf new file mode 100644 index 0000000000000000000000000000000000000000..ff2b69a1262910659023ec30d589b806dda52785 Binary files /dev/null and b/src/mhnfs/hopfield/examples/bit_pattern/resources/hopfield_adapted.pdf differ diff --git a/src/mhnfs/hopfield/examples/bit_pattern/resources/hopfield_base.pdf b/src/mhnfs/hopfield/examples/bit_pattern/resources/hopfield_base.pdf new file mode 100644 index 0000000000000000000000000000000000000000..bfb48f53d397db6a964c29c6d98868bf413ceba5 Binary files /dev/null and b/src/mhnfs/hopfield/examples/bit_pattern/resources/hopfield_base.pdf differ diff --git a/src/mhnfs/hopfield/examples/bit_pattern/resources/hopfield_lookup.pdf b/src/mhnfs/hopfield/examples/bit_pattern/resources/hopfield_lookup.pdf new file mode 100644 index 0000000000000000000000000000000000000000..04f508591e7386d44cfdb211a9c69b056ed64ad3 Binary files /dev/null and b/src/mhnfs/hopfield/examples/bit_pattern/resources/hopfield_lookup.pdf differ diff --git a/src/mhnfs/hopfield/examples/bit_pattern/resources/hopfield_lookup_adapted.pdf b/src/mhnfs/hopfield/examples/bit_pattern/resources/hopfield_lookup_adapted.pdf new file mode 100644 index 0000000000000000000000000000000000000000..6bf00e7255e2fe19981f78692a7ca43b1b55b3ad Binary files /dev/null and b/src/mhnfs/hopfield/examples/bit_pattern/resources/hopfield_lookup_adapted.pdf differ diff --git a/src/mhnfs/hopfield/examples/bit_pattern/resources/hopfield_pooling.pdf b/src/mhnfs/hopfield/examples/bit_pattern/resources/hopfield_pooling.pdf new file mode 100644 index 0000000000000000000000000000000000000000..176f73ac652d4b1fa7620dee9320a842afceb22e Binary files /dev/null and b/src/mhnfs/hopfield/examples/bit_pattern/resources/hopfield_pooling.pdf differ diff --git a/src/mhnfs/hopfield/examples/bit_pattern/resources/hopfield_pooling_adapted.pdf b/src/mhnfs/hopfield/examples/bit_pattern/resources/hopfield_pooling_adapted.pdf new file mode 100644 index 0000000000000000000000000000000000000000..e58252a1cafc2cdd695675cf94fd31890b4b88e5 Binary files /dev/null and b/src/mhnfs/hopfield/examples/bit_pattern/resources/hopfield_pooling_adapted.pdf differ diff --git a/src/mhnfs/hopfield/examples/latch_sequence/auxiliary/__init__.py b/src/mhnfs/hopfield/examples/latch_sequence/auxiliary/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/mhnfs/hopfield/examples/latch_sequence/auxiliary/data.py b/src/mhnfs/hopfield/examples/latch_sequence/auxiliary/data.py new file mode 100644 index 0000000000000000000000000000000000000000..2c3e393145921d6588d6b0adc41f036b6bd5f4a9 --- /dev/null +++ b/src/mhnfs/hopfield/examples/latch_sequence/auxiliary/data.py @@ -0,0 +1,252 @@ +import torch + +from math import ceil +from torch.utils.data import Dataset +from typing import Dict, Optional, Sequence, Tuple, Union + + +class BitPatternSet(Dataset): + """ + Binary multiple instance learning (MIL) data set comprising bit patterns as instances, + with implanted bit patterns unique to one of the classes. + """ + + def __init__(self, num_bags: int, num_instances: int, num_signals: int, num_signals_per_bag: int = 1, + fraction_targets: float = 0.5, num_bits: int = 8, dtype: torch.dtype = torch.float32, + seed_signals: int = 43, seed_data: int = 44): + """ + Create new binary bit pattern data set conforming to the specified properties. + + :param num_bags: amount of bags + :param num_instances: amount of instances per bag + :param num_signals: amount of unique signals used to distinguish both classes + :param num_signals_per_bag: amount of unique signals to be implanted per bag + :param fraction_targets: fraction of targets + :param num_bits: amount of bits per instance + :param dtype: data type of instances + :param seed_signals: random seed used to generate the signals of the data set (excl. samples) + :param seed_data: random seed used to generate the samples of the data set (excl. signals) + """ + super(BitPatternSet, self).__init__() + assert (type(num_bags) == int) and (num_bags > 0), r'"num_bags" must be a positive integer!' + assert (type(num_instances) == int) and (num_instances > 0), r'"num_instances" must be a positive integer!' + assert (type(num_signals) == int) and (num_signals > 0), r'"num_signals" must be a positive integer!' + assert (type(num_signals_per_bag) == int) and (num_signals_per_bag >= 0) and ( + num_signals_per_bag <= num_instances), r'"num_signals_per_bag" must be a non-negative integer!' + assert (type(fraction_targets) == float) and (fraction_targets > 0) and ( + fraction_targets < 1), r'"fraction_targets" must be in interval (0; 1)!' + assert (type(num_bits) == int) and (num_bits > 0), r'"num_bits" must be a positive integer!' + assert ((2 ** num_bits) - 1) > num_signals, r'"num_signals" must be smaller than "2 ** num_bits - 1"!' + assert type(seed_signals) == int, r'"seed_signals" must be an integer!' + assert type(seed_data) == int, r'"seed_data" must be an integer!' + + self.__num_bags = num_bags + self.__num_instances = num_instances + self.__num_signals = num_signals + self.__num_signals_per_bag = num_signals_per_bag + self.__fraction_targets = fraction_targets + self.__num_targets = min(self.__num_bags, max(1, ceil(self.__num_bags * self.__fraction_targets))) + self.__num_bits = num_bits + self.__dtype = dtype + self.__seed_signals = seed_signals + self.__seed_data = seed_data + self.__data, self.__targets, self.__signals = self._generate_bit_pattern_set() + + def __len__(self) -> int: + """ + Fetch amount of bags. + + :return: amount of bags + """ + return self.__num_bags + + def __getitem__(self, item_index: int) -> Dict[str, torch.Tensor]: + """ + Fetch specific bag. + + :param item_index: specific bag to fetch + :return: specific bag as dictionary of tensors + """ + return {r'data': self.__data[item_index].to(dtype=self.__dtype), + r'target': self.__targets[item_index].to(dtype=self.__dtype)} + + def _generate_bit_pattern_set(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Generate underlying bit pattern data set. + + :return: tuple containing generated bags, targets and signals + """ + torch.random.manual_seed(seed=self.__seed_signals) + + # Generate signal patterns. + generated_signals = torch.randint(low=0, high=2, size=(self.__num_signals, self.__num_bits)) + check_instances = True + while check_instances: + generated_signals = torch.unique(input=generated_signals, dim=0) + generated_signals = generated_signals[generated_signals.sum(axis=1) != 0] + missing_signals = self.__num_signals - generated_signals.shape[0] + if missing_signals > 0: + generated_signals = torch.cat(tensors=( + generated_signals, torch.randint(low=0, high=2, size=(missing_signals, self.__num_bits))), dim=0) + else: + check_instances = False + + # Generate data and target tensors. + torch.random.manual_seed(seed=self.__seed_data) + generated_data = torch.randint(low=0, high=2, size=(self.__num_bags, self.__num_instances, self.__num_bits)) + generated_targets = torch.zeros(size=(self.__num_bags,), dtype=generated_data.dtype) + generated_targets[:self.__num_targets] = 1 + + # Check invalid (all-zero and signal) instances and re-sample them. + check_instances = True + while check_instances: + invalid_instances = (generated_data.sum(axis=2) == 0).logical_or( + torch.sum(torch.stack([(generated_data == _).all(axis=2) for _ in generated_signals]), axis=0)) + if invalid_instances.sum() > 0: + generated_data[invalid_instances] = torch.randint( + low=0, high=2, size=(invalid_instances.sum(), self.__num_bits)) + else: + check_instances = False + + # Re-implant signal into respective bags. + for data_index in range(self.__num_targets): + implantation_indices = [] + for _ in range(self.__num_signals_per_bag): + while True: + current_implantation_index = torch.randint(low=0, high=generated_data.shape[1], size=(1,)) + if current_implantation_index not in implantation_indices: + implantation_indices.append(current_implantation_index) + break + current_signal_index = torch.randint(low=0, high=generated_signals.shape[0], size=(1,)) + generated_data[data_index, current_implantation_index] = generated_signals[current_signal_index] + + return generated_data, generated_targets, generated_signals + + @property + def num_bags(self) -> int: + return self.__num_bags + + @property + def num_instances(self) -> int: + return self.__num_instances + + @property + def num_bits(self) -> int: + return self.__num_bits + + @property + def num_targets_high(self) -> int: + return self.__num_targets + + @property + def num_targets_low(self) -> int: + return self.__num_bags - self.__num_targets + + @property + def num_signals(self) -> int: + return self.__num_signals + + @property + def num_signals_per_bag(self) -> int: + return self.__num_signals_per_bag + + @property + def initial_seed(self) -> int: + return self.__seed + + @property + def bags(self) -> torch.Tensor: + return self.__data.clone() + + @property + def targets(self) -> torch.Tensor: + return self.__targets.clone() + + @property + def signals(self) -> torch.Tensor: + return self.__signals.clone() + + +class LatchSequenceSet(Dataset): + """ + Latch data set comprising patterns as one-hot encoded instances. + """ + + def __init__(self, num_samples: int, num_instances: int = 20, num_characters: int = 6, + dtype: torch.dtype = torch.float32, seed: int = 43): + """ + Create new latch sequence data set conforming to the specified properties. + + :param num_samples: amount of samples + :param num_instances: amount of instances per sample + :param num_characters: amount of different characters + :param dtype: data type of samples + :param seed: random seed used to generate the samples of the data set + """ + super(LatchSequenceSet, self).__init__() + assert (type(num_samples) == int) and (num_samples > 0), r'"num_samples" must be a positive integer!' + assert (type(num_instances) == int) and (num_instances > 0), r'"num_instances" must be a positive integer!' + assert (type(num_characters) == int) and (num_characters > 0), r'"num_characters" must be a positive integer!' + assert type(seed) == int, r'"seed" must be an integer!' + + self.__num_samples = num_samples + self.__num_instances = num_instances + self.__num_characters = num_characters + self.__dtype = dtype + self.__seed = seed + self.__data, self.__targets = self._generate_latch_sequences() + + def __len__(self) -> int: + """ + Fetch amount of samples. + + :return: amount of samples + """ + return self.__num_samples + + def __getitem__(self, item_index: int) -> Dict[str, torch.Tensor]: + """ + Fetch specific sample. + + :param item_index: specific sample to fetch + :return: specific sample as dictionary of tensors + """ + return {r'data': self.__data[item_index].to(dtype=self.__dtype), + r'target': self.__targets[item_index].to(dtype=self.__dtype)} + + def _generate_latch_sequences(self) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Generate underlying latch sequence data set. + + :return: tuple containing generated data and targets + """ + torch.random.manual_seed(seed=self.__seed) + + # Generate data and target tensors. + generated_data = torch.randint( + low=2, high=self.__num_characters, size=(self.__num_samples, self.__num_instances)) + generated_signal = torch.randint(low=0, high=2, size=(self.__num_samples,)) + generated_data[:, 0] = generated_signal + generated_data = torch.nn.functional.one_hot(input=generated_data, num_classes=self.__num_characters) + + return generated_data, generated_signal + + @property + def num_samples(self) -> int: + return self.__num_samples + + @property + def num_instances(self) -> int: + return self.__num_instances + + @property + def num_characters(self) -> int: + return self.__num_characters + + @property + def initial_seed(self) -> int: + return self.__seed + + @property + def targets(self) -> torch.Tensor: + return self.__targets.clone() diff --git a/src/mhnfs/hopfield/examples/latch_sequence/latch_sequence_demo.ipynb b/src/mhnfs/hopfield/examples/latch_sequence/latch_sequence_demo.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..0d01559cd2c1af8c7f49dce43e93de42623b4a49 --- /dev/null +++ b/src/mhnfs/hopfield/examples/latch_sequence/latch_sequence_demo.ipynb @@ -0,0 +1,1047 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Example: Latch Problem

\n", + "\n", + "We study an easy example of learning long-term dependencies by using a simple latch task (see [Hochreiter and Mozer](https://link.springer.com/chapter/10.1007/3-540-44668-0_92)). The essence of this task is that a sequence of inputs is presented, beginning with one of two symbols, A or B, and after a variable number of time steps, the model has to output a corresponding symbol. Thus, the task requires memorizing the original input over time. It has to be noted, that both class-defining symbols must only appear at the first position of a sequence. This task was specifically designed to demonstrate the capability of recurrent neural networks to capture long term dependencies. This demonstration shows, that Hopfield, HopfieldPooling and HopfieldLayer adapt extremely fast to this specific task, concentrating only on the first entry of the sequence.\n", + "\n", + "This demonstration instructs how to apply Hopfield, HopfieldPooling and HopfieldLayer for an exemplary sequential task, potentially substituting LSTM and GRU layers.\n", + "\n", + "NOTA BENE: No tweeking of the exemplary LSTM network is done. The focus is put on the technical details. Feel free to tune yourself and see what works better :)\n", + "\n", + "

In the chapters Adapt Hopfield-based Network, Adapt Hopfield-based Pooling and Adapt Hopfield-based Lookup you can explore and try the new functionalities of our new Hopfield layer.

\n", + "\n", + "In order to run this notebook, a few modules need to be imported. The installation of third-party modules is not covered here." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "# Import general modules used e.g. for plotting.\n", + "import matplotlib.pyplot as plt\n", + "import os\n", + "import pandas as pd\n", + "import seaborn as sns\n", + "import sys\n", + "import torch\n", + "\n", + "# Importing Hopfield-specific modules.\n", + "from auxiliary.data import LatchSequenceSet\n", + "from modules import Hopfield, HopfieldPooling, HopfieldLayer\n", + "\n", + "# Import auxiliary modules.\n", + "from distutils.version import LooseVersion\n", + "from typing import List, Tuple\n", + "\n", + "# Importing PyTorch specific modules.\n", + "from torch import Tensor\n", + "from torch.nn import Flatten, Linear, LSTM, Module, Sequential\n", + "from torch.nn.functional import binary_cross_entropy_with_logits\n", + "from torch.nn.utils import clip_grad_norm_\n", + "from torch.optim import AdamW\n", + "from torch.utils.data import DataLoader\n", + "from torch.utils.data.sampler import SubsetRandomSampler\n", + "\n", + "# Set plotting style.\n", + "sns.set()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Specific minimum versions of Python itself as well as of some used modules is required." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Installed Python version: 3.8.8 (βœ“)\n", + "Installed PyTorch version: 1.7.0 (βœ“)\n" + ] + } + ], + "source": [ + "python_check = '(\\u2713)' if sys.version_info >= (3, 8) else '(\\u2717)'\n", + "pytorch_check = '(\\u2713)' if torch.__version__ >= LooseVersion(r'1.5') else '(\\u2717)'\n", + "\n", + "print(f'Installed Python version: {sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro} {python_check}')\n", + "print(f'Installed PyTorch version: {torch.__version__} {pytorch_check}')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Create Dataset

\n", + "\n", + "We study an easy example of learning long-term dependencies by using a simple latch task. \n", + "The latch task was introcuded by Hochreiter and Mozer:
\n", + "Sepp Hochreiter, Michael Mozer, 2001. A discrete probabilistic memory model for discovering dependencies in time. Artificial Neural Networks -- ICANN 2001, 13, pp.661-668.

\n", + "The essence of this task is that a sequence of inputs is presented, beginning with one of two symbols, A or B, and after a variable number of time steps, the model has to output a corresponding symbol. Thus, the task requires memorizing the original input over time. It has to be noted, that both class-defining symbols must only appear at the first position of an instance. Defining arguments are:\n", + "

\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
ArgumentValue (used in this demo)Description
num_samples4096Amount of samples of the full dataset.
num_instances32Amount of instances per sample (sample length).
num_characters20Amount of different characters (size of the one-hot encoded vector).
\n", + "\n", + "Let's define the dataset using previously mentioned properties as well as a logging directory for storing all auxiliary outputs like performance plots." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "latch_sequence_set = LatchSequenceSet(\n", + " num_samples=4096,\n", + " num_instances=32,\n", + " num_characters=20)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "log_dir = f'resources/'\n", + "os.makedirs(log_dir, exist_ok=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Create Auxiliaries

\n", + "\n", + "Before digging into Hopfield-based networks, a few auxiliary variables and functions need to be defined. This is nothing special with respect to Hopfield-based networks, but rather common preparation work of (almost) every machine learning setting (e.g. definition of a data loader as well as a training loop). We will see, that this comprises the most work of this whole demo." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "device = torch.device(r'cuda:0' if torch.cuda.is_available() else r'cpu')\n", + "\n", + "# Create data loader of training set.\n", + "sampler_train = SubsetRandomSampler(list(range(512, 4096 - 512)))\n", + "data_loader_train = DataLoader(dataset=latch_sequence_set, batch_size=32, sampler=sampler_train)\n", + "\n", + "# Create data loader of validation set.\n", + "sampler_eval = SubsetRandomSampler(list(range(512)) + list(range(4096 - 512, 4096)))\n", + "data_loader_eval = DataLoader(dataset=latch_sequence_set, batch_size=32, sampler=sampler_eval)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "def train_epoch(network: Module,\n", + " optimiser: AdamW,\n", + " data_loader: DataLoader\n", + " ) -> Tuple[float, float]:\n", + " \"\"\"\n", + " Execute one training epoch.\n", + " \n", + " :param network: network instance to train\n", + " :param optimiser: optimiser instance responsible for updating network parameters\n", + " :param data_loader: data loader instance providing training data\n", + " :return: tuple comprising training loss as well as accuracy\n", + " \"\"\"\n", + " network.train()\n", + " losses, accuracies = [], []\n", + " for sample_data in data_loader:\n", + " data, target = sample_data[r'data'], sample_data[r'target']\n", + " data, target = data.to(device=device), target.to(device=device)\n", + "\n", + " # Process data by Hopfield-based network.\n", + " model_output = network.forward(input=data)\n", + "\n", + " # Update network parameters.\n", + " optimiser.zero_grad()\n", + " loss = binary_cross_entropy_with_logits(input=model_output, target=target, reduction=r'mean')\n", + " loss.backward()\n", + " clip_grad_norm_(parameters=network.parameters(), max_norm=1.0, norm_type=2)\n", + " optimiser.step()\n", + "\n", + " # Compute performance measures of current model.\n", + " accuracy = (model_output.sigmoid().round() == target).to(dtype=torch.float32).mean()\n", + " accuracies.append(accuracy.detach().item())\n", + " losses.append(loss.detach().item())\n", + " \n", + " # Report progress of training procedure.\n", + " return (sum(losses) / len(losses), sum(accuracies) / len(accuracies))\n", + "\n", + "\n", + "def eval_iter(network: Module,\n", + " data_loader: DataLoader\n", + " ) -> Tuple[float, float]:\n", + " \"\"\"\n", + " Evaluate the current model.\n", + " \n", + " :param network: network instance to evaluate\n", + " :param data_loader: data loader instance providing validation data\n", + " :return: tuple comprising validation loss as well as accuracy\n", + " \"\"\"\n", + " network.eval()\n", + " with torch.no_grad():\n", + " losses, accuracies = [], []\n", + " for sample_data in data_loader:\n", + " data, target = sample_data[r'data'], sample_data[r'target']\n", + " data, target = data.to(device=device), target.to(device=device)\n", + "\n", + " # Process data by Hopfield-based network.\n", + " model_output = network.forward(input=data)\n", + " loss = binary_cross_entropy_with_logits(input=model_output, target=target, reduction=r'mean')\n", + "\n", + " # Compute performance measures of current model.\n", + " accuracy = (model_output.sigmoid().round() == target).to(dtype=torch.float32).mean()\n", + " accuracies.append(accuracy.detach().item())\n", + " losses.append(loss.detach().item())\n", + "\n", + " # Report progress of validation procedure.\n", + " return (sum(losses) / len(losses), sum(accuracies) / len(accuracies))\n", + "\n", + "\n", + "def operate(network: Module,\n", + " optimiser: AdamW,\n", + " data_loader_train: DataLoader,\n", + " data_loader_eval: DataLoader,\n", + " num_epochs: int = 1\n", + " ) -> Tuple[pd.DataFrame, pd.DataFrame]:\n", + " \"\"\"\n", + " Train the specified network by gradient descent using backpropagation.\n", + " \n", + " :param network: network instance to train\n", + " :param optimiser: optimiser instance responsible for updating network parameters\n", + " :param data_loader_train: data loader instance providing training data\n", + " :param data_loader_eval: data loader instance providing validation data\n", + " :param num_epochs: amount of epochs to train\n", + " :return: data frame comprising training as well as evaluation performance\n", + " \"\"\"\n", + " losses, accuracies = {r'train': [], r'eval': []}, {r'train': [], r'eval': []}\n", + " for epoch in range(num_epochs):\n", + " \n", + " # Train network.\n", + " performance = train_epoch(network, optimiser, data_loader_train)\n", + " losses[r'train'].append(performance[0])\n", + " accuracies[r'train'].append(performance[1])\n", + " \n", + " # Evaluate current model.\n", + " performance = eval_iter(network, data_loader_eval)\n", + " losses[r'eval'].append(performance[0])\n", + " accuracies[r'eval'].append(performance[1])\n", + " \n", + " # Report progress of training and validation procedures.\n", + " return pd.DataFrame(losses), pd.DataFrame(accuracies)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "def set_seed(seed: int = 42) -> None:\n", + " \"\"\"\n", + " Set seed for all underlying (pseudo) random number sources.\n", + " \n", + " :param seed: seed to be used\n", + " :return: None\n", + " \"\"\"\n", + " torch.manual_seed(42)\n", + " torch.backends.cudnn.deterministic = True\n", + " torch.backends.cudnn.benchmark = False\n", + "\n", + "\n", + "def plot_performance(loss: pd.DataFrame,\n", + " accuracy: pd.DataFrame,\n", + " log_file: str\n", + " ) -> None:\n", + " \"\"\"\n", + " Plot and save loss and accuracy.\n", + " \n", + " :param loss: loss to be plotted\n", + " :param accuracy: accuracy to be plotted\n", + " :param log_file: target file for storing the resulting plot\n", + " :return: None\n", + " \"\"\"\n", + " fig, ax = plt.subplots(1, 2, figsize=(20, 7))\n", + " \n", + " loss_plot = sns.lineplot(data=loss, ax=ax[0])\n", + " loss_plot.set(xlabel=r'Epoch', ylabel=r'Cross-entropy Loss')\n", + " \n", + " accuracy_plot = sns.lineplot(data=accuracy, ax=ax[1])\n", + " accuracy_plot.set(xlabel=r'Epoch', ylabel=r'Accuracy')\n", + " \n", + " ax[1].yaxis.set_label_position(r'right')\n", + " fig.tight_layout()\n", + " fig.savefig(log_file)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

LSTM-based Network

\n", + "\n", + "The instantiation of the heart of an LSTM-based network, the module LSTM, is rather straightforward. Only two arguments, the size of the input as well as the site of the hidden state, need to be set.\n", + "

\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
ArgumentValue (used in this demo)Description
input_sizenum_characters (20)Size (depth) of the input.
hidden_size4Size (depth) of the hidden state.
...defaultThe remaining arguments are not explicitly used in this example.
\n", + "\n", + "An additional output projection is defined, to downproject the hidden state of the last time step of the LSTM to the correct output size. Afterwards, everything is wrapped into a container of type torch.nn.Sequential and a corresponding optimiser is defined." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "class LSTMNetwork(Module):\n", + " def __init__(self, input_size: int, hidden_size: int):\n", + " \"\"\"\n", + " Initialize a new instance of an LSTM-based network.\n", + " \n", + " :param input size: size (depth) of the input\n", + " :param hidden_size: size (depth) of the hidden state\n", + " \"\"\"\n", + " super(LSTMNetwork, self).__init__()\n", + " self.lstm = LSTM(input_size, hidden_size, batch_first=True)\n", + " self.projection = Linear(hidden_size, 1)\n", + " \n", + " def forward(self, input: Tensor) -> Tensor:\n", + " \"\"\"\n", + " Compute result of LSTM-based network on specified data.\n", + " \n", + " :param input: data to be processed by the LSTM-based network\n", + " :return: result as computed by the LSTM-based network\n", + " \"\"\"\n", + " out, _ = self.lstm.forward(input=input) \n", + " return self.projection.forward(input=out[:, -1, :]).flatten()" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "set_seed()\n", + "network = LSTMNetwork(\n", + " input_size=latch_sequence_set.num_characters,\n", + " hidden_size=4).to(device=device)\n", + "optimiser = AdamW(params=network.parameters(), lr=1e-3)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Operate LSTM-based Network

" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "losses, accuracies = operate(\n", + " network=network,\n", + " optimiser=optimiser,\n", + " data_loader_train=data_loader_train,\n", + " data_loader_eval=data_loader_eval,\n", + " num_epochs=20)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plot_performance(loss=losses, accuracy=accuracies, log_file=f'{log_dir}/lstm_base.pdf')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Hopfield-based Network

\n", + "\n", + "The instantiation of the heart of a Hopfield-based network, the module Hopfield, is even simpler. Only one argument, the size of the input, needs to be set.\n", + "

\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
ArgumentValue (used in this demo)Description
input_sizenum_characters (20)Size (depth) of the input (state pattern).
...defaultThe remaining arguments are not explicitly used in this example.
\n", + "\n", + "An additional output projection is defined, to downproject the result of Hopfield to the correct output size. Afterwards, everything is wrapped into a container of type torch.nn.Sequential and a corresponding optimiser is defined. Now the Hopfield-based network and all auxiliaries are set up and ready to associate!" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "set_seed()\n", + "hopfield = Hopfield(\n", + " input_size=latch_sequence_set.num_characters)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "output_projection = Linear(in_features=hopfield.output_size * latch_sequence_set.num_instances, out_features=1)\n", + "network = Sequential(hopfield, Flatten(), output_projection, Flatten(start_dim=0)).to(device=device)\n", + "optimiser = AdamW(params=network.parameters(), lr=1e-3)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Operate Hopfield-based Network

" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "losses, accuracies = operate(\n", + " network=network,\n", + " optimiser=optimiser,\n", + " data_loader_train=data_loader_train,\n", + " data_loader_eval=data_loader_eval,\n", + " num_epochs=20)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plot_performance(loss=losses, accuracy=accuracies, log_file=f'{log_dir}/hopfield_base.pdf')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Adapt Hopfield-based Network

\n", + "

We can now explore the functionality of our Hopfield layer Hopfield.

\n", + "\n", + "As described in the paper the Hopfield layer allows:\n", + "- association of two sets\n", + "- multiple updates\n", + "- variable beta\n", + "- changing the dimension of the associative space\n", + "- pattern normalization\n", + "- static patterns for fixed pattern search\n", + "\n", + "This time, additional arguments are set to influence the training as well as the validation performance of the Hopfield-based network.\n", + "

\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
ArgumentValue (used in this demo)Description
input_sizenum_characters (20)Size (depth) of the input (state pattern).
hidden_size8Size (depth) of the association space.
num_heads8Amount of parallel association heads.
update_steps_max3Number of updates in one Hopfield head.
scaling0.25Beta parameter that determines the kind of fixed point.
dropout0.5Dropout probability applied on the association matrix.
...defaultThe remaining arguments are not explicitly used in this example.
" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "set_seed()\n", + "hopfield = Hopfield(\n", + " input_size=latch_sequence_set.num_characters,\n", + " hidden_size=8,\n", + " num_heads=8,\n", + " update_steps_max=3,\n", + " scaling=0.25,\n", + " dropout=0.5)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "output_projection = Linear(in_features=hopfield.output_size * latch_sequence_set.num_instances, out_features=1)\n", + "network = Sequential(hopfield, Flatten(), output_projection, Flatten(start_dim=0)).to(device=device)\n", + "optimiser = AdamW(params=network.parameters(), lr=1e-3)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "losses, accuracies = operate(\n", + " network=network,\n", + " optimiser=optimiser,\n", + " data_loader_train=data_loader_train,\n", + " data_loader_eval=data_loader_eval,\n", + " num_epochs=20)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plot_performance(loss=losses, accuracy=accuracies, log_file=f'{log_dir}/hopfield_adapted.pdf')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Hopfield-based Pooling

\n", + "\n", + "The previous examples manually downprojected the result of Hopfield by applying a linear layer afterwards. It would've also been possible to apply some kind of pooling. Exactly for such use cases, the module HopfieldPooling might be the right choice. Internally, a state pattern is trained, which in turn is used to compute pooling weights with respect to the input." + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "set_seed()\n", + "hopfield_pooling = HopfieldPooling(\n", + " input_size=latch_sequence_set.num_characters)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "output_projection = Linear(in_features=hopfield_pooling.output_size, out_features=1)\n", + "network = Sequential(hopfield_pooling, output_projection, Flatten(start_dim=0)).to(device=device)\n", + "optimiser = AdamW(params=network.parameters(), lr=1e-3)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Operate Hopfield-based Pooling

" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "losses, accuracies = operate(\n", + " network=network,\n", + " optimiser=optimiser,\n", + " data_loader_train=data_loader_train,\n", + " data_loader_eval=data_loader_eval,\n", + " num_epochs=20)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plot_performance(loss=losses, accuracy=accuracies, log_file=f'{log_dir}/hopfield_pooling.pdf')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Adapt Hopfield-based Pooling

\n", + "

We can now again explore the functionality of our Hopfield-based pooling layer HopfieldPooling.

\n", + "\n", + "Again, additional arguments are set to influence the training as well as the validation performance of the Hopfield-based pooling.\n", + "

\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
ArgumentValue (used in this demo)Description
input_sizenum_characters (20)Size (depth) of the input (state pattern).
hidden_size8Size (depth) of the association space.
num_heads8Amount of parallel association heads.
update_steps_max3Number of updates in one Hopfield head.
scaling0.25Beta parameter that determines the kind of fixed point.
dropout0.5Dropout probability applied on the association matrix.
...defaultThe remaining arguments are not explicitly used in this example.
" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [], + "source": [ + "set_seed()\n", + "hopfield_pooling = HopfieldPooling(\n", + " input_size=latch_sequence_set.num_characters,\n", + " hidden_size=8,\n", + " num_heads=8,\n", + " update_steps_max=3,\n", + " scaling=0.25,\n", + " dropout=0.5)" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [], + "source": [ + "output_projection = Linear(in_features=hopfield_pooling.output_size, out_features=1)\n", + "network = Sequential(hopfield_pooling, output_projection, Flatten(start_dim=0)).to(device=device)\n", + "optimiser = AdamW(params=network.parameters(), lr=1e-3)" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [], + "source": [ + "losses, accuracies = operate(\n", + " network=network,\n", + " optimiser=optimiser,\n", + " data_loader_train=data_loader_train,\n", + " data_loader_eval=data_loader_eval,\n", + " num_epochs=20)" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAABZQAAAHsCAYAAABFbSiOAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAACIHklEQVR4nOzdeXxU5d3///eZLclsWWBCABUXrKhAXaggKmpVUCRqFXeK1KWtvS23/FpaXFoVq9ZWS9WqVdrq162VKi5Yi7jUpUWtercCrlWrIkgIJCQzmUlmOef3xyQjMQlJJpkt83o+bh9kzrnOOZ+5wLvjm2s+l2FZliUAAAAAAAAAAHphy3UBAAAAAAAAAIDCQKAMAAAAAAAAAOgTAmUAAAAAAAAAQJ8QKAMAAAAAAAAA+oRAGQAAAAAAAADQJwTKAAAAAAAAAIA+IVAGAAAAAAAAAPSJI9cFDERjY4tM08rqM4cN82rr1lBWnzlUMHfpY+7Sw7ylj7lLH3OXPuYufcxd/9hshiorPbkuo9/47Fs4mLf0MXfpY+7Sx9ylj7lLD/OWPuaufzL1ubegA2XTtLL+obrjuUgPc5c+5i49zFv6mLv0MXfpY+7Sx9wNfXz2LSzMW/qYu/Qxd+lj7tLH3KWHeUsfc5d7tLwAAAAAAAAAAPQJgTIAAAAAAAAAoE8KuuUFAAAA+i6RiKuxsV7xeDTXpWSFw+FSZWVAdjsfeQEAAIDBwqdrAACAItHYWK/SUrc8nhoZhpHrcjLKsiy1tDSrsbFew4ePzHU5AAAAwJBBywsAAIAiEY9H5fH4h3yYLEmGYcjj8RfNamwAAAAgWwiUAQAAikgxhMkdium9AgAAANlCoAwAAAAAAAAA6BMCZQAAAGRdKBTSJZf8sM/j3333bf3851dnsCIAAAAAfcGmfAAAAMi6YLBZ//nPe30eP27cPlq0aJ8MVgQAAACgLwiUAQAAkHW//vUvtWVLvS655If65JP/qry8QiUlJbrmml/ouuuuVn39Zm3ZUq9Jkw7SokU/0b/+9Yb+8Ic79Zvf3KmLLvq29tlnX7355r+1bVujLr54oQ4++JBcvyUAAACgKBAoAwAAFKF/rP1cf1/zeUbufejEkTpkwsgdjrn44oX6/ve/o/nz/z+deuoJ+vOfb9HIkaP09NMrteeeX9HPfna9YrGY5sw5Ve+9926X62OxuO644y79/e8vaunS2wmUAQAAgCwhUAYAAEBOVVZWaeTIUZKkY445Vm+/vU7Llj2gjz/+r5qamhSJhLtcM3nywZKk3XffQ8Fgc1brBQAAAIoZgTIAAEAROmRC76uIs6WkpCT180MP/UnPP/+cTjjhG5o9+yD9978fyrKsLte4XC5JkmEY3Z4HAAAAkBm2XBcAAACA4mO325VIJLocf+21V3XCCSdr+vTjFI1G9Z//vC/TNHNQYc9CoZBmzZqlzz77rMu5d955RyeffLJmzJihyy67TPF4XJK0ceNGnX322Tr22GN14YUXqqWlJdtlAwAAAIOCQBkAAABZV1U1TCNG1Ojaa6/qdPy0087SXXfdqblzT9dNN92o8eMn6vPPN+aoyq7efPNNnXnmmfr444+7Pb9w4UL99Kc/1VNPPSXLsrRs2TJJ0lVXXaWzzjpLK1eu1Pjx43XbbbdlsWoAAABg8NDyAgAAAFnncDj029/+ocvxAw/8mv74x+XdXnPAAZMkSb/5zZ2pYyNHjtJDD63ITJHdWLZsma644gr96Ec/6nJuw4YNam1t1X777SdJOvnkk3XzzTfr1FNP1WuvvaZbb701dXzOnDlauHBh1uoGAAAABguBcj/8Y+3nilmf64iJ+dFvEAAAANl1zTXX9Hhu8+bNCgQCqdeBQEB1dXVqbGyU1+uVw+HodHwossyElIjJcJZKkhJb18uKRaRYm6x4W+pX51cOleFwKfb+35XY/NF256JSvE2uA0+SY9Tein34qqL/6voXBo7dD1LJASfIbN6syKqbu5y3+QIqm/G/kqTwiutktXVtMeKetUhGqVetr/xJic/WdTlfMvl0OXaekFYNbQ6bzLJhOa0hH+YhnRo+2/S24vHObW76W4MlyTItmZZkWZZipVXaOH6eIm1x7bLmDtliYZmmJdOyZFnJcctdJ6kp5tK0xD+0q7W+yzOetx2sj41dtJf5gQ623uhy/l1jrF6xHagKq0knmSu7nN8mvx61HydJOj3xmMrU2mXMn2wnqtUo1RHm6rRqWGnfU6t1QE5ryId5SKeGjwxD2q4ff7HOQzo1fObcTXtE3yv6eehvDdv/mSvmeUinhrjTo0OjLxXkPPzLe7hOOevELtcUIgLlfnjrvw36uC5IoAwAAIAuTNOUYRip15ZlpTYN3P64pC6v+2LYMO+Aa+yPtk3/VeNLf5Ut1iYr2ioz1ior2qbSMfuqfNJxim2r06Y/XSMz2ior1iYz1iol4nIOG6Wdv3uLJOmT+29UomVbl3uPOOAwOXw+bX7lQ0X/+7pszhIZrlLZnSUynKWqKC9TWcCncFOVmgOju1zvDgTkD/gUL2nTlm7OO3zDNDzgkyRtDoyU2RbpMmZ4wC9bqUfbAtVqbet6j4pAlUqHYg1WMkhNJCzFE5a2tZlqCUdkxEpls1cqYSYDDkOSYUjbNrcpGmuQtzEhn7MqdVyGIUNSPFGqRGNEjtaEStwBdfzRNtrPW54qRWXIZhgy/AEZMZ8MQzJktN9H8lV65SzzKDQ8oEhrjZSwlDBNxdt/3Ro01fTJNtnqovJZ/tTxRMJS3LT08bomvb7mNTkijTrKsqfeQ4dtZlzL310jSTrLY1OpUSJDkt1uyGG3yW4z5ChxqabcK0e4Uq3Rrr9Xld4KJUr8Km8tV2vL8C7nS0srtbPHL3dcam3qet6y+7RzhV+SlGisUqsZ7TJmVGW5YrYSOYLp1eAqrch5DfkwD9SQ7Ro8eVBDPswDNVBD7zUMC1Qq0P6/y4XOsAp4W+ytW0MyzeyV/6dn/6MX39yo2/6/w7P2zKEkEPCpvj6Y6zIKEnOXHuYtfcxd+pi79DF36evr3G3a9IlqasZkoaL80d17ttmMAYezX//613XPPfdop512Sh3bsGGD5s2bp6efflqS9Prrr+vmm2/W73//e02ePFmvvfaa7Ha7Pv/8c82ZM0fPPvtsv56Z7c++sfdeUusLv5fsDslRIsNRIsNZklwNOukbMsNNalt9X6dzcpbIKCuXa+8jJEnx9WuTwaOj/VzHr6V+Gbahu51Lpv7/mWlZaosmFGmLt/+TUDj1c1yRaPuvrT0cb0te++XAtdA47Da5S+wqK3Gk/nGXOFTafszdh+Muhy2tv9jJZ/zvaPqYu/Qxd+lh3tLH3PXPYHzu7Q4rlPvB73GpNZpQWzShEpc91+UAAAAgj4wePVolJSV64403dOCBB+qxxx7TtGnT5HQ6NWnSJD355JOqra3Vo48+qmnTpuW63F459pyq3abO0Jat4W7P29zlKjv6f3Z8j50nZKK0jEqYprY0tSoaM1NtERKmlfzZtJSw2n81LVntv5rbHTNNS25PiZqbI19cZyXvmxpjKXW/7e/fca943OwcCLfFFW5LqLUtrt6iYMNQMkR1dQSqdlV6SzRquOeLgNVl7xSubh+2lpbYZUjtNXfUZaZq/vL7/vLcWGbX9/Pln788D9vf3+stkRlPqGy7+stKt6vV5ZDTMXT/MgIAgEJAoNwPfrdLktQUjqraVZbjagAAAJAPLrjgAs2fP18TJkzQDTfcoMsvv1yhUEj77ruv5s6dK0m64oortGjRIt1+++0aOXKkfvWrX+W46t4ZNrsM29BdRGFZlhqDbfqsvkUbtoT02ebkrxu3hBVPmL3fYAAMQ7Lbkm0gbDZDdpshw0j+arMZctpt7QGqXYGKsm6C386rc7c/XuK0F/TKW1aeAQCQ/wiU+8HvSQbKzS1RVVcQKAMAABSr5557LvXz0qVLUz+PGzdODz30UJfxo0eP1r333puV2tBVS2tMG+pb9Fl9qNOv4bZ4akyF16WdAl7tfWBlcjWvy5EKeDv+sRvb/bxdIJx6bTNkM6TqgE+NjeH2MeoUGtvarwMAAChUBMr9UL5doAwAAAAgv0RjCX2+Ndw5ON7SosZgW2pMWYlDowMeHbTPCO0U8Gj0cI9GB7zyljkHrY5h5WUyo/HeBwIAABQgAuV+8BMoAwAA5JXf//4OSdJ5530nx5Ugm0zTUl1juFNo/Fl9izY3htWx5bjDbtOoYW6N26UyGRwHvNop4FGlr6SgW0IAAADkGoFyP/jcyVULBMoAAABA5lmWpW2haJdWFRu3tigWT/Y5NiRVV5ZpdMCryXtXp4Lj6soy2W1s3gYAADDYCJT7wWG3yed2qilMoAwAAApfeMV13R53114iSWpdfb/MrZ92OV9y8FmyDx+j2HsvKfb+33u8vi/uvfdu/e1vTyuRMDV58hTF43EFAiN05plzJEmXXbZQ06fP1M4776wlS36pSCSixsYGffOb83TSSbP7/BwUlnUfbdUTqz/Whi0tamn9onVEeXuf4yN3Ga2dAl7tVO3RyGEelTiH7uaBAAAA+YZAuZ8qfCWsUAYAABgEr7yyWu+9946WLr1HhmHo6qt/qjFjdtUzzzylM8+co3C4RevWrdWVV16r2267Weecc54mTTpIGzZ8pnnzziJQHsJeWvO5Ptkc0sH7jEitOB7sPscAAABID4FyP1V4SwmUAQDAkNDbSuLSqWfv8Lxzr8Pk3OuwtJ//+uv/1Ntvr9N5531TktTW1qoRI2oUjbbps8/Wa+3aN3XIIYfJ6XTqoosu1quvvqx7771LH374gSKRcNrPRf4LRWLaOeDV3GPH5boUAAAAfAmBcj9V+Er0fkNLrssAAAAoeKaZ0Gmnnakzzki2twgGg7Lb7fJ4vHr22VVat26N5syZJ0n66U8Xyefz65BDDtNRR03XM888lcPKkWnBcFSBirJclwEAAIBusEtFP1X4StRMD2UAAIABO+CAr+mpp55UOBxWPB7XJZf8QM8//6ymTz9Wzz33tD77bL0mTtxPkvTaa//U+ed/V4cddoReeWW1JCmRSOSwemRSMBKjvQUAAECeYoVyP1V4SxRpSygWT8jpYPMPAACAdB166DR98MH7+va358k0E5o8eaqOO26WDMNQeXmF9t13ggzDkCSde+4FuvDC81VS4tIee+ypkSNH6fPPN+b4HSATLMtSKByTz+3KdSkAAADoBoFyP1X4SiRJTS1RDS/na3gAAAADMW/e+Zo37/wux2+++bedXp9xxpxUa4ztnXfedzJWG3Ij0pZQwrRYoQwAAJCnaHnRTx2BcnNLLMeVAAAAAENPKJJsL+dzEygDAADkIwLlfqrwdgTK9FEGAAAABlswnFy4QaAMAACQnwiU+ym1QpmN+QAAQAGyLCvXJWRNMb3XoSQYSQbK3jJ6KAMAAOQjAuV+6lih3MQKZQAAUGBsNrsSiXiuy8iaRCIum41NlAtNiBXKAAAAeY1AuZ9cTrvKShy0vAAAAAWnrMyrYHCbLMvMdSkZZ1mmgsFGlZV5c10K+inY3kOZTfkAAADykyPXBRQiv8dFoAwAAAqO11uuxsZ61dV9Jmmot4Mw5HKVyustz3Uh6KdQOCaH3aZSF6vLAQAA8hGBchrK3U4CZQAAUHAMw1BVVXWuywB2KBiOyed2yjCMXJcCAACAbtDyIg1+j4tN+QAAAIAMCEVitLsAAADIYwTKaaDlBQAAAJAZwXCUDfkAAADyWEYD5RUrVmjmzJmaPn267r///i7nP/roI33zm9/UCSecoPPOO09NTU2ZLGfQ+D0utbTGFYsP/Q1tAAAAgGwKskIZAAAgr2UsUK6rq9OSJUv0wAMP6NFHH9WDDz6oDz74IHXesixdeOGFuuCCC/T4449r77331p133pmpcgaV3+OSlFw9AQAAAGDwhMIx+dyuXJcBAACAHmQsUF69erWmTJmiiooKud1uzZgxQytXrkydf+utt+R2uzVt2jRJ0ne/+12dffbZmSpnUJW3f8Btou0FAAAAMGjiCVPhtrh8rFAGAADIWxkLlDdv3qxAIJB6XV1drbq6utTrTz/9VMOHD9ell16qb3zjG7riiivkdrszVc6g6lihTB9lAAAAYPC0RGKSRA9lAACAPObI1I1N05RhGKnXlmV1eh2Px/XPf/5T9913nyZMmKBf//rX+vnPf66f//znfX7GsGHeQa25r3bbpUqSZNlsCgR8OamhUDFf6WPu0sO8pY+5Sx9zlz7mLn3MHYaCYDgZKHtpeQEAAJC3MhYo19TU6PXXX0+9rq+vV3V1dep1IBDQmDFjNGHCBEnSrFmzNH/+/H49Y+vWkEzTGpyC+ygQ8CnemlyZvKGuWfX1waw+v5AFAj7mK03MXXqYt/Qxd+lj7tLH3KWPuesfm83I2cIE7FiwfYUym/IBAADkr4y1vJg6dapefvllNTQ0KBKJaNWqVal+yZK0//77q6GhQe+++64k6bnnntO+++6bqXIGlctpV6nLTg9lAAAAYBB1bHpNywsAAID8lbEVyiNGjNCCBQs0d+5cxWIxzZ49WxMnTtQFF1yg+fPna8KECbr11lt1+eWXKxKJqKamRr/4xS8yVc6g83tc9FAGAAAABlGoo4cyK5QBAADyVsYCZUmqra1VbW1tp2NLly5N/fzVr35VDz30UCZLyBgCZQAAAGBwhdp7KHsIlAEAAPJWxlpeDHXlbpea2z/wAgAAABi4YDgmd4lDDjv/mQIAAJCv+KSWJlYoAwAAAIMrGInSPxkAACDPESinye9xKRSJKZ4wc10KAAAAMCQEwzF5CZQBAADyGoFymvwel6Tkh14AAAAAAxeKxOQrc+W6DAAAAOwAgXKa/O7kB13aXgAAAACDIxRhhTIAAEC+I1BOU3n7CuXmMIEyAAAAMFCWZSkYjspXRqAMAACQzwiU0+T3JD/oskIZAAAAGLjWaELxhCWfm5YXAAAA+YxAOU0dPZQJlAEAAICBC0aSe5N4WaEMAACQ1wiU01TqcsjltKmJQBkAAAAYsFD7Ztc+eigDAADkNQLlAfC7XfRQBgAAAAZBsP1zNZvyAQAA5DcC5QEo97hoeQEAAAAMglB7yws25QMAAMhvBMoD4CdQBgAAAAZFMNXygk35AAAA8hmB8gAQKAMAAACDIxiJym4zVOqy57oUAAAA7ACB8gD43S4FIzGZppXrUgAAAICCFgrH5HM7ZRhGrksBAADADhAoD4Df45JlScH2fm8AAAAA0hMMx+Qto90FAABAviNQHoByT/IDL20vAAAAgIEJRZIrlAEAAJDfCJQHwE+gDAAAAAyKYDhKoAwAAFAACJQHgEAZAAAAGByhSEzeMgJlAACAfEegPAB+dzJQbiJQBgAAANKWME21tMblc9NDGQAAIN8RKA9AWYldDrtNzWECZQAAACBdoUhcklihDAAAUAAIlAfAMAyVe5y0vAAAAAAGINS+QIMeygAAAPmPQHmA/B4XgTIAAAAwAMFwTJLkY4UyAABA3iNQHiC/m0AZAAAAGIhQJBkoe+mhDAAAkPcIlAfI73GxKR8AAAAwAEFaXgAAABQMAuUB8ntcCoZjMi0r16UAAAAABSnYsUKZlhcAAAB5j0B5gPwel0zLSn1NDwAAAED/hMIxlZU45LDznycAAAD5jk9sA1TuSfZ5o48yAAAAkJ5gJMaGfAAAAAWCQHmACJQBAACAgQmFo/RPBgAAKBAEygPkJ1AGAAAABiQYjtE/GQAAoEAQKA8QgTIAAAAwMMFITF5WKAMAABQEAuUBcpc45LAbagoTKAMAAAD9ZbVvcO1zu3JdCgAAAPqAQHmADMOQz+1ihTIAAACQhrZYQrG4yaZ8AAAABYJAeRD4PS41t8RyXQYAAABQcELh5OdoWl4AAAAUBgLlQVDuYYUyAAAAkI5gJBko+8poeQEAAFAICJQHgd/tUjM9lAEAAIa8FStWaObMmZo+fbruv//+LudfeOEF1dbWqra2Vj/4wQ/U0tIiSXrkkUd06KGH6sQTT9SJJ56oJUuWZLv0vBVsX6HsY4UyAABAQXDkuoChwN++QtmyLBmGketyAAAAkAF1dXVasmSJli9fLpfLpTPOOEOTJ0/W2LFjJUnNzc1atGiR7r33Xo0dO1ZLly7VkiVLdPnll2vdunVatGiRZs2aleN3kX+C7QszaHkBAABQGFihPAj8HpcSpqWW1niuSwEAAECGrF69WlOmTFFFRYXcbrdmzJihlStXps5//PHHGjVqVCpgPvLII/XMM89IktauXatHHnlEtbW1+uEPf6impqacvId8FEq1vCBQBgAAKASsUB4Efk/yw29zS1RePggDAAAMSZs3b1YgEEi9rq6u1po1a1Kvd911V23atEnvvvuuxo0bp7/+9a/asmWLJCkQCOjcc8/VAQccoF/96ldavHixbrzxxn49f9gw7+C8kX4KBHwZvb9pGLLbDO2yU+WQ+rZfpudtKGPu0sfcpY+5Sx9zlx7mLX3MXe4RKA+CcndyA5HmlqhGDffkuBoAAABkgmmanQLPL7c78/v9uv766/WTn/xEpmnqtNNOk9OZXGxw6623psadf/75OuaYY/r9/K1bQzJNawDvoP8CAZ/q64MZfUbdlpC8ZU5t2RLK6HOyKRvzNlQxd+lj7tLH3KWPuUsP85Y+5q5/bDYjI4sSaHkxCPye9kCZjfkAAACGrJqaGtXX16de19fXq7q6OvU6kUiopqZGf/7zn/Xwww9r77331s4776xgMKi77747Nc6yLNnt9myWnteC4Rgb8gEAABQQAuVB0BEoN7UQKAMAAAxVU6dO1csvv6yGhgZFIhGtWrVK06ZNS503DEPnnnuu6urqZFmW7r77bs2cOVNut1u/+93v9Oabb0qS7rvvvrRWKA9VwUiMtnEAAAAFhJYXg8BT5pTNMNRMoAwAADBkjRgxQgsWLNDcuXMVi8U0e/ZsTZw4URdccIHmz5+vCRMmaPHixTr//PMVjUZ18MEH67zzzpPdbtevf/1rXXnllWptbdWuu+6qX/ziF7l+O3kjFI5p5+rc9IcGAABA/xEoDwKbYcjncRIoAwAADHG1tbWqra3tdGzp0qWpn4844ggdccQRXa6bNGmSHnnkkUyXV5CC4ai8tLwAAAAoGLS8GCTlbheBMgAAANAPCdNUuDUuHy0vAAAACgaB8iDxe1xsygcAAAD0Q0trXJYkn9uV61IAAADQRwTKg8TvYYUyAAAA0B/BcEyS2JQPAACggBAoDxK/x6Wmlpgsy8p1KQAAAEBBCLV/w89HD2UAAICCQaA8SPxul+IJU5G2RK5LAQAAAAoCK5QBAAAKD4HyICn3JPu+0UcZAAAA6JtQJBko00MZAACgcBAoDxJ/R6BMH2UAAACgT4IRVigDAAAUGgLlQUKgDAAAAPRPMBxVqcsup4P/LAEAACgUfHIbJB2BchOBMgAAANAnoUiMDfkAAAAKDIHyIPGVOWUYrFAGAAAA+ioYjslbRv9kAACAQkKgPEhsNkO+Mieb8gEAAAB9FAqzQhkAAKDQZDRQXrFihWbOnKnp06fr/vvv73L+N7/5jY488kideOKJOvHEE7sdU0j8HpeaQgTKAAAAQF8EI1H52JAPAACgoDgydeO6ujotWbJEy5cvl8vl0hlnnKHJkydr7NixqTHr1q3Tr371K+2///6ZKiOr/B4XK5QBAACAPgqFY/KyQhkAAKCgZGyF8urVqzVlyhRVVFTI7XZrxowZWrlyZacx69at0x133KHa2lotXrxYbW1tmSonK/weFz2UAQAAgD5oiyUUjZvyuemhDAAAUEgytkJ58+bNCgQCqdfV1dVas2ZN6nVLS4v23ntvLVy4UGPGjNGiRYt02223acGCBX1+xrBh3kGtua8CAV+3x2uGe/V/72/R8OFeGYaR5aoKQ09zh94xd+lh3tLH3KWPuUsfc5c+5g6FJtj+zT4vLS8AAAAKSsYCZdM0O4WqlmV1eu3xeLR06dLU63PPPVeXXnppvwLlrVtDMk1rcAruo0DAp/r6YLfnnDYpGkto/YZtKivJ2NQWrB3NHXaMuUsP85Y+5i59zF36mLv0MXf9Y7MZOVuYgC+EIjFJYlM+AACAApOxlhc1NTWqr69Pva6vr1d1dXXq9caNG/XQQw+lXluWJYejsENYf/vX9eijDAAAAOxYMNweKJfR8gIAAKCQZCxQnjp1ql5++WU1NDQoEolo1apVmjZtWup8aWmpfvnLX2r9+vWyLEv333+/jjnmmEyVkxXlnvZAmT7KAAAAwA6FwqxQBgAAKEQZC5RHjBihBQsWaO7cuTrppJM0a9YsTZw4URdccIHWrl2rqqoqLV68WBdeeKGOPfZYWZalb33rW5kqJyv8BMoAAABAn6R6KBMoAwAAFJSM9piora1VbW1tp2Pb902eMWOGZsyYkckSsopAGQAAAOibYCQmm2Gw9wgAAECBydgK5WLkcztlSGoiUAYAAAB2KBSJyet2yrbdxt0AAADIfwTKg8hus8lT5lRzez84AAAAAN0LhmPyldHuAgAAoNAQKA+yco+LlhcAAABAL0LhKBvyAQAAFCAC5UHmJ1AGAAAAehWMxORlhTIAAEDBIVAeZATKAAAAQO+C4Zh8bleuywAAAEA/ESgPMr/bpaYwgTIAAADQE9O01NLKCmUAAIBCRKA8yPwep9qiCbXFErkuBQAAAMhLLa0xWZbkpYcyAABAwSFQHmR+T/Jre7S9AAAAALoXisQkiU35AAAAChCB8iArJ1AGAAAAdigYbg+Uy+ihDAAAUGgIlAcZK5QBAACAHUsFyqxQBgAAKDgEyoPM375TNRvzAQAAAN0LRpKfldmUDwAAoPAQKA8yVigDAAAAOxZihTIAAEDBIlAeZA67TZ5SB4EyAAAA0INQJKYSl11Ohz3XpQAAAKCfCJQzwO9xESgDAAAAPQiGo/LR7gIAAKAgEShngN9NoAwAAAD0JBiJ0e4CAACgQBEoZ4Df41JTe184AAAAAJ0FwzF5y1y5LgMAAABpIFDOAFpeAAAAAD0LhVmhDAAAUKgIlDPA73Ep0hZXLJ7IdSkAAABA3glGovLSQxkAAKAgEShnQLkn+fW95hbaXgAAAADba4slFI2ZrFAGAAAoUATKGeB3twfKYdpeAAAAANtriSQXXfjc9FAGAAAoRATKGeBvX6HcRB9lAAAAoJNg++bVtLwAAAAoTATKGeD3JD8cszEfAAAA0FkwkvyMTMsLAACAwkSgnAHlrFAGAAAAusUKZQAAgMJGoJwBToddZSV2VigDAAAAXxIK00MZAACgkBEo94MVjyrevLVPY/1uF4EyAAAA8CXBSEyGIblLHbkuBQAAAGkgUO6H1pfu1sb/d2mfxvo9BMoAAADAl4XCUXnLnLIZRq5LAQAAQBoIlPvB5qlSPNggyzR7HVvucak5TKAMAAAAbC8YidHuAgAAoIARKPeD4a2SLFNWpKnXsaxQBgAAALoKhmNsyAcAAFDACJT7weaplCRZLQ29jvV7XGppjSue6H01MwAAAFAsQpGYfG4CZQAAgEJFoNwPhqdKkmSG+hYoS2KVMgAAALCdYDgqHyuUAQAAChaBcj8Y3irZvVWSmeh1bHl7Xzj6KAMAAABJpmUpFInJSw9lAACAguXIdQGFxFbq05j/Xar6+mCvY1mhDAAAAHQWbo3LssQKZQAAgALGCuUM6QiUmwiUAQAAAEnJdheS5KWHMgAAQMEiUO6nzY/fovCTN/Q6jhXKAAAAQGehSEyS2JQPAACggBEop8Hc9nmvY0qcdpW47GpuiWWhIgAAACD/BcPtgXIZPZQBAAAKFYFyPzl8VbJatskyzV7HlrtdbMoHAAAAtGOFMgAAQOEjUO4nh3+4ZCVkRZp6Hev3uGh5AQAAALRL9VBmUz4AAICCRaDcT3b/MEmS1dLY61gCZQAAAOALwXBMJU67XE57rksBAABAmgiU+8nhHy5JMkNbex3r97jURKAMAAAASEq2vGB1MgAAQGFz5LqAQuMavpM837xZRqmv17F+t1MtkZgSpim7jeweAAAAxS0YjslL/2QAAICCRsrZT4bdIVuZX4Zh9Dq23OOSpS92swYAAACKWSgSZUM+AACAAkegnIbWl/+ottcf6XWc3+OSJPooAwAADBErVqzQzJkzNX36dN1///1dzr/wwguqra1VbW2tfvCDH6ilpUWStHHjRp199tk69thjdeGFF6aOF5tgOCYfLS8AAAAKGoFyGsytnyq+4a1exxEoAwAADB11dXVasmSJHnjgAT366KN68MEH9cEHH6TONzc3a9GiRVqyZIlWrFihcePGacmSJZKkq666SmeddZZWrlyp8ePH67bbbsvV28ipYCQmn9uV6zIAAAAwAATKaTA8lbJCDb2O6wiU2ZgPAACg8K1evVpTpkxRRUWF3G63ZsyYoZUrV6bOf/zxxxo1apTGjh0rSTryyCP1zDPPKBaL6bXXXtOMGTMkSSeffHKn64pFLJ5QWzTBpnwAAAAFjk350mDzVCke3ibLNGXsYLM9f/vqi+YwgTIAAECh27x5swKBQOp1dXW11qxZk3q96667atOmTXr33Xc1btw4/fWvf9WWLVvU2Ngor9crhyP50TsQCKiurq7fzx82zDvwN5GGQKD3zaj7Ysu2iCRp1AjfoN0znxXDe8wU5i59zF36mLv0MXfpYd7Sx9zlHoFyGgxvlWSZsiJNMjyVPY4rddnlcthoeQEAADAEmKbZaWNmy7I6vfb7/br++uv1k5/8RKZp6rTTTpPT6ewyTlKfNnj+sq1bQzJNK/03kIZAwKf6+uCg3OvTuuR9rLg5aPfMV4M5b8WGuUsfc5c+5i59zF16mLf0MXf9Y7MZGVmUQKCcBpunSpJktTRIOwiUDcOQ3+MiUAYAABgCampq9Prrr6de19fXq7q6OvU6kUiopqZGf/7znyVJa9as0c4776yqqioFg0ElEgnZ7fYu1xWLYDgmSfK5aXkBAABQyOihnAbbiD1UNvOHslWM7HUsgTIAAMDQMHXqVL388stqaGhQJBLRqlWrNG3atNR5wzB07rnnqq6uTpZl6e6779bMmTPldDo1adIkPfnkk5KkRx99tNN1xSIYSX4mJlAGAAAobATKabCV+uTYabwMl7vXsX63S00tsSxUBQAAgEwaMWKEFixYoLlz5+qkk07SrFmzNHHiRF1wwQVau3atbDabFi9erPPPP1/HHnus/H6/zjvvPEnSFVdcoWXLlmnmzJl6/fXXdfHFF+f2zeRAxwplNuUDAAAobLS8SFN07VMyvMPl3O3AHY7ze1z66PPmLFUFAACATKqtrVVtbW2nY0uXLk39fMQRR+iII47oct3o0aN17733Zrq8vBYKx2QYkqeUQBkAAKCQsUI5TdG3n1P8w1d7Hef3uBQMR7O+gQoAAACQT0KRmDylTtls/d+QEAAAAPmDQDlNNk+VzJaGXseVe1yyrOQHaAAAAKBYBcNR+icDAAAMAQTKaTI8VbJCvQfKfo9LktiYDwAAAEUtFInJR/9kAACAgkegnCabt0pWeJss09zhOH/7KowmAmUAAAAUsWA4Jq/blesyAAAAMEAZDZRXrFihmTNnavr06br//vt7HPf888/r61//eiZLGXSGp1KyTFmRph2OY4UyAAAAIAUjMVpeAAAADAF9CpRDoZAk6e2339ajjz6qWKz3fsB1dXVasmSJHnjgAT366KN68MEH9cEHH3QZt2XLFl1//fX9LDv37DVfkeug02TYd/yhuLw9UGaFMgAAAIqVaVkKhWPy0vICAACg4PUaKN9000264oortHHjRp1//vlavny5rrzyyl5vvHr1ak2ZMkUVFRVyu92aMWOGVq5c2WXc5Zdfrosuuiit4nPJXrWTSvabKaPUu8NxZSUOOeyGmsMEygAAAChOkba4TMuSj5YXAAAABc/R24AXXnhB999/vx588EEdf/zxuuyyy3TKKaf0euPNmzcrEAikXldXV2vNmjWdxtxzzz3aZ5999NWvfjWN0qVhw3Yc5mZKIOCTZZmKfPSmHL5hclXvssPxFb5SRROWAgFflirMX8xB+pi79DBv6WPu0sfcpY+5Sx9zh3wWCie/4cimfAAAAIWv10BZksrKyrR69WqdfvrpkqRotPfVtqZpyjCM1GvLsjq9fv/997Vq1Srdfffd2rRpU3/rliRt3RqSaVppXZuuQMCn+vqgLMtS6M+/kHOfI1V68Jk7vMZX5tDmrS2qrw9mqcr81DF36D/mLj3MW/qYu/Qxd+lj7tLH3PWPzWbkbGFCsQq2B8peeigDAAAUvF5bXlRWVurKK6/UunXrNHXqVN1www2qrq7u9cY1NTWqr69Pva6vr+903cqVK1VfX69TTjlF3/72t7V582adddZZab6N7DMMQ4a3UlZLY69j/W4Xm/IBAACgaAUjyc/CbMoHAABQ+HoNlK+//npVV1frjjvuUFlZmQzD6NMmelOnTtXLL7+shoYGRSIRrVq1StOmTUudnz9/vp566ik99thjuvPOO1VdXa0HHnhgYO8my2yeKpktDb2O83tcaqKHMgAAAIpUaoUyLS8AAAAKXq+B8vDhwzV37lxNmDBBb7/9tvbYYw+Vl5f3euMRI0ZowYIFmjt3rk466STNmjVLEydO1AUXXKC1a9cOSvG5ZniqZIX6FigHW2Iyrey25wAAAADyQSjS3kOZTfkAAAAKXq89lG+66SZ9+umn+sEPfqDzzz9fY8eO1WuvvaZrrrmm15vX1taqtra207GlS5d2GbfTTjvpueee60fZ+cHmrVI8vE2Wacqw9ZzN+z0umZallkiMD9EAAAAoOsFwVC6HTSVOe65LAQAAwAD1ukL5hRde0M9+9jOtWrVKxx9/vO655x69++672agt79kCu8mx+9ekeNsOx5V7kiEyfZQBAABQjELhGP2TAQAAhoheA2VJKisr0+rVqzVlyhRJUjRKMCpJzl0PUNlRF8pwle1wnN9NoAwAAIDiFYzE5C3jm3oAAABDQa+BcmVlpa688kqtW7dOU6dO1Q033KDq6ups1Jb3LMuSGWmW1dayw3H+9hXKbMwHAACAYhQMx+RlhTIAAMCQ0GugfP3116u6ulp33nmnysrKZBiGrr/++mzUlv/aWtRy73zF3vv7Dof5Uy0vYtmoCgAAAMgroUiUlhcAAABDRK+b8g0fPlwnnnii/vnPf+qdd97R7NmzNXz48GzUlv9KPJLdJbOlYYfDPKUO2W0GLS8AAABQlILhmLxlBMoAAABDQa8rlF966SWdcsopeuaZZ/Tss89q9uzZeuaZZ7JRW94zDEOGt0pWL4GyYRjye1wEygAAACg6sbip1mhCPjc9lAEAAIaCXlco33TTTbrvvvs0duxYSdJ//vMfLVy4UEcffXTGiysENk+lzNCOA2UpuTFfMz2UAQAAUGRCkWTbNx8rlAEAAIaEXlcox2KxVJgsSXvuuacSiURGiyokyRXKjb2O83tcamKFMgAAAIpMsH1RBT2UAQAAhoZeA+XS0lKtXbs29Xrt2rUqKyvLaFGFxOYfIcPllmWZOxzn9zhpeQEAAICi07FCmR7KAAAAQ0OvLS8WLlyo7373uxozZowMw9BHH32km266KRu1FYSSA05QyQEn9DrO73EpGI7KsiwZhpGFygAAAIDcC4bbA2V6KAMAAAwJvQbKkyZN0l/+8he9+eabMk1T++23nyorK7NR25BS7nYpnrAUbovLU8rqDAAAABSHVA9lWl4AAAAMCb22vJCkiooKHX744TryyCNVWVmpM888M9N1FQyzqU6h+y5W7L+v73Cc35NckUHbCwAAABSTYDgqQ5KXRRUAAABDQp8C5S977733BruOgmWUemWFt8kKbt3hOAJlAAAAFKNgJCZPmVM2G23fAAAAhoK0AmVsx+WWHC6ZLQ07HNYRKDcRKAMAAKCIhMIxNuQDAAAYQgiUB8gwDNk8VbL6GCizQhkAAADFJBiOykv/ZAAAgCGjx035fvazn3V73LIsxWKxjBVUiAxvlczQjgNlb5lTNsNQc5hAGQAAAMUjFIkpUFGW6zIAAAAwSHoMlCsqKnq86Dvf+U4mailYhqdK5oa3dzjGZhjyuZ2sUAYAAEBRCYZj2n2UP9dlAAAAYJD0GChfdNFF2ayjoJUefKZk7/1rfH6PS80trO4GAABAcbAsS6FITD63K9elAAAAYJD0GCij74wST5/G+T0uNuUDAABA0Yi0xZUwLTblAwAAGELYlG8QJBo+U/ivv1Ji66c7HOd3u2h5AQAAgKIRjCS/nedjUz4AAIC80djYOKDrew2U4/H4gB5QFCxLifVrZDZt2uGwco9LzeGoLMvKUmEAAABA7oTCyUDZW0bLCwAAgHxx/PHH6wc/+IFef/31tK7vNVA+8sgjtWTJEm3YsCGtBxQDm7dKkmSFGnY4zu9xKRY31RpNZKMsAAAAIKeCYVYoAwAA5JvnnntOU6dO1S9+8QvV1tbq/vvvVygU6vP1vQbKy5Ytk91u15w5c/Sd73xHzz//PCtsv8zllhwumS07Xi7u9yQ/SNNHGQAAAMUgGEl+7vXRQxkAACBvlJaW6pRTTtGyZct0+eWX6w9/+IMOO+wwXXXVVX1qh9FroDxy5EjNnz9fzz77rE499VRdffXVOuqoo/S73/1O0SjBqCQZhiGbp0pWaOsOx/k9ya/60UcZAAAAxSDV8oIVygAAAHnlxRdf1Pe//30tWLBARx99tP70pz9p5MiR+t73vtfrtY6+PODDDz/Un//8Zz3xxBPab7/9dPLJJ+ull17S//7v/+r2228f8BsYCgxvVe8rlN0EygAAACgewUhMTodNJU57rksBAABAuyOPPFIVFRU666yz9Mtf/lKlpaWSpL322ksPPvhgr9f3GiifeeaZWr9+vU455RQ99NBDqqmpkSQdccQRmjJlygDLHzpKvjZbsu14wXd5+wplWl4AAACgGITCMXnLnDIMI9elAAAAoN2NN96ovfbaSx6PR9FoVFu3btWwYcMkSc8++2yv1/fa8uKss87S3/72Ny1YsCAVJkuSzWbT3/72twGUPrTYq3eXffiuOxzjc7tkGKxQBgAAQHEIhqNsyAcAAJBnNm3apG984xuSpA0bNuj444/Xc8891+frew2UjzvuOC1btkwXXnihLrroIj388MOpcx6PJ42Sh6ZE4wa1/fMhmZHmHsfYbIZ8ZU41hwmUAQAAMPSFIjE25AMAAMgzv/3tb3XPPfdIknbbbTc98sgjuuWWW/p8fa8tL6655hp98MEHOvHEE2VZlh5++GF9+umnWrBgQfpVD0FWcIui/35CjjH7SWX+Hsf5PS5WKAMAAKAoBMMxDa8oy3UZAAAA2I5pmp06UYwcOVKmafb5+l4D5X/84x/6y1/+IqczubLghBNO0AknnECg/CWGt0qSZIYaZB/R8zgCZQAAABSLICuUAQAA8k5VVZX+9Kc/afbs2TIMQ4888oiGDx/e5+t7bXlRVVWlRCKRem0Yhvz+nlfgFiubJxkoWy0NOxzn97jYlA8AAABDXjxhKtIWl5ceygAAAHll8eLFWrZsmSZOnKiJEydq2bJluuKKK/p8fa8rlMeNG6ezzjpLJ598sux2u5588klVVlbqrrvukiR961vfSr/6ocTllhwlMlsadzjM73bRQxkAAABDXigSk5TcmBoAAAD5Y9ddd9Xy5cvV1NQku90ur9fbr+t7DZTb2tq011576a233pIk7bTTTpKk999/P41yhy7DMGTzVMoKbd3huHKPS9GYqdZoXKWuXqcfAAAAKEihcHugTMsLAACAvNLQ0KDHH39cLS0tsixLpmnqk08+0Y033tin63tNNK+77jpJ0oYNGxSPxzVmzJiBVTyEub46Uyrx7HCM35NcodHcEiVQBgAAwJAVbP9Wno+WFwAAAHnl4osvVmlpqT744ANNnTpVq1ev1oEHHtjn63vtofzJJ5/o+OOP10knnaSTTz5ZRx99tD788MMBFT1UOcdNk3O3HU/+F4FyLBslAQAAADkRbG954WWFMgAAQF7ZuHGj7rzzTk2bNk1z5szRH//4R3300Ud9vr7XQHnx4sU6//zz9dprr+mNN97QhRdeqKuuumpARQ9VZmirYh+9JstM9DjG395Djo35AAAAMJQF21teeOmhDAAAkFeGDx8uKdlL+f3339eIESMUj8f7fH2vgfLWrVv1jW98I/X6lFNOUWPjjjeeK1bx9WvV+sytssLbehyTWqHMxnwAAAAYwkKpFcq0eQMAAMgnw4YN0+9+9zuNHz9eDz/8sJ577jm1trb2+fpeA+VEIqFt27alXjc0NKRVaDGweSolSVZLz4F7Rw+5ZlYoAwAAYAgLhqPylDpkt/X6nxwAAADIosWLF8vlcmnSpEkaP368br75Zv3whz/s8/W9LheYM2eOTj/9dB133HEyDENPPvmkzjnnnAEVPVQZ3ipJkhlqkH1E92Mcdpu8ZU4CZQAAAAxpoUiMdhcAAAB56Prrr9cvfvELSdLChQu1cOHCfl3fa6B88skna8yYMXrppZdkmqauuOIKTZ06Nb1qhzibJxkoWy07XsXt97gIlAEAADCkBcMx+diQDwAAIO+88847sixLhmGkdX2vgfLs2bP12GOPacqUKWk9oKi43JKjRGaol0DZ7VQTPZQBAAAwhAXDMQUqSnNdBgAAAL6kurpaxx9/vL761a/K4/Gkjl9++eV9ur7XQLmsrEybNm1STU1N+lUWCcMw5Bw7RbaKkTsc5/e49PGmYJaqAgAAALIvFIlqt5G+XJcBAACAL9l///21//77p319r4FyJBLRUUcdpZqaGrnd7tTxFStWpP3Qoax02rd6HUPLCwAAAAxllmUpGI7J66blBQAAQL656KKLBnR9r4HyZZddNqAHFBsrHpUVbpLNH+hxTLnHpdZoQtFYQi6nPYvVAQAAAJnXGk0oYVrylbEpHwAAQL6pra3t9nhfFxD3Gig/+uijuvbaazsdmz9/vg466KA+PaDYRP/9hKL/t0Le85fKsHU/vf723a6bW6IaXlGWzfIAAACAjAu27xfiY4UyAABA3vnJT36S+jkWi+kvf/mLdt555z5f32OgfMUVV6iurk5vvPGGGhq+2GQuHo9r/fr1aZY79BmeKkmWrHCTDO+wbsf4PclAuSlMoAwAAIChJxiJSSJQBgAAyEdfXig8depUnXHGGbrwwgv7dH2PgfLs2bP1n//8R++9955mzJiROm6327XffvulV20RsHmqJElWqEHqJVCmjzIAAACGolA4GSh7aXkBAACQ9xobG7V58+Y+j+8xUJ4wYYImTJigqVOnqqamZlCKKwaGNxkomy0N6qk7cjmBMgAAAIawYEegzAplAACAvPPlHsobN27U6aef3ufre+2h/Pnnn2vhwoVqamqSZVmp431t0lxsbJ5KSZLV0tDjGJ+bQBkAAKAQrVixQrfffrvi8bjOOeccnX322Z3Ov/XWW/rpT3+qWCymkSNH6pe//KX8fr8eeeQR3XjjjRo2LPkNtiOOOEILFizIxVvIilBHy4syAmUAAIB8s30PZcMwVFVVpT322KPP1/caKP/0pz/VySefrH322UeGYaRXZTFxuWX4R+xwiNNhk7vEoeaWWJaKAgAAwEDV1dVpyZIlWr58uVwul8444wxNnjxZY8eOTY255pprNH/+fB1++OH6+c9/rt///vdasGCB1q1bp0WLFmnWrFk5fAfZEwxH5bAbKnX19J09AAAA5Mouu+yi3/72t7ryyiv10Ucf6YYbbtDixYs1fPjwPl1v622Aw+HQt771LU2ePFkHHXRQ6h90zzAMec+4Xq6Jx+1wnN/jUlOYFcoAAACFYvXq1ZoyZYoqKirkdrs1Y8YMrVy5stMY0zTV0tIiSYpEIiotLZUkrV27Vo888ohqa2v1wx/+UE1NTVmvP5uCkZh8bhcLUgAAAPLQokWLtPvuu0uSRo8erYMOOkiXXHJJn6/vNVDec8899d5776VfIbrl97hoeQEAAFBANm/erEAgkHpdXV2turq6TmMWLVqkyy+/XIceeqhWr16tM844Q5IUCAT0ve99T48//rhGjhypxYsXZ7X2bAuFY/LS7gIAACAvNTY2au7cuZKkkpISzZs3T/X19X2+vteWF+vXr9cpp5yiUaNGqaSkJHWcHso9a/vnQ4r/93V5Tv95j2P8Hpc+2xzKYlUAAAAYCNM0O624tSyr0+vW1lZddtlluvvuuzVx4kTddddd+vGPf6w777xTt956a2rc+eefr2OOOabfzx82zDuwN5CmQMDX72taYwkNKy9L69qhopjf+0Axd+lj7tLH3KWPuUsP85Y+5m7gEomE6urqNGJEsm3vli1bOu2d15teA+WBbBbS26YlTz/9tG6++WaZpqkJEyZo8eLFcrlcaT8vb9jsMpvqZJlxGbbup7jc7dJbrFAGAAAoGDU1NXr99ddTr+vr61VdXZ16/f7776ukpEQTJ06UJJ1++um66aabFAwG9fDDD2vevHmSkkG03d7/3sJbt4Zkmn3/oD8YAgGf6uuD/b6usblVY2rSu3YoSHfewNwNBHOXPuYufcxdepi39DF3/WOzGd0uSpg3b55OOukkHXbYYTIMQ6tXr9aPfvSjvt+3twEHHXSQSktL9dFHH2m//faT0+nsUw/ljk1LHnjgAT366KN68MEH9cEHH6TOh8NhLV68WHfddZf+8pe/qK2tTY888kifC89nhqdSkiUr3HNvPL/HqUhbXLF4InuFAQAAIG1Tp07Vyy+/rIaGBkUiEa1atUrTpk1LnR8zZow2bdqkjz76SJL07LPPasKECXK73frd736nN998U5J03333pbVCuZAEwzH5yobAQhEAAIAhaPbs2brrrru0zz77aPz48frDH/6g2traPl/fa6C8fPlyXXLJJfrd736nYDCo733ve1q2bFmvN+5t0xK3263nnntOw4cPVyQS0datW+X3+/tceD6zeaskSVaooccxfk/yA3ZzSywrNQEAAGBgRowYoQULFmju3Lk66aSTNGvWLE2cOFEXXHCB1q5dq/Lycl133XW6+OKLVVtbq4cffljXXnut7Ha7fv3rX+vKK6/Ucccdp7feeksLFy7M9dvJmHjCVLgtLp+bHsoAAAD5qK6uTn/60580b948HXLIIVqyZMng9lC+99579eCDD2rOnDkaNmyYli9frvPPP1+nnXbaDq/rbtOSNWvWdBrjdDr1wgsv6Ec/+pGqq6t16KGH9rnwfGZ4koGy2dKgnr7MmAqUw1ENKy/NUmUAAAAYiNra2i6rN5YuXZr6+fDDD9fhhx/e5bpJkyYNmW/j9aYlklww4SVQBgAAyEs//vGP9fWvf12SNHr0aB100EG69NJLO32u3ZFeA2WbzSav94teGyNHjuxTz7feNi3pcPjhh+vVV1/Vr371K1155ZW68cYb+1S4lL8bk5i+XfSxJLfCquhh7K6RePIHh72omokX03sdbMxdepi39DF36WPu0sfcpY+5Q74ItgfKPjctLwAAAPJRY2Oj5s6dK0kqKSnRvHnz9Oijj/b5+l4D5YqKCr3zzjupMPjxxx9XeXl5rzfubdOSbdu2ad26dalVybW1tf3eADBfNyaxLEueOTcpWtbz2EQ0+UF7/cYm7RbwDHqd+YjG6elj7tLDvKWPuUsfc5c+5i59zF3/9LQ5CQZHKNy+QrmMFcoAAAD5KJFIqK6uTiNGjJAkbdmyRZbV94y110D50ksv1f/+7//q008/1aGHHqqSkhLddtttvd546tSpuuWWW9TQ0KCysjKtWrVKV199deq8ZVlauHChHn74YY0aNUorV67UAQcc0OfC85lhGDLcOw7dy1M9lKPZKAkAAADIii9WKBMoAwAA5KN58+bppJNO0mGHHSZJevnll/WjH/2oz9f3Gijvscceeuyxx/Txxx8rkUhot912k9PZ+4fD7TcticVimj17dmrTkvnz52vChAm6+uqr9Z3vfEeGYWjs2LG66qqr+lx4vmv79xOymreodNq8bs87HXaVldgJlAEAADCkhMLJz7c+VigDAADkpdmzZ2v8+PF65ZVXZLfbtcsuu+iee+7psldIT3oNlCXJbrdrjz320He+8x3dcccdfS6ut01Ljj76aB199NF9vl8hMbfVKfHZ2h2O8btdag4TKAMAAGDoCLa3vPAQKAMAAOStkSNHKhqN6v7771c4HNY3v/nNPl/bp0C5w+bNm/tdXLGyeasUDzfJMuMybN1Ps9/jYoUyAAAAhpRgJCZ3iUMOuy3XpQAAAOBLPvroI/2///f/9Pjjj2v06NFqbW3Vc889J5+v75t89+tTXn+aMxc7w1MpyZIVbupxjN/jUhOBMgAAAIaQYDgqL/2TAQAA8s63v/1tzZkzR06nU/fcc4+eeOIJeTyefoXJUj8D5fnz5/fr5sXM5q2SJJmhhh7HsEIZAAAAQ00oEmNDPgAAgDz09ttva99999Wee+6pMWPGSJIMw+j3fXoNlLds2aJnn31WkvTGG2/onHPO0bvvvtvvBxUbw5MMlK2WngPlcrdLLa1xxRNmtsoCAAAAMioUjslX5sp1GQAAAPiS559/Xt/4xjf0xBNP6NBDD9X8+fPV1tbW7/v0GigvWrRI69ev18svv6yXXnpJJ554on72s5+lVXQxsfmrVXbcD2QftXePY/ye5Aftjo1LAAAAgEIXjMRoeQEAAJCHHA6HZs6cqXvvvVfLly9XdXW12traNH36dP3xj3/s8316DZS3bdumefPm6cUXX9SsWbN08sknKxKJDKj4YmA4XHLsPEG2Mn+PYzoCZdpeAAAAYCiwLEvBcEy+MgJlAACAfDZ27FhdfvnlevHFF3Xeeedp2bJlfb6210A5FospFovppZde0tSpUxWJRBQOhwdUcLGIvfeSou++0OP5jkCZjfkAAAAwFLRGE4onTFYoAwAAFIiysjKdfvrpeuSRR/p8Ta+B8lFHHaWDDz5YlZWVGj9+vE499VTNmjVrQIUWi9gHryj2Tu+BMiuUAQAAMBSEIslWbvRQBgAAGLocvQ2YP3++TjvtNI0YMUKSdMMNN2jcuHEZL2wosHmrFF+/tsfz5e72QDlMoAwAAIDC1xEos0IZAABg6Op1hfKWLVv01ltvyTAM/fKXv9R1112nd999Nxu1FTzDUyUr3CTLjHd7vsRlV4nTzgplAAAADAnB9oUSPgJlAACAIavXQHnRokVav369Xn75Zb300ks68cQT9bOf/SwbtRU8w1slyZIVbupxjN/jJFAGAADAkBAMd7S8IFAGAAAYqnoNlLdt26Z58+bpxRdf1KxZs3TyyScrEolko7aCZ/NUSpLMUEOPY/weF5vyAQAAYEjoCJS99FAGAAAYsnoNlGOxmGKxmF566SVNnTpVkUhE4XA4G7UVPFvVznIdNFs2T0WPY/xuFz2UAQAAMCSEIjHZbYbKSuy5LgUAAAAZ0mugfNRRR+nggw9WZWWlxo8fr1NPPVWzZs3KRm0Fz+apVMl+s2TzBXocU+5x0fICAAAAQ0IwHJXX7ZRhGLkuBQAAABni6G3A/Pnzddppp6mmpkaSdMMNN2jcuHEZL2yoiG98V4bDKXv1Ht2e93tcCoVjSpim7LZe830AAAAgb4UiMflodwEAADCk9Room6apFStW6MUXX1Q8HtchhxyisWPHyuHo9VJIavvHPbKVj1TZ9O93e97vccmSFArHVO4tyW5xAAAAwCAKRmLyudmQDwAAYCjrdUnsjTfeqFdeeUXnnHOOvvWtb+lf//qXfvGLX2SjtiHB8FTJbNnBpnzu5AoONuYDAABAoQuGCZQBAACGul6XGb/00kt6+OGH5XQmPxgeccQROuGEE3TppZdmvLihwOapVLzhsx7P+z3JQJmN+QAAAFDoQuGovGUEygAAAENZryuULctKhcmS5HK5Or3GjhmeKlnhJlmJeLfnyzsCZVYoAwAAoIAlTFMtrXECZQAAgCGu10B53Lhxuvbaa/Xpp59q/fr1uu666/SVr3wlG7UNCYa3SpIlK7yt2/OpFcotsewVBQAAAAyylkhyAYXPzaZ8AAAAQ1mvgfIVV1yh5uZmnXHGGTrttNPU0NCgn/zkJ9mobUiwV+0kx+4HSbK6PV/qssvpsLFCGQAAAAUt2N7CjR7KAAAAQ1uvPZTvuOMO/fznP89GLUOSvXoPlR39vR7PG4Yhv9vFpnwAAAAoaKFI8ht3PlpeAAAADGm9rlB+/vnns1DG0Ga2BmVGmns87/e41NzSlsWKAAAAgMEVDCcDZS8tLwAAAIa0Xlco77TTTjr33HN1wAEHyOPxpI5/61vfymhhQ0nLHxfKuddhKp16drfnyz0ubWlqzXJVAAAAwOAJdqxQpuUFAADAkNZroFxRUSFJ2rBhQ6ZrGbJs3ipZoYYez/s9Tn30ec8rmAEAAIB8F2rvoeyl5QUAAMCQ1mugfN1116V+jkajcrn4Clt/GZ4qmS07CpRdCoajMk1LNpuRxcoAAACAwREMx1RWYpfD3mtXPQAAABSwHj/tRaNR/fjHP9bTTz+dOvb9739fl1xyieLxeFaKGypsnl5WKLtdsqwvNjIBAAAACk0oEpOvjMUnAAAAQ12PgfLNN9+sUCikAw44IHVs8eLFampq0i233JKV4oYKw1MpK9IsK9F9EF/uLZEkNbdEs1kWAAAAMGiCkZi89E8GAAAY8noMlJ9//nndeOONGjZsWOrYiBEj9Itf/ELPPPNMVoobKmzlI2SUj5DV1tLteX/7B++mMIEyAAAAClMwHJWP/skAAABDXo89lJ1Op0pLS7sc93q99FHuJ+eeU+Xcc2qP5/2e5HyyQhkAAACFKhSJaedqb67LAAAAQIb1uELZZrMpFAp1OR4KheihPMjKCZQBAABQwCzLUjAck8/NwhMAAIChrsdAedasWbr88ssVDodTx8LhsC6//HJNnz49K8UNFVY8qtD9CxRds7Lb82UlDjnsBoEyAAAAClI0ZioWN2l5AQAAUAR6DJTPOecc+Xw+HXLIITrttNM0e/ZsHXLIIfL7/fqf//mfbNZY8AyHS1a0VWZoa/fnDUN+j4tAGQAAAAUp2L4XiJdAGQAAYMjrsYeyzWbT1Vdfre9+97t66623ZLPZNHHiRFVXV2ezviHD5q2UFWro8bzf7WJTPgAAABSkYCQmSbS8AAAAKAI9BsodRo8erdGjR2ejliHN8FTJbNlBoOxxaVuoLYsVAQAAAIMj1B4oe92sUAYAABjqemx5gcFl81TteIUyLS8AAABQoDpaXvgIlAEAAIY8AuUsMbxVsiLNshLxbs+Xe1wKhmMyLSvLlQEAAAADEwq3t7yghzIAAMCQ12vLCwwO14QZck08Toa9+yn3u11KmJbCrXE2MwEAAEBBCUZistsMlZXwnxcAAABDHSuUs8RwlclwlvR43u9JbmDSRNsLAAAAFJhgOCZvmVOGYeS6FAAAAGQYgXKWmOFtCq9covj6Nd2e7wiU6aMMAACAQhMMR9mQDwAAoEgQKGeJ4ShR4tM3ZTZ81u15AmUAAAAUqlAkRv9kAACAIkGgnCWGq0xylskMNXR7vpxAGQAAAAUqFInJ63blugwAAABkAYFyFtm8lbJaGrs95y51yG4z1BwmUAYAAEBhCYZj8tHyAgAAoCgQKGeR4amS2dL9CmWbYcjndrIpHwAAAAqKaVpqoeUFAABA0SBQziKbp0pWDy0vpGQfZVpeAAAAoJCEWmOyJHkJlAEAAIqCI9cFFBPnhBlyjpvW43kCZQAAABSaUDgmSfLRQxkAAKAoEChnkb1q9A7Pl7td2rilJUvVAAAAAAMXbN8DxEsPZQAAgKJAy4ssMlsa1fbawzK3fd7t+Y4VypZlZbkyAAAAID2hSPsKZVpeAAAAFAUC5SyyohFF/7VCiS0fd3ve73EpnrAUaYtntzAAAAAgTcEILS8AAACKCYFyFtk8lZIkM9TY7Xm/J/khvIk+ygAAACgQwfYeymzKBwAAUBwIlLPIcJVJrjJZLQ3dnu8IlNmYDwAAAIUiFI6p1GWX08F/WgAAABQDPvVlmc1T1WOgXN7+NcHm9lUeAAAAQL4LRqKsTgYAACgiBMpZZnirZLbsuOUFK5QBAABQKELhGP2TAQAAiogj1wUUG+e4w6VYa7fnvGVOGYbU1NKW5aoAAACA9AQjMZV7CJQBAACKRUZXKK9YsUIzZ87U9OnTdf/993c5/8wzz+jEE0/UCSecoO9973tqamrKZDl5wbnbJDm/cmi352w2Qz63ixXKAAAAKBihcFQ+Wl4AAAAUjYwFynV1dVqyZIkeeOABPfroo3rwwQf1wQcfpM6HQiFdeeWVuvPOO/X4449rr7320i233JKpcvKGGWlW7L+vy2pr6fa83+1Scws9lAEAAFAYgpGYvG4CZQAAgGKRsUB59erVmjJliioqKuR2uzVjxgytXLkydT4Wi+mKK67QiBEjJEl77bWXPv/880yVkzfMrZ+q9enfKNHwWbfnyz1ONbFCGQAAIC/19g28t956S6eccopOOOEEfec731Fzc7MkaePGjTr77LN17LHH6sILL1RLS/eLCwpNWyyhaMykhzIAAEARyVigvHnzZgUCgdTr6upq1dXVpV5XVlbqmGOOkSS1trbqzjvv1NFHH52pcvKG4amSJFktDd2e93toeQEAAJCPevsGniRdc801mj9/vh5//HHttttu+v3vfy9Juuqqq3TWWWdp5cqVGj9+vG677bZcvIVBFwonv1nnpeUFAABA0cjYpnymacowjNRry7I6ve4QDAb1P//zPxo3bpy+8Y1v9OsZw4Z5B1xnOgIBX9rXmv5d9LEkt8Kq6OY+NQGf3nivXsOHe7udr0I3kLkrdsxdepi39DF36WPu0sfcpY+5y7ztv4EnKfUNvIsuuig1xjTN1OrjSCSi8vJyxWIxvfbaa7r11lslSSeffLLmzJmjhQsXZv09DLZgJLkQgh7KAAAAxSNjgXJNTY1ef/311Ov6+npVV1d3GrN582add955mjJlii699NJ+P2Pr1pBM0xpwrf0RCPhUXx8c2E1cZQrWbVSsm/s4DSkaN7V+wzaVlWTstycnBmXuihRzlx7mLX3MXfqYu/Qxd+lj7vrHZjPSWpjQ3Tfw1qxZ02nMokWLdO655+raa69VWVmZli1bpsbGRnm9Xjkcyc92gUCg0zf3ClnHCmVaXgAAABSPjCWWU6dO1S233KKGhgaVlZVp1apVuvrqq1PnE4mEvvvd7+q4447T9773vUyVkZdsnipZLY3dnvN7kqs7mluiQy5QBgAAKGS9fQOvtbVVl112me6++25NnDhRd911l3784x/r6quv7vLNs3S+iZaP384zPt0mSdplpwoFArmpL1/xrYH0MXfpY+7Sx9ylj7lLD/OWPuYu9zKWWI4YMUILFizQ3LlzFYvFNHv2bE2cOFEXXHCB5s+fr02bNuntt99WIpHQU089JUkaP368rrnmmkyVlDccux4g2bv/WqDfk1zd0dQS1YgqdzbLAgAAwA709g28999/XyUlJZo4caIk6fTTT9dNN92kqqoqBYNBJRIJ2e32br+51xf5+O28DXXJc7HWKKvkt8O3BtLH3KWPuUsfc5c+5i49zFv6mLv+Sfebeb3J6BLY2tpa1dbWdjq2dOlSSdKECRP07rvvZvLxeavka6f0eM7f/nVBNuYDAADIL719A2/MmDHatGmTPvroI+2+++569tlnNWHCBDmdTk2aNElPPvmkamtr9eijj2ratGk5fCeDJxSJymYYfLMOAACgiPDJLwesRFxWS4MM7zAZNnunc+XtK5SbwwTKAAAA+aS3b+BNmDBB1113nS6++GJZlqVhw4bp2muvlSRdccUVWrRokW6//XaNHDlSv/rVr3L8bgZHMByT1+2UbQhuJg0AAIDuESjnQPw/q9X64h/kOeOXMvyBTue8bqcMsUIZAAAgH+3oG3iSdPjhh+vwww/vct3o0aN17733Zry+bAuFY/KVdd/KDQAAAEOTLdcFFCPDWyVJMlsaupyz22zyup0EygAAAMh7wXBUXgJlAACAokKgnAOGJxkoW90EylJyY74mAmUAAADkuWAkJp+bQBkAAKCYECjngM1TKUkyQz0Eym4XPZQBAACQ90KRmLztm0oDAACgOBAo54DhKpNcZT2uUC73uGh5AQAAgLxmWpZCEXooAwAAFBsC5RyxD99VsnW/J6Lf41JzSyy7BQEAAAD9EG6Ny7KSm0oDAACgeHSfaCLj3LN+3OM5v8eltlhCbdGESlz2LFYFAAAA9E2wvUUbK5QBAACKCyuU85C/vQ9dE32UAQAAkKeC4eQ36nz0UAYAACgqBMo5En3rWQXvvlBWomtrC78n+aGcPsoAAADIV6FI8nOslxXKAAAARYVAOUcMh0uKRmS1bOtyrpxAGQAAAHku1fKCHsoAAABFhUA5RwxPlSTJbGnoco4VygAAAMh3HSuUCZQBAACKC4FyjhjeSkmS1U2g3PGhnEAZAAAA+SoYjqnEZZfTwSbSAAAAxYRAOUdsHSuUQ10DZYfdJk+pg035AAAAkLeC4Zh89E8GAAAoOgTKOWI4SyWXu9sVylKy7QUrlAEAAJCvgpEoG/IBAAAUIUeuCyhmnjOul1Hi6fZcOYEyAAAA8lgoHJPP7cp1GQAAAMgyVijnkK3UJ8Po/reAFcoAAADIZ6FIjBXKAAAARYhAOYei776gyFM3dXvO73apmR7KAAAAyFPBcCy1mTQAAACKB4FyDlkt2xT/5F+yErEu5/welyJtCcXiiRxUBgAAAPQsGkuoLZYgUAYAAChCBMo5ZPNWSZKslsYu5/yeZD+6JtpeAAAAIM+EIskFEfRQBgAAKD4EyjlkeColSWaoocu5jkC5uaXr6mUAAAAgl4Lh5GdUeigDAAAUHwLlHDJSK5S7BsrlqUCZFcoAAADIL8FI8jMqgTIAAEDxIVDOIZsnGSib3QTKfndHy4u2rNYEAAAA9CYU7mh5QaAMAABQbBy5LqCYGc5SlR33A9kqR3c55/ckP5yzQhkAAAD5JkgPZQAAgKJFoJxjjp0ndHvc6bCrrMRBD2UAAADknWA4JsOQ3KX85wQAAECxoeVFjsX++7qibz7Z7Tm/x6WmMCuUAQAAkF9CkZi8ZU7ZDCPXpQAAACDLCJRzLLF+naJv/rXbc+VuJy0vAAAAkHeC4SjtLgAAAIoUgXKOGd5KWa1BWfGuwbHfW0KgDAAAgLwTCidXKAMAAKD4ECjnmM1TJUmywtu6nCt3uwiUAQAAkHdCkZh8BMoAAABFiUA5x4z2QNkMNXQ55/c4FW6LKxY3s10WAAAA0KNkywsCZQAAgGJEoJxjNm/7CuWW7gLlZF+6IBvzAQAAIE+YlqVQJC4vgTIAAEBRIlDOMcNbJdfXTpFt2C5dznUEyk20vQAAAECeCLfGZVqWfGVsygcAAFCMHLkuoNgZjhKV7F/b7bmOQJk+ygAAAMgXoUhMklihDAAAUKRYoZwHEps/UnzD212Ol7sJlAEAAJBfOtqx0UMZAACgOBEo54G2Nx5V26sPdjmeWqFMD2UAAADkiVA4uUKZlhcAAADFiUA5D9g8VbJCXTflczntKnXZ6aEMAACAvBHsaHlRxgplAACAYkSgnAcMb5Ws1qCseNfg2O9x0fICAAAAeaOj5QU9lAEAAIoTgXIesHkqJUlWeFuXcwTKAAAAyCehSEwup00lTnuuSwEAAEAOECjnAcNTJUkyu2l7Ue52qbm9Tx0AAACQa8FwjP7JAAAARYxAOQ/Yyqvl2G2SDGdpl3OsUAYAAEA+CUVitLsAAAAoYo5cFwDJ5guo7JiLuj3n97gUisQUT5hy2Mn/AQAAkFvBcFQ+NuQDAAAoWiSUecJqDclsaexy3O9Jfp0wSNsLAAAA5IFgOCYfK5QBAACKFoFyngg/8XO1/f2eLsf97mSgTNsLAAAA5INQJCYvPZQBAACKFoFynjA8Vd1vyte+Qrk5TKAMAACA3IrFTbVGE6xQBgAAKGIEynnC5qmS1dI1UPZ7kh/WWaEMAACAXAtFkm3Y2JQPAACgeBEo5wnDWyWrNSgr3jk47uihTKAMAACAXAu2f2vOR8sLAACAokWgnCdsnkpJkvWljflKXQ65nDY1ESgDAAAgx4LtK5RpeQEAAFC8CJTzhOELyPCPkBVr7XLO73bRQxkAAAA5Fwq3t7woI1AGAAAoVo5cF4Akx6hx8p5xfbfnyj0uWl4AAAAg51ItL1ihDAAAULRYoVwA/ATKAAAAyAOhSEyGJE8pgTIAAECxIlDOIy0P/0Stq+/vcpxAGQAAAPkgGI7JU+aUzWbkuhQAAADkCIFyPrEkK7ily2G/26VgJCbTtHJQFAAAAJAUjMRodwEAAFDkCJTziOGtkhlq6HLc73HJsr7YVRsAAADIhVA4Kh8b8gEAABQ1AuU8YvNUymrpGiiXe1ySpKZQW7ZLAgAAAFKCkZi8bleuywAAAEAOZTRQXrFihWbOnKnp06fr/vu79gbu8KMf/UjLly/PZCkFwfBUyWoNyop37pfsbw+Um8P0UQYAAEDuhMIxeVmhDAAAUNQyFijX1dVpyZIleuCBB/Too4/qwQcf1AcffNBlzHe/+1099dRTmSqjoNi8VZIkK7yt0/FUoMzGfAAAAMgRy7IUoocyAABA0XNk6sarV6/WlClTVFFRIUmaMWOGVq5cqYsuuig1ZsWKFTrqqKNSY4qdY/evybvbJBnO0k7H/e6OQJkeygAAAMiNSFtcCdOihzIAAECRy1igvHnzZgUCgdTr6upqrVmzptOY888/X5L0xhtvpPWMYcO86Rc4AIGAL0N37v6+lmXJ6bApZmXy2dlR6PXnEnOXHuYtfcxd+pi79DF36WPukGnBcHJxg48eygAAAEUtY4GyaZoyDCP12rKsTq8Hw9atIZmmNaj37E0g4FN9fTAj97bMhFqf/o0cux4g516HdTrnd7u0qT6UsWdnQybnbqhj7tLDvKWPuUsfc5c+5i59zF3/2GxGzhYmFLJgJBkoe2l5AQAAUNQy1kO5pqZG9fX1qdf19fWqrq7O1OOGBMNmV3zT+0rU/7fLOb/HxaZ8AAAAyJlg+2dReigDAAAUt4wFylOnTtXLL7+shoYGRSIRrVq1StOmTcvU44YMm6dKZqihy/Fyj4tN+QAAAJAzofaWF156KAMAABS1jAXKI0aM0IIFCzR37lyddNJJmjVrliZOnKgLLrhAa9euzdRjC57hqZTV0jVQ9nucBMoAAADImVB7ywtfGT2UAQAAilnGeihLUm1trWprazsdW7p0aZdxP//5zzNZRkGxeasU76HlRTAck2lZsg1yL2oAAACgN8FwTC6HTSUue65LAQAAQA5lbIUy0mN4qmS1BmXFO69G9rtdMi0rtTIEAAAAyKZgJMqGfAAAAMjsCmX0n3PswbKPGifZOmf9fk/yq4XNLVH53XzNEAAAIBdWrFih22+/XfF4XOecc47OPvvs1Ll33nlHixYtSr1uaGhQeXm5nnjiCT3yyCO68cYbNWzYMEnSEUccoQULFmS9/oEIhmO0uwAAAACBcr6x+QOy+QNdjpdvFyir62kAAABkWF1dnZYsWaLly5fL5XLpjDPO0OTJkzV27FhJ0t57763HHntMkhSJRHTqqafqyiuvlCStW7dOixYt0qxZs3JV/oCFIjFWKAMAAICWF/nGikbU9vojSmz6T6fjNVVuOeyG/r7m8xxVBgAAUNxWr16tKVOmqKKiQm63WzNmzNDKlSu7HXvHHXfoa1/7miZNmiRJWrt2rR555BHV1tbqhz/8oZqamrJZ+qAIhqPyESgDAAAUPVYo5xubTdH/e0yyO2Sv2TN1uNxboplTxujxf3ysQyaO1L67VuWwSAAAgOKzefNmBQJffFWsurpaa9as6TIuGAxq2bJlWrFiRepYIBDQueeeqwMOOEC/+tWvtHjxYt144439ev6wYd70ix+AQMAnSWppjau6ypN6jR1jntLH3KWPuUsfc5c+5i49zFv6mLvcI1DOM4ajRCrxyGpp7HLu+IPH6JW363TfU+9p8XkHyelgh20AAIBsMU1ThmGkXluW1el1h8cff1xHH310ql+yJN16662pn88//3wdc8wx/X7+1q0hmabV7+sGIhDwqb4+qHjCVLg1Lrss1dcHs1pDIeqYN/Qfc5c+5i59zF36mLv0MG/pY+76x2YzMrIogZYXecjmrZIZauhy3Omw65vT91JdY0RPvvJpDioDAAAoXjU1Naqvr0+9rq+vV3V1dZdxzzzzjGbOnJl6HQwGdffdd6deW5Ylu72wFgYEwzFJko/NoQEAAIoegXIeMjxVslq6BsqStO9uVZq8zwj95eWPVdcQznJlAAAAxWvq1Kl6+eWX1dDQoEgkolWrVmnatGmdxliWpbfeekv7779/6pjb7dbvfvc7vfnmm5Kk++67L60VyrkUiiQDZW8ZPZQBAACKHYFyHrJ5qrptedHhjK+PldNh172r3pNlZfdrjwAAAMVqxIgRWrBggebOnauTTjpJs2bN0sSJE3XBBRdo7dq1kqSGhgY5nU6VlJSkrrPb7fr1r3+tK6+8Uscdd5zeeustLVy4MFdvIy3BcFSS2JQPAAAA9FDOR47dvyZb5ege+/KVe0t0yuG7675V7+vVd+o0ZZ+aHFQJAABQfGpra1VbW9vp2NKlS1M/Dxs2TP/4xz+6XDdp0iQ98sgjGa8vU1IrlGl5AQAAUPRYoZyHHKP3kWv80d2GyR2O2G+0dhvp05+e/UDh1lgWqwMAAECxSfVQpuUFAABA0SNQzkNWNKLYx2/IDG3tcYzNZmjujHEKhqN6+MWPslgdAAAAik0wHJUhyVPGFxwBAACKHYFyHrIiTWpddYsSG9/d4bgxNT4ddeBOev7/Nuijjc1Zqg4AAADFJhSJyV3qkN3Gfz4AAAAUOz4R5iHDUylJMlsaeh37jcN2V7nXpXueelcJ08x0aQAAAChCwXBMPvonAwAAQATKeclwlMgo8cpqaex1bFmJQ2cd/RV9WhfSc29syEJ1AAAAKDahSExeN/2TAQAAQKCctwxv5Q57KG/vwL0CGr97lZa/9JEag20ZrgwAAADFJhiOsiEfAAAAJBEo5y3DU9WnFcqSZBiG5hzzFZmmpT8+836GKwMAAECxCUZi8rFCGQAAACJQzluOncbLPnJcn8dXV7o1a+quev29eq35sG8rmwEAAIDeWJalUDgmbxk9lAEAAECgnLdc449R6dSz+nXNsQftopHD3Lpv1XuKxhIZqgwAAADFJNKWUMK0WKEMAAAASQTKecsyEzKDW2TF+94T2emw6ZvT99KWplY98fLHmSsOAAAARSMUiUqSvPRQBgAAgAiU81bi8/fU8scfKlH3Yb+uGzemUlPH1+ivr3yqjVtaMlQdAAAAikUwHJMk+dy0vAAAAACBct6yeaskqc8b823vtCPHqtRl171PvSfLsga7NAAAABSRYKQjUGaFMgAAAAiU85bhqZQkmaH+b7Dn97g0+4g99N76bVq9btNglwYAAIAiEgwnW174aHkBAAAAESjnLcNRIqPEm9YKZUk67KujtMdovx587gOF2leVAAAAAP3V8VnSywplAAAAiEA5rxneSpktDWldazMMzZ0xTuHWuB56vn99mAEAAIAOoXBMDrtNJU57rksBAABAHiBQzmO2qp1lOErSvn7naq+mf21nvfjmRn3wWdMgVgYAAIBiEQzH5HM7ZRhGrksBAABAHiBQzmNlR35bZUd/b0D3OOHQXVXlL9E9T72reMIcpMoAAABQLEKRGP2TAQAAkEKgPMSVuhw6++iv6LP6Fj3z+me5LgcAAAAFJhiOykf/ZAAAALQjUM5j8Y//peDd35PZtGlA99n/KwHtN3a4Hv37R9ra1DpI1QEAAKAYBCMxed2uXJcBAACAPEGgnM9cpVI0LLOlccC3OuuYPSVJDzzz/oDvBQAAgOIRCtPyAgAAAF8gUM5jNk+VJMkKNQz4XsPLy3TiobvpX//Zon+9Xz/g+wEAAGDoiydMhdvi8tLyAgAAAO0IlPOY4amUJJktAw+UJemYSTtrdMCj+595X63R+KDcEwAAAENXsCUqSaxQBgAAQAqBch4zHC4Zpb5BWaEsSQ67TXNn7KWG5jY9/o+PB+WeAAAAGLqaOgJleigDAACgHYFynjM8lYO2QlmS9typQtO+OlKr/rlen20ODdp9AQAAMPQ0t7RJkrysUAYAAEA7AuU85z7+Ryqb/r+Des/ZR4yVu9She556T6ZlDeq9AQAAMHQ0hTpWKBMoAwAAIIlAOc8ZpV4ZtsH9bfKWOXXakWP1wYYm/X3N54N6bwAAAAwdze0tL7y0vAAAAEA7AuU8F//kXwo/cb2seNug3veQCTX6ys4V+vPfPlBzODqo9wYAAMDQ0BEoe0odOa4EAAAA+YJAOc9ZbWElNr4jK9Q4qPc1DEPfnLGXWqMJ/flvHwzqvQEAADA0NIfa5Cl1yGHnPxsAAACQxFKDPGd4qyRJZkuDbBU1g3rv0cM9OnbyLvrLy5/o0AkjtdculYN6fwAAABS25pYoG/IBAICClEjE1dhYr3i8OL6Z73C4VFkZkN2e+biXQDnP2TzJQNlqacjI/WdN3VWvvl2ne556T1edexCrTwAAAJDS1NImH/2TAQBAAWpsrFdpqVseT40Mw8h1ORllWZZaWprV2Fiv4cNHZvx5pId5zvAkVw2bocwEyiVOu+ZM/4o+3xrWU//8NCPPAAAAQGFihTIAAChU8XhUHo9/yIfJUrK1rcfjz9pqbALlPGc4XDJKfRlboSxJE/cYrgP3Cujxf3yszdsiGXsOAAAACktTKCqfm0AZAAAUpmIIkztk870SKBeA0mMukuurMzP6jDOP2lM2m6H7V70vy7Iy+iwAAADkP8uykiuUCZQBAAAGJBQK6ZJLftjn8e+++7Z+/vOrM1jRwBAoFwDHyL1k81dn9BlV/lJ947DdtfajrXrjvfqMPgsAAAD5rzWaUDxhyldGD2UAAICBCAab9Z//vNfn8ePG7aNFi36SwYoGhk35CkB8w9tKbHhbJQfNzuhzjjpwtFav/VwPPPO+9t2tSmUl/PEAAAAoVsFITJJoeQEAADBAv/71L7VlS70uueSH+uST/6q8vEIlJSW65ppf6LrrrlZ9/WZt2VKvSZMO0qJFP9G//vWG/vCHO/Wb39ypiy76tvbZZ1+9+ea/tW1boy6+eKEOPviQnL4fEsMCkNj8kaL/fkKuA2plOEoy9hy7zaZvHruXrr3nDT360n915tF7ZuxZAAAAyG+hcDJQZlM+AABQ6P6x9nP9fc3nGbn3oRNH6pAJI3c45uKLF+r73/+O5s///3TqqSfoz3++RSNHjtLTT6/Unnt+RT/72fWKxWKaM+dUvffeu12uj8XiuuOOu/T3v7+opUtvJ1BG72yeSkmSFWqUUVGT0WftMapcR+w/Ws+8sV5Tx9doTI0vo88DAABAfgqGk7uE+9y0vAAAABgslZVVGjlylCTpmGOO1dtvr9OyZQ/o44//q6amJkUi4S7XTJ58sCRp9933UDDYnNV6u0OgXAAMb5UkyWxpkC3DgbIknXL47nrjvc2656l3ddk3J8lmK54dMQEAAJAUam95waZ8AACg0B0yofdVxNlSUvJF94GHHvqTnn/+OZ1wwjc0e/ZB+u9/P5RlWV2ucbmSf8FvGEa357ONTfkKgM2TDJStloasPM9d6tQZR+2p/34e1Av/3pCVZwIAACC/BNtbXvhoeQEAADAgdrtdiUSiy/HXXntVJ5xwsqZPP07RaFT/+c/7Mk0zBxX2DyuUC4DR3vLCDGUnUJakyfuM0EtrPtdDL3ykA74SULk3c72bAQAAkH+CkagcdptKXfZclwIAAFDQqqqGacSIGl177VWdjp922lm64YbrdN99d8nj8Wr8+In6/PONGj16pxxV2jcEygXAcLhUMvl02Wuyt0meYRj65oy99NPfv6oHn/tA3z5h36w9GwAAALkXCsfk97hkGLQ/AwAAGAiHw6Hf/vYPXY4feODX9Mc/Lu/2mgMOmCRJ+s1v7kwdGzlylB56aEVmiuwHWl4UCNdXj5N9xNisPrOmyq2ZU8bolbfr9NbH2VsdDQAAgNwLhmMq97IhHwAAADojUC4Qia3rFf/4X1l/7vEHj1F1ZZnue+o9xeJde70AAABgaApFkiuUAQAAgO3R8qJAxN75m2Ifvirfrrdm9blOh13fnL6Xbnzw3/rdE+9o15E+Oew2OR02Oe02Odr/cTqM1HGHvf2cwyaH3ZCz/XhbLCHTsmTja5MAAAB5LxiOalTAm+syAAAAkGcyGiivWLFCt99+u+LxuM455xydffbZnc6/8847uuyyy9TS0qJJkybpqquuksNBxt0dw1MltbXIirfJcGR3g7x9d6vSkfuP1t/+tUGvvbt5wPez24z2INr4IoB2bBdOb3fcsV1w3XlscozTYZfTbqTGffl+zh0cd9gNegICAAD0gBXKAAAA6E7G0tu6ujotWbJEy5cvl8vl0hlnnKHJkydr7Ngv+gAvXLhQP/vZz7Tffvvp0ksv1bJly3TWWWdlqqSCZvNWSZKsUIOMipFZf/43Z+ylM4/eU/GEqXjCUixuKpYwFY+biie++Dn5q9XpWDxhKhY35Sp1qqm5NfW6068Jq9PY1mhCsURM8fbjsS9dkzCtQXlfHaurew6dk7/abYYsSzKt5HNNy5JlSdaXf9WXjrW/Ni1J1hc/W0oONrtc13FtckzHfRx2myxLstkkm2HIZmv/x2j/Z/vjqV8lo/21vf1Xo/24zWbInnr9xfgu926/r2G0388wZBjdHzPar02dl1Kvt7/OZnQ91v39t7u3DLX/X/uvyeOSvjjf/nPy1+SRcNxSY2OL1P56+/Ptt9ruvu1jun3Gl+6/3fjU9V8at90Tuta63X2M9pts/3cbHeNMy1LCtGSaVuefze6Pb/+r1cNx0/rSMeuLe27/2u0uUWtrVHZb8s+/3Z78c+Sw2VI/2zv9bMhut3X62fHl41+6LpPfVtj+3z9r+3/3+vC64/dy+3+HOv58dvx70fFzx59TABhqEqaplta4/N7sLmQAAABA/stYoLx69WpNmTJFFRUVkqQZM2Zo5cqVuuiiiyRJGzZsUGtrq/bbbz9J0sknn6ybb76ZQLkHhqdSktT64l0qPfLbsvmGK/r235RYv6bLWOfeh8uxy36Kb/qPYm8+2eW8bcSeKtlvpqxoRK3P/67rw5ylKjvyguTzVt8vK9R5Qz67JPfUs2SrGKbouy8osX5t11vsNU2OXSYqUfeBomtWSpJKWp1qa4sl7zFiD7kmHpes4cWuu1wazlKVHn5esoZX/tSlBkuW7AeeqkRZpWLvvSRr41upcNZsD9OCI7+mYOVecjT8VxXrX0oGaduN2VY6Sh+VT5YVi2jC5r8kz5lfhLxtMaeesb6ucFtch8X+Lr9C0vbBpqS/Ow9Ri92vfeNva+fEp50DR0nvu/bVBtduqk58rvGt/ydju67lhiFtcY7Su+5JclptmhJcpe1jKcMwFDNc+r+q41RS4tSem55SWaJZ2i6AlqRXyw5T0PDpK63rNDr6iZKHrfZfpXXGOH2oXVRjfq5J1ppU0NZxn8/MgP4eHy+n2aaTXP+QUkFcclCr6dSfwlMlSSeWva4KW0uX36/HIweq0fRqsus/GufcqPb8XGb7+VejY/VubLTG2Ot1ROnbXa7/JBHQ8637qERRneF5ucv5Nqv/NUjSh9udz1UN28tGDY9FJmmb6Wmv4fNuathjuxrekaHkv9P2jhriw/V8W0cNr3S61pTUbDn0YKcawt3UcGCfatjV0V5DKrBN/rnfqGq9ak2U02rTLNtLqT/LHX+u2+TU8rZDZcnScc5/qtxoSY3p+LPd1xo65uHLdjQPktT6pXmotCfnYfuAeWXsa2qWT5Mc72usbUOXv5BYY4zTJ8bOGq06TdLaLn9B8bkxQv+y7yeXFdXRib91qSEql551HilJOiz+D3mtUJcxLzmmKmT4tE/iHY0x13c5/7Z9nD6xjVGNuUn7J978Yp7b5/FzY4T+z/ZVOa2ojjGf7zidGhWVU0/qCMmy9HW9LJ9a1PH/XDp+2542J6vJ8mp/4z3tbny23dOTz3kj/hV9YO6knYw6HeJ8W7Ks1F+0SNIGVetVa4JKFFOt7cXU3/p0zGdUTq2yHSkZ0pHm6vYatp9LQ/9wTlWLza99Em9rl8Sn251PjnjftY82OHdVIP65JkT/td17TNpsH6E1rgPktNo0rfXZ7d9CsgbDpeddX5ckTY3+XZ5ufi/+3v57sXf8be3S6ffCav+92LvT78V2UyBJqrON0JvO/eWyojoy+twXJ9rHmPYS/c11lAxDmtL6krxmqNP1kvTPsmkK233as22ddop9ou3/JsuQ9J+SfbXRtasCsY3ap7Xrvg31jhq9XXagnGabprY80+V8zHBqtXe6JGlSywvymF3n4XXPNIXtfo1tXatR0U8UclRqp2PO1i4jfF3GIrdCkbgksUIZAAAAXWQsUN68ebMCgUDqdXV1tdasWdPj+UAgoLq6un49Y9iw3PR0CwSy/x89pm9fbVq7r8zWsKoqy+Ss9KnREVdL67YuY31lNnkDPoWDNjVEGrucLzNaNSzgUyIifR6u73LeVupNvcdN0W2KtXQdU1XukrPKp8b3WhUKdf1985WY7TUY2hrcJEmKBr/4b9uS4SM0POBTImJoY3PXsMdW4vmihnC9Yk2bOp03JI0YXiZnVZUaP44qFN7U5R677OKWb/xuCn+4TVvXb9UXy08l2aWxY8bo2On7KxEJaeP/+1M38+DR8fOOSNbw4CuKNbZ2GXPUaZPkrBqlxpfqFXrr3S7npxw6Rr7xhyr84b+09emXupyfuLtbp00/uL2G5d3WcPQ5B7XX8FfFGrv+x/m00/Zrr2GTQm91DcYOOXQP+cZPa6/hn13OT919mL43fWZ7DV1DK1tpic6ce4Isy1Ldn99SvLFJqcCpYx5qD5VRXqPwq2FF3/9E0heBtCRN3n+sjLEHK/7pGpmv/Dt1ndU+8Ksjy3X8pMNktoXkWvV8pzBKkiynTeOOnCLLkryr/0+2UNffi10OGq+ENyD3e1tVtuGj7e6fPD9m7CiFRx2okvp3VPnO/6VOdDxnj2Fl+upe+8sWa9Gofz7X6d6WJNMhVU36qmRZGr3mNTkjEX15kPfQPRQtDSjw6SaVb45sd3XSsJ2Ga8rwfeTb9r52+vj17YLSpNHlTu00Zpxs8Yj2fjsZWn2R9xgyHZaqvvZV2W2GRv77dbkiEX15dfMeU8dL/hqVvtso5ycfdVo9axjSpP12l3OvQ6UNaxX9x787rciWpMN2Ha7502fJbA1p070vtv9ebBcQuuyaVntkcoXz0/+Wmlu7/CXFzl/bR9GygDzv18u98UN1BMEdb7di1DAdULmXvNsM7bb+tS/+EqT9HmVlVYoMD8hlRrTrhqYv3lz7W4nbSnXUV3eRYUj7fvIPuaNd/72wHbiL2soCGr1xo4ZtbZHxpXsEdgnosBHj5Wt8XzUfvN4pfJSk3apKtfeeE2W0tWjXfz37xe9i+w9xe6nsh42TZVn6yvuvqrTtixo6/lwfunO1Qs4q7Vb/oUY1N213eXIyWsptsrsrNCKyVaO3NXT582C5/Brud8tl2lTT0Pkv1SQpapSoqrxMklS9rVn+xLYuY4Z5HXI6yhRoaVNN69Yu5zeXJdRcWqphbZZqQl86b0gq8avG75XLatXorV1riNlK9ZWaShmSxtSH5Y01fnFt+1xOrClX2DVcX2l8X6ODXX8/rWFlGl2xk4aHWrVHXVOXufa6hyk2fITs8bB22bit0zmrvYYx1X5ZljRqS0jeeNf/3St322TYXKoIRxSIb0ld3PEslxlRNGHKnmhTVaLr/+ZFbG6p/S8+hptbvpzTqk0lKilxyjCkYbEm+c0earCXaFi4VTVtvf1ebOn0/78kKWbzqKzUKWcioUDbli9OtL+PqFUilRiyJJUnGlVubev8/0glxWNtCsfL5Iw1qyLetYWVae2ibbGoyhMtKo93/d/2bYkSNSaiKrHaVB7rer5NJWoMRSVJ7tjWZA1fEgqFtc1WIiPepHKzTrF4Qq5SV04+W2HHPKUO7Td2uL665/BclwIAAIA8Y1jWl/8TdnDcfvvtamtr08UXXyxJWrZsmdatW6fFixdLkt544w3deOONeuCBByRJH3/8sb773e9q5cqVfX7G1q0hmYPU+qCvAgGf6uuDWX3mUMHcpY+5Sw/zlj7mLn3MXfqYu/Qxd/1jsxk5W5gwEHz2LRzMW/qYu/Qxd+lj7tLH3KWHeUtff+Zu06ZPVFMzJsMVZc/vf3+HJOm8877T45gvv+dMfe619T4kPTU1Naqv/2KVT319vaqrq3s8v2XLlk7nAQAAAAAAAAD5JWMtL6ZOnapbbrlFDQ0NKisr06pVq3T11Venzo8ePVolJSV64403dOCBB+qxxx7TtGnTMlUOAAAAAAAAAOTEvfferb/97WklEqYmT56ieDyuQGCEzjxzjiTpsssWavr0mdp55521ZMkvFYlE1NjYoG9+c55OOml2jqvvLGOB8ogRI7RgwQLNnTtXsVhMs2fP1sSJE3XBBRdo/vz5mjBhgm644QZdfvnlCoVC2nfffTV37txMlQMAAAAAAACgSIVXXNftcXftJZKk1tX3y9z6aZfzJQefJfvwMYq995Ji7/+9x+t35JVXVuu9997R0qX3yDAMXX31TzVmzK565pmndOaZcxQOt2jdurW68sprddttN+ucc87TpEkHacOGzzRv3lnFEyhLUm1trWprazsdW7p0aerncePG6aGHHspkCQAAAAAAAACQM6+//k+9/fY6nXfeNyVJbW2tGjGiRtFomz77bL3Wrn1ThxxymJxOpy666GK9+urLuvfeu/Thhx8oEgnnuPquMhooAwAAAAAAAECu9baSuHTq2Ts879zrMDn3OiytZ5tmQqeddqbOOCPZ3iIYDMput8vj8erZZ1dp3bo1mjNnniTppz9dJJ/Pr0MOOUxHHTVdzzzzVFrPzKSMbcoHAAAAAAAAAMXugAO+pqeeelLhcFjxeFyXXPIDPf/8s5o+/Vg999zT+uyz9Zo4cT9J0muv/VPnn/9dHXbYEXrlldWSpEQikcPqu2KFMgAAAAAAAABkyKGHTtMHH7yvb397nkwzocmTp+q442bJMAyVl1do330nyDAMSdK5516gCy88XyUlLu2xx54aOXKUPv98Y47fQWcEygAAAAAAAACQQfPmna95887vcvzmm3/b6fUZZ8xJtcbY3nnnfSdjtfUXLS8AAAAAAAAAAH1CoAwAAAAAAAAA6BMCZQAAAAAAAABAnxAoAwAAAAAAABhyLMvKdQlZk833yqZ8AAAAQB+tWLFCt99+u+LxuM455xydffbZqXPvvPOOFi1alHrd0NCg8vJyPfHEE9q4caMWLlyorVu3arfddtMNN9wgj8eTi7cAAABQFBwOl1pamuXx+GUYRq7LySjLstTS0iyHw5WV5xEoAwAAAH1QV1enJUuWaPny5XK5XDrjjDM0efJkjR07VpK0995767HHHpMkRSIRnXrqqbryyislSVdddZXOOussHX/88br11lt12223aeHChbl6KwAAAENeZWVAjY31CoW25bqUrHA4XKqsDGTnWVl5CgAAAFDgVq9erSlTpqiiokKSNGPGDK1cuVIXXXRRl7F33HGHvva1r2nSpEmKxWJ67bXXdOutt0qSTj75ZM2ZM4dAGQAAIIPsdoeGDx+Z6zKGJAJlAAAAoA82b96sQOCLVR/V1dVas2ZNl3HBYFDLli3TihUrJEmNjY3yer1yOJIfvQOBgOrq6vr9/GHDvGlWPjCBgC8nzy10zFv6mLv0MXfpY+7Sx9ylh3lLH3OXewTKAAD8/+3da2wUZRvG8Wt7oiAaXG1BoZFARBoNSmjCWagIlHbXAjVSJBQpAiIxAgpaMSmWimQlEYVqjAcgoYFWYoU2HlASkQBBioqAcoiAAtVSKMIWS1vY5/3A677vdhe6TIXd2v/v28wz0z5z5+7kyrOzUwAIgsfj8Xn/njEm4Pv4NmzYoIcffli33XbbFY+z8h6/06dr5PHc2H8sExd3s6qq3Df0d/4bUDfrqJ111M46amcdtbOGullH7a5NRITtujyU0KIXlCMiQvNC7VD93n8DamcdtbOGullH7ayjdtZRO+uoXfCs1qpTp04qLy/3bldVVSk+Pt7vuK+++krTp0/3btvtdrndbl26dEmRkZFXPO96zbu56C1rqJt11M46amcdtbOO2llD3ayjdsG7XrWyGWNu7GMOAAAAQAtUWVmp8ePHa926dWrbtq0yMzO1cOFC9erVy3uMMUZ9+/bVli1b1KZNG+/+adOmyel0yul06p133tHJkyeVm5sbissAAAAAmiUi1BMAAAAAWoKOHTtq9uzZysrK0ujRo+VwONSrVy9NnTpVe/bskSRVV1crOjraZzFZknJzc1VcXKzU1FSVl5dr1qxZIbgCAAAAoPl4QhkAAAAAAAAAEBSeUAYAAAAAAAAABIUFZQAAAAAAAABAUFhQBgAAAAAAAAAEhQVlAAAAAAAAAEBQWFAGAAAAAAAAAASFBWUAAAAAAAAAQFBYUAYAAAAAAAAABIUF5SsoLS1VamqqRowYocLCQr/xn3/+WWPHjtXIkSM1f/58Xbx4MQSzDD/Lly9XWlqa0tLS5HK5Ao4nJycrPT1d6enpAWvbWk2cOFFpaWne2uzevdtnnJ4L7KOPPvLWLD09XX369FFeXp7PMfSdv5qaGjkcDh0/flyStG3bNjmdTo0YMUJvvPFGwHMqKio0YcIEpaSkaMaMGTp//vyNnHJYaFy3oqIiORwOOZ1O5eTkqL6+3u+ckpISDRo0yNt/V6rvv13j2uXk5GjEiBHeunz55Zd+59Bzl/1/7TZv3uxzz+vXr5+mT5/udw59h2tB7rWO7Gsd2dcasu+1I/daR/a1juxrHdm3hTDw88cff5jk5GRz5swZc/78eeN0Os2hQ4d8jklLSzPff/+9McaYnJwcU1hYGIKZhpetW7eacePGmbq6OlNfX2+ysrLMxo0bfY6ZPn26+e6770I0w/Dl8XjMoEGDTENDwxWPoeeadvDgQTN8+HBz+vRpn/30na8ffvjBOBwOc++995pjx46Z2tpaM2TIEPPbb7+ZhoYGk52dbb7++mu/86ZNm2bKysqMMcYsX77cuFyuGz31kGpct8OHD5vhw4cbt9ttPB6PmTdvnlmxYoXfeXl5eaa0tPTGTziMNK6dMcY4HA5TWVl51fNae88ZE7h2fzt58qQZNmyYOXLkiN959B2CRe61juxrHdn3n0H2bRq51zqyr3VkX+vIvi0HTygHsG3bNvXr108dOnRQu3btNHLkSH3++efe8RMnTujChQt64IEHJEljx471GW+t4uLi9OKLLyomJkbR0dHq3r27KioqfI7Zu3ev3n33XTmdTuXl5amuri5Esw0vhw8fliRlZ2frkUce0erVq33G6bngLFiwQLNnz5bdbvfZT9/5Ki4uVm5uruLj4yVJP/74o+666y4lJCQoKipKTqfTr78aGhq0c+dOjRw5UlLr7MHGdYuJiVFubq7at28vm82mHj16+N3zJGnPnj0qKSmR0+nU888/r7Nnz97oqYdc49rV1taqoqJCL730kpxOp9566y15PB6fc+i5yxrX7v+5XC5lZmaqa9eufmP0HYJF7rWO7Gsd2fefQfZtGrnXOrKvdWRf68i+LQcLygGcPHlScXFx3u34+HhVVlZecTwuLs5nvLW6++67vaHv6NGj+uyzzzRkyBDv+Pnz55WYmKi5c+eqpKRE586d09tvvx2i2YaXc+fOqX///iooKNDKlSu1du1abd261TtOzzVt27ZtunDhgkaNGuWzn77z9+qrryopKcm73dQ9T5LOnDmj9u3bKyoqSlLr7MHGdevcubMGDhwoSaqurlZhYaGGDRvmd15cXJyefvppbdiwQXfccYff11Jbg8a1O3XqlPr166dFixapuLhY5eXlWrdunc859NxljWv3t6NHj+rbb79VVlZWwPPoOwSL3Gsd2dc6sm/zkX2DQ+61juxrHdnXOrJvy8GCcgAej0c2m827bYzx2W5qvLU7dOiQsrOzNW/ePJ9Pjm666Sa999576t69u6KiopSdna3NmzeHbqJhpHfv3nK5XLr55ptlt9v16KOP+tSGnmva2rVrNXnyZL/99F3TgumvQPvowcsqKys1adIkZWRkqG/fvn7jBQUF6tOnj2w2m5588klt2bIlBLMMLwkJCSooKFB8fLzatm2riRMn+v1d0nNXV1RUpMcff1wxMTEBx+k7BIvc23xk32tH9m0+sq815N7mI/teO7Jv85F9ww8LygF06tRJVVVV3u2qqiqfx+0bj586dSrg4/it0a5du/TEE0/oueee05gxY3zGKioqfD6FM8Z4P31r7crLy7V9+3bvduPa0HNXV19fr507d+qhhx7yG6PvmtbUPU+S7Ha73G63Ll26dMVjWqNffvlFmZmZGjNmjGbOnOk37na7tXLlSu+2MUaRkZE3cIbh6cCBA/riiy+824H+Lum5q9u0aZNSU1MDjtF3uBbk3uYh+1pD9m0esq915N7mIftaQ/ZtPrJv+GFBOYABAwZo+/btqq6uVm1trTZu3KgHH3zQO965c2e1adNGu3btkiStX7/eZ7y1+v333zVz5kwtWbJEaWlpfuOxsbF6/fXXdezYMRljVFhYqOHDh4dgpuHH7XbL5XKprq5ONTU1Kikp8akNPXd1Bw4cUNeuXdWuXTu/Mfquaffff7+OHDmiX3/9VZcuXVJZWZlff0VHRyspKUmffvqpJOmTTz5p9T1YU1OjKVOm6Nlnn1V2dnbAY9q1a6f333/f+5/rV69eTf/pcshbtGiRzp49q4aGBhUVFfnVhZ67surqal24cEEJCQkBx+k7XAtyr3VkX+vIvs1D9rWO3Gsd2dc6sm/zkH3DEwvKAXTs2FGzZ89WVlaWRo8eLYfDoV69emnq1Knas2ePJGnJkiV67bXXlJKSor/++uuK73FpTT744APV1dVp8eLFSk9PV3p6utasWeOtm91uV15enmbMmKGUlBQZYwJ+Tas1Sk5O1pAhQzR69GhlZGQoIyNDvXv3pueCdOzYMXXq1MlnH30XvDZt2mjx4sV65plnlJqaqm7duiklJUWSNH/+fG3atEmSlJubq+LiYqWmpqq8vFyzZs0K4axDb926dTp16pRWrFjhvee9+eabkv5Xt8jISC1dulQLFizQqFGjtG/fPs2dOzfEMw+9nj17atq0aRo/frzS0tKUmJgoh8MhiZ4LxvHjx/3ueRJ9B2vIvdaRfa0j+zYP2dc6cq91ZF/ryL7NQ/YNTzZjjAn1JAAAAAAAAAAA4Y8nlAEAAAAAAAAAQWFBGQAAAAAAAAAQFBaUAQAAAAAAAABBYUEZAAAAAAAAABAUFpQBAAAAAAAAAEGJCvUEAAD+7rnnHvXo0UMREb6f+xUUFKhLly7/+O/avn277Hb7P/pzAQAAgGCQfQGgZWFBGQDC1KpVqwi6AAAAaBXIvgDQcrCgDAAtzI4dO7RkyRLdeeedOnz4sGJjY7V48WJ1795dbrdbr7zyivbv3y+bzabBgwdrzpw5ioqK0u7du5Wfn6/a2lpFR0dr3rx56t+/vyRp2bJl2r17t/78809NmTJFEyZMCPFVAgAAAGRfAAhHLCgDQJiaNGmSz9f+unTpooKCAknS3r179cILLygpKUlr1qzR3Llz9fHHHys/P18dOnRQaWmpGhoaNGPGDH344YeaPHmyZs6cqfz8fA0dOlR79+5VTk6O1q9fL0lKSEhQbm6ufvrpJ40bN06PPfaYoqOjQ3LdAAAAaH3IvgDQcrCgDABh6mpf++vZs6eSkpIkSRkZGcrLy9OZM2f0zTffaM2aNbLZbIqJiVFmZqZWrVqlgQMHKiIiQkOHDpUk3XfffSotLfX+PIfDIUlKTExUfX29ampqdOutt17fCwQAAAD+i+wLAC1HRNOHAADCTWRkZMB9Ho9HNpvNu8/j8ejixYuKjIz02S9JBw8e1MWLFyVJUVGXP1/8+xhjzPWaOgAAAHBNyL4AEF5YUAaAFmj//v3av3+/JKmoqEi9e/fWLbfcokGDBmn16tUyxqi+vl7FxcUaMGCAunXrJpvNpq1bt0qS9u3bp0mTJsnj8YTyMgAAAIAmkX0BILzwygsACFON3yMnSXPmzFFsbKxuv/12LV26VCdOnJDdbpfL5ZIkvfzyy8rPz5fT6VRDQ4MGDx6sp556SjExMVq2bJkWLVokl8ul6OhoLVu2TDExMaG4NAAAAMAH2RcAWg6b4bsdANCi7NixQwsXLlRZWVmopwIAAABcV2RfAAg/vPICAAAAAAAAABAUnlAGAAAAAAAAAASFJ5QBAAAAAAAAAEFhQRkAAAAAAAAAEBQWlAEAAAAAAAAAQWFBGQAAAAAAAAAQFBaUAQAAAAAAAABBYUEZAAAAAAAAABCU/wBXKejImj+wXwAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plot_performance(loss=losses, accuracy=accuracies, log_file=f'{log_dir}/hopfield_pooling_adapted.pdf')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Hopfield-based Lookup

\n", + "\n", + "In contrast to the first Hopfield setting, in which the state patterns as well as the stored patterns are directly dependent on the input, HopfieldLayer employs a trainable but fixed stored pattern matrix, which in turn acts as a learnable lookup table." + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [], + "source": [ + "latch_samples_unique = [_[r'data'] for _ in data_loader_train]\n", + "latch_samples_unique = torch.cat(latch_samples_unique).view(-1, latch_samples_unique[0].shape[2]).unique(dim=0)" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [], + "source": [ + "set_seed()\n", + "hopfield_lookup = HopfieldLayer(\n", + " input_size=latch_sequence_set.num_characters,\n", + " quantity=len(latch_samples_unique))" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [], + "source": [ + "output_projection = Linear(in_features=hopfield_lookup.output_size * latch_sequence_set.num_instances, out_features=1)\n", + "network = Sequential(hopfield_lookup, Flatten(start_dim=1), output_projection, Flatten(start_dim=0)).to(device=device)\n", + "optimiser = AdamW(params=network.parameters(), lr=1e-3)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Operate Hopfield-based Lookup

" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [], + "source": [ + "losses, accuracies = operate(\n", + " network=network,\n", + " optimiser=optimiser,\n", + " data_loader_train=data_loader_train,\n", + " data_loader_eval=data_loader_eval,\n", + " num_epochs=18)" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plot_performance(loss=losses, accuracy=accuracies, log_file=f'{log_dir}/hopfield_lookup.pdf')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Adapt Hopfield-based Lookup

\n", + "

We can now again explore the functionality of our Hopfield-based lookup layer HopfieldLayer.

\n", + "\n", + "This lookup setting is especially pronounced, if the state patterns are initialized with a subset of the training set (and optionally provide the corresponding training targets as pattern projection inputs).\n", + "\n", + "Again, additional arguments are set to increase the training as well as the validation performance of the Hopfield-based lookup.\n", + "

\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
ArgumentValue (used in this demo)Description
lookup_weights_as_separatedTrueSeparate lookup weights from lookup target weights (e.g. to set lookup target weights separately).
lookup_targets_as_trainableFalseEmploy trainable lookup target weights (used as pattern projection input).
" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [], + "source": [ + "set_seed()\n", + "hopfield_lookup = HopfieldLayer(\n", + " input_size=latch_sequence_set.num_characters,\n", + " quantity=len(latch_samples_unique),\n", + " lookup_weights_as_separated=True,\n", + " lookup_targets_as_trainable=False)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Initialize the trainable but fixed stored patterns with all unique samples from the training set. In this way, the Hopfield-based lookup already starts with meaningful stored patterns (instead of random noise). This may enhance the performance of the network, especially at the beginning of the training." + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [], + "source": [ + "with torch.no_grad():\n", + " hopfield_lookup.lookup_weights[:] = latch_samples_unique.unsqueeze(dim=0)" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [], + "source": [ + "output_projection = Linear(in_features=hopfield_lookup.output_size * latch_sequence_set.num_instances, out_features=1)\n", + "network = Sequential(hopfield_lookup, Flatten(start_dim=1), output_projection, Flatten(start_dim=0)).to(device=device)\n", + "optimiser = AdamW(params=network.parameters(), lr=1e-3)" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [], + "source": [ + "losses, accuracies = operate(\n", + " network=network,\n", + " optimiser=optimiser,\n", + " data_loader_train=data_loader_train,\n", + " data_loader_eval=data_loader_eval,\n", + " num_epochs=18)" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plot_performance(loss=losses, accuracy=accuracies, log_file=f'{log_dir}/hopfield_lookup_adapted.pdf')" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.8" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/src/mhnfs/hopfield/examples/latch_sequence/modules/__init__.py b/src/mhnfs/hopfield/examples/latch_sequence/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c11713739acb1aca3b32cca53acf0c3faf1fa9f6 --- /dev/null +++ b/src/mhnfs/hopfield/examples/latch_sequence/modules/__init__.py @@ -0,0 +1,898 @@ +import torch +import torch.nn as nn + +from math import sqrt +from torch import Tensor +from torch.nn import Module, Parameter +from typing import Optional, Tuple, Union + +from .activation import HopfieldCore + + +class Hopfield(Module): + """ + Module with underlying Hopfield association. + """ + + def __init__(self, + input_size: Optional[int] = None, + hidden_size: Optional[int] = None, + output_size: Optional[int] = None, + pattern_size: Optional[int] = None, + num_heads: int = 1, + scaling: Optional[Union[float, Tensor]] = None, + update_steps_max: Optional[Union[int, Tensor]] = 0, + update_steps_eps: Union[float, Tensor] = 1e-4, + + normalize_stored_pattern: bool = True, + normalize_stored_pattern_affine: bool = True, + normalize_state_pattern: bool = True, + normalize_state_pattern_affine: bool = True, + normalize_pattern_projection: bool = True, + normalize_pattern_projection_affine: bool = True, + normalize_hopfield_space: bool = False, + normalize_hopfield_space_affine: bool = False, + stored_pattern_as_static: bool = False, + state_pattern_as_static: bool = False, + pattern_projection_as_static: bool = False, + pattern_projection_as_connected: bool = False, + stored_pattern_size: Optional[int] = None, + pattern_projection_size: Optional[int] = None, + + batch_first: bool = True, + association_activation: Optional[str] = None, + dropout: float = 0.0, + input_bias: bool = True, + concat_bias_pattern: bool = False, + add_zero_association: bool = False, + disable_out_projection: bool = False + ): + """ + Initialise new instance of a Hopfield module. + + :param input_size: depth of the input (state pattern) + :param hidden_size: depth of the association space + :param output_size: depth of the output projection + :param pattern_size: depth of patterns to be selected + :param num_heads: amount of parallel association heads + :param scaling: scaling of association heads, often represented as beta (one entry per head) + :param update_steps_max: maximum count of association update steps (None equals to infinity) + :param update_steps_eps: minimum difference threshold between two consecutive association update steps + :param normalize_stored_pattern: apply normalization on stored patterns + :param normalize_stored_pattern_affine: additionally enable affine normalization of stored patterns + :param normalize_state_pattern: apply normalization on state patterns + :param normalize_state_pattern_affine: additionally enable affine normalization of state patterns + :param normalize_pattern_projection: apply normalization on the pattern projection + :param normalize_pattern_projection_affine: additionally enable affine normalization of pattern projection + :param normalize_hopfield_space: enable normalization of patterns in the Hopfield space + :param normalize_hopfield_space_affine: additionally enable affine normalization of patterns in Hopfield space + :param stored_pattern_as_static: interpret specified stored patterns as being static + :param state_pattern_as_static: interpret specified state patterns as being static + :param pattern_projection_as_static: interpret specified pattern projections as being static + :param pattern_projection_as_connected: connect pattern projection with stored pattern + :param stored_pattern_size: depth of input (stored pattern) + :param pattern_projection_size: depth of input (pattern projection) + :param batch_first: flag for specifying if the first dimension of data fed to "forward" reflects the batch size + :param association_activation: additional activation to be applied on the result of the Hopfield association + :param dropout: dropout probability applied on the association matrix + :param input_bias: bias to be added to input (state and stored pattern as well as pattern projection) + :param concat_bias_pattern: bias to be concatenated to stored pattern as well as pattern projection + :param add_zero_association: add a new batch of zeros to stored pattern as well as pattern projection + :param disable_out_projection: disable output projection + """ + super(Hopfield, self).__init__() + assert type(batch_first) == bool, f'"batch_first" needs to be a boolean, not {type(batch_first)}.' + assert (association_activation is None) or (type(association_activation) == str) + + # Initialise Hopfield association module. + self.association_core = HopfieldCore( + embed_dim=input_size, num_heads=num_heads, dropout=dropout, bias=input_bias, + add_bias_kv=concat_bias_pattern, add_zero_attn=add_zero_association, kdim=stored_pattern_size, + vdim=pattern_projection_size, head_dim=hidden_size, pattern_dim=pattern_size, out_dim=output_size, + disable_out_projection=disable_out_projection, key_as_static=stored_pattern_as_static, + query_as_static=state_pattern_as_static, value_as_static=pattern_projection_as_static, + value_as_connected=pattern_projection_as_connected, normalize_pattern=normalize_hopfield_space, + normalize_pattern_affine=normalize_hopfield_space_affine) + self.association_activation = None + if association_activation is not None: + self.association_activation = getattr(torch, association_activation, None) + + # Initialise stored pattern normalization. + self.norm_stored_pattern = None + if normalize_stored_pattern_affine: + assert normalize_stored_pattern, "affine normalization without normalization has no effect." + if normalize_stored_pattern: + normalized_shape = input_size if stored_pattern_size is None else stored_pattern_size + assert normalized_shape is not None, "stored pattern size required for setting up normalisation" + self.norm_stored_pattern = nn.LayerNorm( + normalized_shape=normalized_shape, elementwise_affine=normalize_stored_pattern_affine) + + # Initialise state pattern normalization. + self.norm_state_pattern = None + if normalize_state_pattern_affine: + assert normalize_state_pattern, "affine normalization without normalization has no effect." + if normalize_state_pattern: + assert input_size is not None, "input size required for setting up normalisation" + self.norm_state_pattern = nn.LayerNorm( + normalized_shape=input_size, elementwise_affine=normalize_state_pattern_affine) + + # Initialise pattern projection normalization. + self.norm_pattern_projection = None + if normalize_pattern_projection_affine: + assert normalize_pattern_projection, "affine normalization without normalization has no effect." + if normalize_pattern_projection: + normalized_shape = input_size if pattern_projection_size is None else pattern_projection_size + assert normalized_shape is not None, "pattern projection size required for setting up normalisation" + self.norm_pattern_projection = nn.LayerNorm( + normalized_shape=normalized_shape, elementwise_affine=normalize_pattern_projection_affine) + + # Initialise remaining auxiliary properties. + if self.association_core.static_execution: + self.__scaling = 1.0 if scaling is None else scaling + else: + assert self.association_core.head_dim > 0, f'invalid hidden dimension encountered.' + self.__scaling = (1.0 / sqrt(self.association_core.head_dim)) if scaling is None else scaling + self.__batch_first = batch_first + self.__update_steps_max = update_steps_max + self.__update_steps_eps = update_steps_eps + self.reset_parameters() + + def reset_parameters(self) -> None: + """ + Reset Hopfield association. + + :return: None + """ + for module in (self.association_core, self.norm_stored_pattern, + self.norm_state_pattern, self.norm_pattern_projection): + if hasattr(module, r'reset_parameters'): + module.reset_parameters() + + def _maybe_transpose(self, *args: Tuple[Tensor, ...]) -> Union[Tensor, Tuple[Tensor, ...]]: + """ + Eventually transpose specified data. + + :param args: tensors to eventually transpose (dependent on the state of "batch_first") + :return: eventually transposed tensors + """ + transposed_result = tuple(_.transpose(0, 1) for _ in args) if self.__batch_first else args + return transposed_result[0] if len(transposed_result) == 1 else transposed_result + + def _associate(self, data: Union[Tensor, Tuple[Tensor, Tensor, Tensor]], + return_raw_associations: bool = False, return_projected_patterns: bool = False, + stored_pattern_padding_mask: Optional[Tensor] = None, + association_mask: Optional[Tensor] = None) -> Tuple[Optional[Tensor], ...]: + """ + Apply Hopfield association module on specified data. + + :param data: data to be processed by Hopfield core module + :param return_raw_associations: return raw association (softmax) values, unmodified + :param return_projected_patterns: return pattern projection values, unmodified + :param stored_pattern_padding_mask: mask to be applied on stored patterns + :param association_mask: mask to be applied on inner association matrix + :return: Hopfield-processed input data + """ + assert (type(data) == Tensor) or ((type(data) == tuple) and (len(data) == 3)), \ + r'either one tensor to be used as "stored pattern", "state pattern" and' \ + r' "pattern_projection" must be provided, or three separate ones.' + if type(data) == Tensor: + stored_pattern, state_pattern, pattern_projection = data, data, data + else: + stored_pattern, state_pattern, pattern_projection = data + + # Optionally transpose data. + stored_pattern, state_pattern, pattern_projection = self._maybe_transpose( + stored_pattern, state_pattern, pattern_projection) + + # Optionally apply stored pattern normalization. + if self.norm_stored_pattern is not None: + stored_pattern = self.norm_stored_pattern(input=stored_pattern.reshape( + shape=(-1, stored_pattern.shape[2]))).reshape(shape=stored_pattern.shape) + + # Optionally apply state pattern normalization. + if self.norm_state_pattern is not None: + state_pattern = self.norm_state_pattern(input=state_pattern.reshape( + shape=(-1, state_pattern.shape[2]))).reshape(shape=state_pattern.shape) + + # Optionally apply pattern projection normalization. + if self.norm_pattern_projection is not None: + pattern_projection = self.norm_pattern_projection(input=pattern_projection.reshape( + shape=(-1, pattern_projection.shape[2]))).reshape(shape=pattern_projection.shape) + + # Apply Hopfield association and optional activation function. + return self.association_core( + query=state_pattern, key=stored_pattern, value=pattern_projection, + key_padding_mask=stored_pattern_padding_mask, need_weights=False, attn_mask=association_mask, + scaling=self.__scaling, update_steps_max=self.__update_steps_max, update_steps_eps=self.__update_steps_eps, + return_raw_associations=return_raw_associations, return_pattern_projections=return_projected_patterns) + + def forward(self, input: Union[Tensor, Tuple[Tensor, Tensor, Tensor]], + stored_pattern_padding_mask: Optional[Tensor] = None, + association_mask: Optional[Tensor] = None) -> Tensor: + """ + Apply Hopfield association on specified data. + + :param input: data to be processed by Hopfield association module + :param stored_pattern_padding_mask: mask to be applied on stored patterns + :param association_mask: mask to be applied on inner association matrix + :return: Hopfield-processed input data + """ + association_output = self._maybe_transpose(self._associate( + data=input, return_raw_associations=False, + stored_pattern_padding_mask=stored_pattern_padding_mask, + association_mask=association_mask)[0]) + if self.association_activation is not None: + association_output = self.association_activation(association_output) + return association_output + + def get_association_matrix(self, input: Union[Tensor, Tuple[Tensor, Tensor, Tensor]], + stored_pattern_padding_mask: Optional[Tensor] = None, + association_mask: Optional[Tensor] = None) -> Tensor: + """ + Fetch Hopfield association matrix gathered by passing through the specified data. + + :param input: data to be passed through the Hopfield association + :param stored_pattern_padding_mask: mask to be applied on stored patterns + :param association_mask: mask to be applied on inner association matrix + :return: association matrix as computed by the Hopfield core module + """ + with torch.no_grad(): + return self._associate( + data=input, return_raw_associations=True, + stored_pattern_padding_mask=stored_pattern_padding_mask, + association_mask=association_mask)[2] + + def get_projected_pattern_matrix(self, input: Union[Tensor, Tuple[Tensor, Tensor, Tensor]], + stored_pattern_padding_mask: Optional[Tensor] = None, + association_mask: Optional[Tensor] = None) -> Tensor: + """ + Fetch Hopfield projected pattern matrix gathered by passing through the specified data. + + :param input: data to be passed through the Hopfield association + :param stored_pattern_padding_mask: mask to be applied on stored patterns + :param association_mask: mask to be applied on inner association matrix + :return: pattern projection matrix as computed by the Hopfield core module + """ + with torch.no_grad(): + return self._associate( + data=input, return_projected_patterns=True, + stored_pattern_padding_mask=stored_pattern_padding_mask, + association_mask=association_mask)[3] + + @property + def batch_first(self) -> bool: + return self.__batch_first + + @property + def scaling(self) -> Union[float, Tensor]: + return self.__scaling.clone() if type(self.__scaling) == Tensor else self.__scaling + + @property + def stored_pattern_dim(self) -> Optional[int]: + return self.association_core.kdim + + @property + def state_pattern_dim(self) -> Optional[int]: + return self.association_core.embed_dim + + @property + def pattern_projection_dim(self) -> Optional[int]: + return self.association_core.vdim + + @property + def input_size(self) -> Optional[int]: + return self.state_pattern_dim + + @property + def hidden_size(self) -> Optional[int]: + return self.association_core.head_dim + + @property + def output_size(self) -> Optional[int]: + return self.association_core.out_dim + + @property + def pattern_size(self) -> Optional[int]: + return self.association_core.pattern_dim + + @property + def update_steps_max(self) -> Optional[Union[int, Tensor]]: + return self.__update_steps_max.clone() if type(self.__update_steps_max) == Tensor else self.__update_steps_max + + @property + def update_steps_eps(self) -> Optional[Union[float, Tensor]]: + return self.__update_steps_eps.clone() if type(self.__update_steps_eps) == Tensor else self.__update_steps_eps + + @property + def stored_pattern_as_static(self) -> bool: + return self.association_core.key_as_static + + @property + def state_pattern_as_static(self) -> bool: + return self.association_core.query_as_static + + @property + def pattern_projection_as_static(self) -> bool: + return self.association_core.value_as_static + + @property + def normalize_stored_pattern(self) -> bool: + return self.norm_stored_pattern is not None + + @property + def normalize_stored_pattern_affine(self) -> bool: + return self.normalize_stored_pattern and self.norm_stored_pattern.elementwise_affine + + @property + def normalize_state_pattern(self) -> bool: + return self.norm_state_pattern is not None + + @property + def normalize_state_pattern_affine(self) -> bool: + return self.normalize_state_pattern and self.norm_state_pattern.elementwise_affine + + @property + def normalize_pattern_projection(self) -> bool: + return self.norm_pattern_projection is not None + + @property + def normalize_pattern_projection_affine(self) -> bool: + return self.normalize_pattern_projection and self.norm_pattern_projection.elementwise_affine + + @property + def normalize_hopfield_space(self) -> bool: + return self.hopfield.normalize_hopfield_space + + @property + def normalize_hopfield_space_affine(self) -> bool: + return self.hopfield.normalize_hopfield_space_affine + + +class HopfieldPooling(Module): + """ + Wrapper class encapsulating a trainable but fixed state pattern and "Hopfield" in + one combined module to be used as a Hopfield-based pooling layer. + """ + + def __init__(self, + input_size: int, + hidden_size: Optional[int] = None, + output_size: Optional[int] = None, + pattern_size: Optional[int] = None, + num_heads: int = 1, + scaling: Optional[Union[float, Tensor]] = None, + update_steps_max: Optional[Union[int, Tensor]] = 0, + update_steps_eps: Union[float, Tensor] = 1e-4, + + normalize_stored_pattern: bool = True, + normalize_stored_pattern_affine: bool = True, + normalize_state_pattern: bool = True, + normalize_state_pattern_affine: bool = True, + normalize_pattern_projection: bool = True, + normalize_pattern_projection_affine: bool = True, + normalize_hopfield_space: bool = False, + normalize_hopfield_space_affine: bool = False, + stored_pattern_as_static: bool = False, + state_pattern_as_static: bool = False, + pattern_projection_as_static: bool = False, + pattern_projection_as_connected: bool = False, + stored_pattern_size: Optional[int] = None, + pattern_projection_size: Optional[int] = None, + + batch_first: bool = True, + association_activation: Optional[str] = None, + dropout: float = 0.0, + input_bias: bool = True, + concat_bias_pattern: bool = False, + add_zero_association: bool = False, + disable_out_projection: bool = False, + quantity: int = 1, + trainable: bool = True + ): + """ + Initialise a new instance of a Hopfield-based pooling layer. + + :param input_size: depth of the input (state pattern) + :param hidden_size: depth of the association space + :param output_size: depth of the output projection + :param pattern_size: depth of patterns to be selected + :param num_heads: amount of parallel association heads + :param scaling: scaling of association heads, often represented as beta (one entry per head) + :param update_steps_max: maximum count of association update steps (None equals to infinity) + :param update_steps_eps: minimum difference threshold between two consecutive association update steps + :param normalize_stored_pattern: apply normalization on stored patterns + :param normalize_stored_pattern_affine: additionally enable affine normalization of stored patterns + :param normalize_state_pattern: apply normalization on state patterns + :param normalize_state_pattern_affine: additionally enable affine normalization of state patterns + :param normalize_pattern_projection: apply normalization on the pattern projection + :param normalize_pattern_projection_affine: additionally enable affine normalization of pattern projection + :param normalize_hopfield_space: enable normalization of patterns in the Hopfield space + :param normalize_hopfield_space_affine: additionally enable affine normalization of patterns in Hopfield space + :param stored_pattern_as_static: interpret specified stored patterns as being static + :param state_pattern_as_static: interpret specified state patterns as being static + :param pattern_projection_as_static: interpret specified pattern projections as being static + :param pattern_projection_as_connected: connect pattern projection with stored pattern + :param stored_pattern_size: depth of input (stored pattern) + :param pattern_projection_size: depth of input (pattern projection) + :param batch_first: flag for specifying if the first dimension of data fed to "forward" reflects the batch size + :param association_activation: additional activation to be applied on the result of the Hopfield association + :param dropout: dropout probability applied on the association matrix + :param input_bias: bias to be added to input (state and stored pattern as well as pattern projection) + :param concat_bias_pattern: bias to be concatenated to stored pattern as well as pattern projection + :param add_zero_association: add a new batch of zeros to stored pattern as well as pattern projection + :param disable_out_projection: disable output projection + :param quantity: amount of state patterns + :param trainable: state pattern used for pooling is trainable + """ + super(HopfieldPooling, self).__init__() + self.hopfield = Hopfield( + input_size=input_size, hidden_size=hidden_size, output_size=output_size, pattern_size=pattern_size, + num_heads=num_heads, scaling=scaling, update_steps_max=update_steps_max, update_steps_eps=update_steps_eps, + normalize_stored_pattern=normalize_stored_pattern, + normalize_stored_pattern_affine=normalize_stored_pattern_affine, + normalize_state_pattern=normalize_state_pattern, + normalize_state_pattern_affine=normalize_state_pattern_affine, + normalize_pattern_projection=normalize_pattern_projection, + normalize_pattern_projection_affine=normalize_pattern_projection_affine, + normalize_hopfield_space=normalize_hopfield_space, + normalize_hopfield_space_affine=normalize_hopfield_space_affine, + stored_pattern_as_static=stored_pattern_as_static, state_pattern_as_static=state_pattern_as_static, + pattern_projection_as_static=pattern_projection_as_static, + pattern_projection_as_connected=pattern_projection_as_connected, stored_pattern_size=stored_pattern_size, + pattern_projection_size=pattern_projection_size, batch_first=batch_first, + association_activation=association_activation, dropout=dropout, input_bias=input_bias, + concat_bias_pattern=concat_bias_pattern, add_zero_association=add_zero_association, + disable_out_projection=disable_out_projection) + self._quantity = quantity + pooling_weight_size = self.hopfield.hidden_size if state_pattern_as_static else self.hopfield.input_size + self.pooling_weights = nn.Parameter(torch.empty(size=(*( + (1, quantity) if batch_first else (quantity, 1) + ), input_size if pooling_weight_size is None else pooling_weight_size)), requires_grad=trainable) + self.reset_parameters() + + def reset_parameters(self) -> None: + """ + Reset pooling weights and underlying Hopfield association. + + :return: None + """ + if hasattr(self.hopfield, r'reset_parameters'): + self.hopfield.reset_parameters() + + # Explicitly initialise pooling weights. + nn.init.normal_(self.pooling_weights, mean=0.0, std=0.02) + + def _prepare_input(self, input: Union[Tensor, Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tensor, Tensor]: + """ + Prepare input for Hopfield association. + + :param input: data to be prepared + :return: stored pattern, expanded state pattern as well as pattern projection + """ + assert (type(input) == Tensor) or ((type(input) == tuple) and (len(input) == 2)), \ + r'either one tensor to be used as "stored pattern" and' \ + r' "pattern_projection" must be provided, or two separate ones.' + if type(input) == Tensor: + stored_pattern, pattern_projection = input, input + else: + stored_pattern, pattern_projection = input + + batch_size = stored_pattern.shape[0 if self.batch_first else 1] + return stored_pattern, self.pooling_weights.expand(size=(*( + (batch_size, self.quantity) if self.batch_first else (self.quantity, batch_size) + ), self.pooling_weights.shape[2])), pattern_projection + + def forward(self, input: Union[Tensor, Tuple[Tensor, Tensor]], stored_pattern_padding_mask: Optional[Tensor] = None, + association_mask: Optional[Tensor] = None) -> Tensor: + """ + Compute Hopfield-based pooling on specified data. + + :param input: data to be pooled + :param stored_pattern_padding_mask: mask to be applied on stored patterns + :param association_mask: mask to be applied on inner association matrix + :return: Hopfield-pooled input data + """ + return self.hopfield( + input=self._prepare_input(input=input), + stored_pattern_padding_mask=stored_pattern_padding_mask, + association_mask=association_mask).flatten(start_dim=1) + + def get_association_matrix(self, input: Union[Tensor, Tuple[Tensor, Tensor]], + stored_pattern_padding_mask: Optional[Tensor] = None, + association_mask: Optional[Tensor] = None) -> Tensor: + """ + Fetch Hopfield association matrix used for pooling gathered by passing through the specified data. + + :param input: data to be passed through the Hopfield association + :param stored_pattern_padding_mask: mask to be applied on stored patterns + :param association_mask: mask to be applied on inner association matrix + :return: association matrix as computed by the Hopfield core module + """ + with torch.no_grad(): + return self.hopfield.get_association_matrix( + input=self._prepare_input(input=input), + stored_pattern_padding_mask=stored_pattern_padding_mask, + association_mask=association_mask) + + def get_projected_pattern_matrix(self, input: Union[Tensor, Tuple[Tensor, Tensor]], + stored_pattern_padding_mask: Optional[Tensor] = None, + association_mask: Optional[Tensor] = None) -> Tensor: + """ + Fetch Hopfield projected pattern matrix gathered by passing through the specified data. + + :param input: data to be passed through the Hopfield association + :param stored_pattern_padding_mask: mask to be applied on stored patterns + :param association_mask: mask to be applied on inner association matrix + :return: pattern projection matrix as computed by the Hopfield core module + """ + with torch.no_grad(): + return self.hopfield.get_projected_pattern_matrix( + input=self._prepare_input(input=input), + stored_pattern_padding_mask=stored_pattern_padding_mask, + association_mask=association_mask) + + @property + def batch_first(self) -> bool: + return self.hopfield.batch_first + + @property + def scaling(self) -> Union[float, Tensor]: + return self.hopfield.scaling + + @property + def stored_pattern_dim(self) -> Optional[int]: + return self.hopfield.stored_pattern_dim + + @property + def state_pattern_dim(self) -> Optional[int]: + return self.hopfield.state_pattern_dim + + @property + def pattern_projection_dim(self) -> Optional[int]: + return self.hopfield.pattern_projection_dim + + @property + def input_size(self) -> Optional[int]: + return self.hopfield.input_size + + @property + def hidden_size(self) -> int: + return self.hopfield.hidden_size + + @property + def output_size(self) -> Optional[int]: + return self.hopfield.output_size + + @property + def pattern_size(self) -> Optional[int]: + return self.hopfield.pattern_size + + @property + def quantity(self) -> int: + return self._quantity + + @property + def update_steps_max(self) -> Optional[Union[int, Tensor]]: + return self.hopfield.update_steps_max + + @property + def update_steps_eps(self) -> Optional[Union[float, Tensor]]: + return self.hopfield.update_steps_eps + + @property + def stored_pattern_as_static(self) -> bool: + return self.hopfield.stored_pattern_as_static + + @property + def state_pattern_as_static(self) -> bool: + return self.hopfield.state_pattern_as_static + + @property + def pattern_projection_as_static(self) -> bool: + return self.hopfield.pattern_projection_as_static + + @property + def normalize_stored_pattern(self) -> bool: + return self.hopfield.normalize_stored_pattern + + @property + def normalize_stored_pattern_affine(self) -> bool: + return self.hopfield.normalize_stored_pattern_affine + + @property + def normalize_state_pattern(self) -> bool: + return self.hopfield.normalize_state_pattern + + @property + def normalize_state_pattern_affine(self) -> bool: + return self.hopfield.normalize_state_pattern_affine + + @property + def normalize_pattern_projection(self) -> bool: + return self.hopfield.normalize_pattern_projection + + @property + def normalize_pattern_projection_affine(self) -> bool: + return self.hopfield.normalize_pattern_projection_affine + + +class HopfieldLayer(Module): + """ + Wrapper class encapsulating a trainable but fixed stored pattern, pattern projection and "Hopfield" in + one combined module to be used as a Hopfield-based pooling layer. + """ + + def __init__(self, + input_size: int, + hidden_size: Optional[int] = None, + output_size: Optional[int] = None, + pattern_size: Optional[int] = None, + num_heads: int = 1, + scaling: Optional[Union[float, Tensor]] = None, + update_steps_max: Optional[Union[int, Tensor]] = 0, + update_steps_eps: Union[float, Tensor] = 1e-4, + lookup_weights_as_separated: bool = False, + lookup_targets_as_trainable: bool = True, + + normalize_stored_pattern: bool = True, + normalize_stored_pattern_affine: bool = True, + normalize_state_pattern: bool = True, + normalize_state_pattern_affine: bool = True, + normalize_pattern_projection: bool = True, + normalize_pattern_projection_affine: bool = True, + normalize_hopfield_space: bool = False, + normalize_hopfield_space_affine: bool = False, + stored_pattern_as_static: bool = False, + state_pattern_as_static: bool = False, + pattern_projection_as_static: bool = False, + pattern_projection_as_connected: bool = False, + stored_pattern_size: Optional[int] = None, + pattern_projection_size: Optional[int] = None, + + batch_first: bool = True, + association_activation: Optional[str] = None, + dropout: float = 0.0, + input_bias: bool = True, + concat_bias_pattern: bool = False, + add_zero_association: bool = False, + disable_out_projection: bool = False, + quantity: int = 1, + trainable: bool = True + ): + """ + Initialise a new instance of a Hopfield-based lookup layer. + + :param input_size: depth of the input (state pattern) + :param hidden_size: depth of the association space + :param output_size: depth of the output projection + :param pattern_size: depth of patterns to be selected + :param num_heads: amount of parallel association heads + :param scaling: scaling of association heads, often represented as beta (one entry per head) + :param update_steps_max: maximum count of association update steps (None equals to infinity) + :param update_steps_eps: minimum difference threshold between two consecutive association update steps + :param lookup_weights_as_separated: separate lookup weights from lookup target weights + :param lookup_targets_as_trainable: employ trainable lookup target weights (used as pattern projection input) + :param normalize_stored_pattern: apply normalization on stored patterns + :param normalize_stored_pattern_affine: additionally enable affine normalization of stored patterns + :param normalize_state_pattern: apply normalization on state patterns + :param normalize_state_pattern_affine: additionally enable affine normalization of state patterns + :param normalize_pattern_projection: apply normalization on the pattern projection + :param normalize_pattern_projection_affine: additionally enable affine normalization of pattern projection + :param normalize_hopfield_space: enable normalization of patterns in the Hopfield space + :param normalize_hopfield_space_affine: additionally enable affine normalization of patterns in Hopfield space + :param stored_pattern_as_static: interpret specified stored patterns as being static + :param state_pattern_as_static: interpret specified state patterns as being static + :param pattern_projection_as_static: interpret specified pattern projections as being static + :param pattern_projection_as_connected: connect pattern projection with stored pattern + :param stored_pattern_size: depth of input (stored pattern) + :param pattern_projection_size: depth of input (pattern projection) + :param batch_first: flag for specifying if the first dimension of data fed to "forward" reflects the batch size + :param association_activation: additional activation to be applied on the result of the Hopfield association + :param dropout: dropout probability applied on the association matrix + :param input_bias: bias to be added to input (state and stored pattern as well as pattern projection) + :param concat_bias_pattern: bias to be concatenated to stored pattern as well as pattern projection + :param add_zero_association: add a new batch of zeros to stored pattern as well as pattern projection + :param disable_out_projection: disable output projection + :param quantity: amount of stored patterns + :param trainable: stored pattern used for lookup is trainable + """ + super(HopfieldLayer, self).__init__() + self.hopfield = Hopfield( + input_size=input_size, hidden_size=hidden_size, output_size=output_size, pattern_size=pattern_size, + num_heads=num_heads, scaling=scaling, update_steps_max=update_steps_max, update_steps_eps=update_steps_eps, + normalize_stored_pattern=normalize_stored_pattern, + normalize_stored_pattern_affine=normalize_stored_pattern_affine, + normalize_state_pattern=normalize_state_pattern, + normalize_state_pattern_affine=normalize_state_pattern_affine, + normalize_pattern_projection=normalize_pattern_projection, + normalize_pattern_projection_affine=normalize_pattern_projection_affine, + normalize_hopfield_space=normalize_hopfield_space, + normalize_hopfield_space_affine=normalize_hopfield_space_affine, + stored_pattern_as_static=stored_pattern_as_static, state_pattern_as_static=state_pattern_as_static, + pattern_projection_as_static=pattern_projection_as_static, + pattern_projection_as_connected=pattern_projection_as_connected, stored_pattern_size=stored_pattern_size, + pattern_projection_size=pattern_projection_size, batch_first=batch_first, + association_activation=association_activation, dropout=dropout, input_bias=input_bias, + concat_bias_pattern=concat_bias_pattern, add_zero_association=add_zero_association, + disable_out_projection=disable_out_projection) + self._quantity = quantity + lookup_weight_size = self.hopfield.hidden_size if stored_pattern_as_static else self.hopfield.stored_pattern_dim + self.lookup_weights = nn.Parameter(torch.empty(size=(*( + (1, quantity) if batch_first else (quantity, 1) + ), input_size if lookup_weight_size is None else lookup_weight_size)), requires_grad=trainable) + + if lookup_weights_as_separated: + target_weight_size = self.lookup_weights.shape[ + 2] if pattern_projection_size is None else pattern_projection_size + self.target_weights = nn.Parameter(torch.empty(size=(*( + (1, quantity) if batch_first else (quantity, 1) + ), target_weight_size)), requires_grad=lookup_targets_as_trainable) + else: + self.register_parameter(name=r'target_weights', param=None) + self.reset_parameters() + + def reset_parameters(self) -> None: + """ + Reset lookup and lookup target weights, including underlying Hopfield association. + + :return: None + """ + if hasattr(self.hopfield, r'reset_parameters'): + self.hopfield.reset_parameters() + + # Explicitly initialise lookup and target weights. + nn.init.normal_(self.lookup_weights, mean=0.0, std=0.02) + if self.target_weights is not None: + nn.init.normal_(self.target_weights, mean=0.0, std=0.02) + + def _prepare_input(self, input: Tensor) -> Tuple[Tensor, Tensor, Tensor]: + """ + Prepare input for Hopfield association. + + :param input: data to be prepared + :return: stored pattern, expanded state pattern as well as pattern projection + """ + batch_size = input.shape[0 if self.batch_first else 1] + stored_pattern = self.lookup_weights.expand(size=(*( + (batch_size, self.quantity) if self.batch_first else (self.quantity, batch_size) + ), self.lookup_weights.shape[2])) + if self.target_weights is None: + pattern_projection = stored_pattern + else: + pattern_projection = self.target_weights.expand(size=(*( + (batch_size, self.quantity) if self.batch_first else (self.quantity, batch_size) + ), self.target_weights.shape[2])) + + return stored_pattern, input, pattern_projection + + def forward(self, input: Tensor, stored_pattern_padding_mask: Optional[Tensor] = None, + association_mask: Optional[Tensor] = None) -> Tensor: + """ + Compute Hopfield-based lookup on specified data. + + :param input: data to used in lookup + :param stored_pattern_padding_mask: mask to be applied on stored patterns + :param association_mask: mask to be applied on inner association matrix + :return: result of Hopfield-based lookup on input data + """ + return self.hopfield( + input=self._prepare_input(input=input), + stored_pattern_padding_mask=stored_pattern_padding_mask, + association_mask=association_mask) + + def get_association_matrix(self, input: Tensor, stored_pattern_padding_mask: Optional[Tensor] = None, + association_mask: Optional[Tensor] = None) -> Tensor: + """ + Fetch Hopfield association matrix used for lookup gathered by passing through the specified data. + + :param input: data to be passed through the Hopfield association + :param stored_pattern_padding_mask: mask to be applied on stored patterns + :param association_mask: mask to be applied on inner association matrix + :return: association matrix as computed by the Hopfield core module + """ + with torch.no_grad(): + return self.hopfield.get_association_matrix( + input=self._prepare_input(input=input), + stored_pattern_padding_mask=stored_pattern_padding_mask, + association_mask=association_mask) + + def get_projected_pattern_matrix(self, input: Union[Tensor, Tuple[Tensor, Tensor]], + stored_pattern_padding_mask: Optional[Tensor] = None, + association_mask: Optional[Tensor] = None) -> Tensor: + """ + Fetch Hopfield projected pattern matrix gathered by passing through the specified data. + + :param input: data to be passed through the Hopfield association + :param stored_pattern_padding_mask: mask to be applied on stored patterns + :param association_mask: mask to be applied on inner association matrix + :return: pattern projection matrix as computed by the Hopfield core module + """ + with torch.no_grad(): + return self.hopfield.get_projected_pattern_matrix( + input=self._prepare_input(input=input), + stored_pattern_padding_mask=stored_pattern_padding_mask, + association_mask=association_mask) + + @property + def batch_first(self) -> bool: + return self.hopfield.batch_first + + @property + def scaling(self) -> Union[float, Tensor]: + return self.hopfield.scaling + + @property + def stored_pattern_dim(self) -> Optional[int]: + return self.hopfield.stored_pattern_dim + + @property + def state_pattern_dim(self) -> Optional[int]: + return self.hopfield.state_pattern_dim + + @property + def pattern_projection_dim(self) -> Optional[int]: + return self.hopfield.pattern_projection_dim + + @property + def input_size(self) -> Optional[int]: + return self.hopfield.input_size + + @property + def hidden_size(self) -> int: + return self.hopfield.hidden_size + + @property + def output_size(self) -> Optional[int]: + return self.hopfield.output_size + + @property + def pattern_size(self) -> Optional[int]: + return self.hopfield.pattern_size + + @property + def quantity(self) -> int: + return self._quantity + + @property + def update_steps_max(self) -> Optional[Union[int, Tensor]]: + return self.hopfield.update_steps_max + + @property + def update_steps_eps(self) -> Optional[Union[float, Tensor]]: + return self.hopfield.update_steps_eps + + @property + def stored_pattern_as_static(self) -> bool: + return self.hopfield.stored_pattern_as_static + + @property + def state_pattern_as_static(self) -> bool: + return self.hopfield.state_pattern_as_static + + @property + def pattern_projection_as_static(self) -> bool: + return self.hopfield.pattern_projection_as_static + + @property + def normalize_stored_pattern(self) -> bool: + return self.hopfield.normalize_stored_pattern + + @property + def normalize_stored_pattern_affine(self) -> bool: + return self.hopfield.normalize_stored_pattern_affine + + @property + def normalize_state_pattern(self) -> bool: + return self.hopfield.normalize_state_pattern + + @property + def normalize_state_pattern_affine(self) -> bool: + return self.hopfield.normalize_state_pattern_affine + + @property + def normalize_pattern_projection(self) -> bool: + return self.hopfield.normalize_pattern_projection + + @property + def normalize_pattern_projection_affine(self) -> bool: + return self.hopfield.normalize_pattern_projection_affine diff --git a/src/mhnfs/hopfield/examples/latch_sequence/modules/activation.py b/src/mhnfs/hopfield/examples/latch_sequence/modules/activation.py new file mode 100644 index 0000000000000000000000000000000000000000..6dcd6e56cddb954cf6b049687bdf5e7783aa2bc9 --- /dev/null +++ b/src/mhnfs/hopfield/examples/latch_sequence/modules/activation.py @@ -0,0 +1,337 @@ +import torch +import torch.nn as nn + +from torch import Tensor +from torch.nn import Linear, Module, Parameter +from typing import Optional + +from .functional import hopfield_core_forward + +try: + from torch.nn.modules.linear import _LinearWithBias +except ImportError: + _LinearWithBias = None + + +class HopfieldCore(Module): + r"""Allows the model to jointly attend to information + from different representation subspaces. + See references: "Hopfield Networks is All You Need" and + "Attention Is All You Need" (on which this implementation is partly based on). + + .. math:: + \text{HopfieldHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O + \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) + + Args: + embed_dim: total dimension of the model. + num_heads: parallel attention heads. + dropout: a Dropout layer on attn_output_weights. Default: 0.0. + bias: add bias as module parameter. Default: True. + add_bias_kv: add bias to the key and value sequences at dim=0. + add_zero_attn: add a new batch of zeros to the key and + value sequences at dim=1. + kdim: total number of features in key. Default: None. + vdim: total number of features in value. Default: None. + + Note: if kdim and vdim are None, they will be set to embed_dim such that + query, key, and value have the same number of features. + + Examples:: + + >>> hopfield_attn = HopfieldCore(embed_dim, num_heads) + >>> attn_output, attn_output_weights, attn_matrix = hopfield_attn(query, key, value) + """ + __annotations__ = { + 'bias_k': torch._jit_internal.Optional[torch.Tensor], + 'bias_v': torch._jit_internal.Optional[torch.Tensor], + } + + def __init__(self, + embed_dim=None, # type: Optional[int] + num_heads=1, # type: int + dropout=0.0, # type: float + bias=True, # type: bool + add_bias_kv=False, # type: bool + add_zero_attn=False, # type: bool + kdim=None, # type: Optional[int] + vdim=None, # type: Optional[int] + + head_dim=None, # type: Optional[int] + pattern_dim=None, # type: Optional[int] + out_dim=None, # type: Optional[int] + disable_out_projection=False, # type: bool + key_as_static=False, # type: bool + query_as_static=False, # type: bool + value_as_static=False, # type: bool + value_as_connected=False, # type: bool + normalize_pattern=False, # type: bool + normalize_pattern_affine=False # type: bool + ): + super(HopfieldCore, self).__init__() + + assert (type(key_as_static) == bool) and (type(query_as_static) == bool) and (type(value_as_static) == bool) + self.key_as_static, self.query_as_static, self.value_as_static = key_as_static, query_as_static, value_as_static + num_non_static = 3 - (self.key_as_static + self.query_as_static + self.value_as_static) + assert 0 <= num_non_static < 4 + + self.value_as_connected = value_as_connected + self.normalize_pattern, self.normalize_pattern_affine = normalize_pattern, normalize_pattern_affine + self.disable_out_projection = disable_out_projection + + # In case of a static-only executions, check corresponding projections and normalizations. + self.static_execution = self._check_execution_mode() + if self.static_execution: + embed_dim, kdim, vdim = None, None, None + if embed_dim is None: + assert self.static_execution, r'static-only execution requires all projections to be deactivated.' + + # Check and set all other properties, conditioned on . + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self._qkv_same_embed_dim = all(( + self.kdim == embed_dim, self.vdim == embed_dim, pattern_dim is None, not self.value_as_connected)) + assert (not self.value_as_connected) or (self.kdim == self.vdim), r'key and value need to be of same dimension.' + + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = None + self.pattern_dim = pattern_dim + self.virtual_hopfield_dim = None + self.virtual_pattern_dim = None + if not self.static_execution: + if head_dim is None: + self.head_dim = embed_dim // num_heads + assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads." + else: + assert head_dim > 0, "dimension of the association space has to be positive." + self.head_dim = head_dim + if self.pattern_dim is None: + self.pattern_dim = self.head_dim + self.virtual_hopfield_dim = self.num_heads * self.head_dim + self.virtual_pattern_dim = self.num_heads * self.pattern_dim + + self.out_dim = embed_dim if out_dim is None else out_dim + assert disable_out_projection or (self.out_dim > 0), "output projection dimension has to be positive." + + if normalize_pattern_affine: + assert normalize_pattern, "affine pattern normalization without pattern normalization has no effect." + self.p_norm_weight = Parameter(torch.Tensor(head_dim)) + self.p_norm_bias = Parameter(torch.Tensor(head_dim)) + else: + self.register_parameter('p_norm_weight', None) + self.register_parameter('p_norm_bias', None) + + if self._qkv_same_embed_dim is False: + if query_as_static: + self.register_parameter('q_proj_weight', None) + else: + self.q_proj_weight = Parameter(torch.Tensor(self.virtual_hopfield_dim, embed_dim)) + if key_as_static: + self.register_parameter('k_proj_weight', None) + else: + self.k_proj_weight = Parameter(torch.Tensor(self.virtual_hopfield_dim, self.kdim)) + if value_as_static: + self.register_parameter('v_proj_weight', None) + else: + self.v_proj_weight = Parameter(torch.Tensor( + self.virtual_pattern_dim, + self.virtual_hopfield_dim if (value_as_connected and not key_as_static) else self.vdim)) + self.register_parameter('in_proj_weight', None) + else: + if num_non_static > 0: + self.in_proj_weight = Parameter(torch.empty( + (not query_as_static) * self.virtual_hopfield_dim + + (not key_as_static) * self.virtual_hopfield_dim + + (not value_as_static) * self.virtual_pattern_dim, embed_dim)) + else: + self.register_parameter('in_proj_weight', None) + self.register_parameter('q_proj_weight', None) + self.register_parameter('k_proj_weight', None) + self.register_parameter('v_proj_weight', None) + + if bias and (num_non_static > 0): + self.in_proj_bias = Parameter(torch.empty( + (not query_as_static) * self.virtual_hopfield_dim + + (not key_as_static) * self.virtual_hopfield_dim + self.virtual_pattern_dim)) + else: + self.register_parameter('in_proj_bias', None) + if disable_out_projection: + self.register_parameter('out_proj', None) + else: + if bias and _LinearWithBias is not None: + self.out_proj = _LinearWithBias(self.virtual_pattern_dim, self.out_dim) + else: + self.out_proj = Linear(self.virtual_pattern_dim, self.out_dim, bias=bias) + + self.bias_k, self.bias_v = None, None + if add_bias_kv: + if not key_as_static: + self.bias_k = Parameter(torch.empty(1, 1, self.virtual_hopfield_dim)) + if not value_as_static: + self.bias_v = Parameter(torch.empty(1, 1, self.virtual_hopfield_dim)) + assert not (self.bias_k is None and self.bias_v is None), r'cannot set key/value bias if both are static.' + + self.add_zero_attn = add_zero_attn + self.reset_parameters() + + def _check_execution_mode(self) -> bool: + return all(( + self.key_as_static, self.query_as_static, self.value_as_static, not self.value_as_connected, + not self.normalize_pattern, not self.normalize_pattern_affine, self.disable_out_projection + )) + + def reset_parameters(self): + if self.p_norm_weight is not None: + nn.init.ones_(self.p_norm_weight) + nn.init.zeros_(self.p_norm_bias) + + if self._qkv_same_embed_dim and (self.in_proj_weight is not None): + nn.init.normal_(self.in_proj_weight, mean=0.0, std=0.02) + else: + if self.q_proj_weight is not None: + nn.init.normal_(self.q_proj_weight, mean=0.0, std=0.02) + if self.k_proj_weight is not None: + nn.init.normal_(self.k_proj_weight, mean=0.0, std=0.02) + if self.v_proj_weight is not None: + nn.init.normal_(self.v_proj_weight, mean=0.0, std=0.02) + + if self.in_proj_bias is not None: + nn.init.constant_(self.in_proj_bias, 0.0) + if not self.disable_out_projection: + nn.init.normal_(self.out_proj.weight, mean=0.0, std=0.02) + if self.out_proj.bias is not None: + nn.init.constant_(self.out_proj.bias, 0.0) + if self.bias_k is not None: + nn.init.normal_(self.bias_k, mean=0.0, std=0.02) + if self.bias_v is not None: + nn.init.normal_(self.bias_v, mean=0.0, std=0.02) + + def __setstate__(self, state): + super(HopfieldCore, self).__setstate__(state) + + def forward(self, + query, # type: Tensor + key, # type: Tensor + value, # type: Tensor + key_padding_mask=None, # type: Optional[Tensor] + need_weights=True, # type: bool + attn_mask=None, # type: Optional[Tensor] + + scaling=None, # type: Optional[Tensor] + update_steps_max=0, # type: Optional[int] + update_steps_eps=1e-4, # type: float + return_raw_associations=False, # type: bool + return_pattern_projections=False # type: bool + ): + # type: (...) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]] + r""" + Args: + query, key, value: map a query and a set of key-value pairs to an output. + See "Attention Is All You Need" for more details. + See "Hopfield Networks is All You Need" for more details in the setting of Hopfield networks. + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. When given a binary mask and a value is True, + the corresponding value on the attention layer will be ignored. When given + a byte mask and a value is non-zero, the corresponding value on the attention + layer will be ignored. + need_weights: output attn_output_weights. + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + + scaling: scaling of association heads, often represented as beta (one entry per head). + update_steps_max: maximum count of association update steps (None equals to infinity). + update_steps_eps: minimum difference threshold between two consecutive association update steps. + return_raw_associations: return raw association (softmax) values, unmodified. + return_pattern_projections: return pattern projection values, unmodified. + + Shape: + - Inputs: + - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the position + with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + + - scaling: :math:`(num_heads,)`, where num_heads is the amount of heads. + + - Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_output_weights: :math:`(N, L, S)` where N is the batch size, + L is the target sequence length, S is the source sequence length. + - attn_raw: :math:``(N, num_heads, L, S)`, where N is the batch size, + L is the target sequence length, S is the source sequence length. + """ + if self.query_as_static and self.key_as_static: + assert query.shape[2] == key.shape[2], \ + f'query shape[2] of {query.shape[2]} and key shape[2] of {key.shape[2]} need to be equal' + head_dim, embed_dim_to_check = query.shape[2], query.shape[2] + else: + assert self.query_as_static or (query.shape[2] == self.embed_dim), \ + f'query shape[2] of {query.shape[2]} invalid, needs to be {self.embed_dim}.' + assert (not self.query_as_static) or (self.query_as_static and query.shape[2] == self.head_dim), \ + f'query shape[2] of {query.shape[2]} invalid, needs to be {self.head_dim}' + + assert self.key_as_static or (key.shape[2] == self.kdim), \ + f'key shape[2] of {key.shape[2]} invalid, needs to be {self.kdim}.' + assert (not self.key_as_static) or (self.key_as_static and key.shape[2] == self.head_dim), \ + f'key shape[2] of {key.shape[2]} invalid, needs to be {self.head_dim}' + head_dim, embed_dim_to_check = self.head_dim, self.head_dim if self.query_as_static else self.embed_dim + + assert self.value_as_static or (value.shape[2] == self.vdim), \ + f'value shape[2] of {value.shape[2]} invalid, needs to be {self.vdim}.' + assert any(( + not self.value_as_static, self.value_as_static and value.shape[2] == self.pattern_dim, + self.disable_out_projection) + ), f'value shape[2] of {value.shape[2]} invalid, needs to be {self.pattern_dim}' + + out_weights, out_bias = None, None + if not self.disable_out_projection: + out_weights, out_bias = self.out_proj.weight, self.out_proj.bias + + if not self._qkv_same_embed_dim: + return hopfield_core_forward( + query=query, key=key, value=value, embed_dim_to_check=embed_dim_to_check, num_heads=self.num_heads, + in_proj_weight=self.in_proj_weight, in_proj_bias=self.in_proj_bias, bias_k=self.bias_k, + bias_v=self.bias_v, add_zero_attn=self.add_zero_attn, dropout_p=self.dropout, + out_proj_weight=out_weights, out_proj_bias=out_bias, training=self.training, + key_padding_mask=key_padding_mask, need_weights=need_weights, attn_mask=attn_mask, + use_separate_proj_weight=True, q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight, + v_proj_weight=self.v_proj_weight, + + key_as_static=self.key_as_static, query_as_static=self.query_as_static, + value_as_static=self.value_as_static, value_as_connected=self.value_as_connected, + normalize_pattern=self.normalize_pattern, + p_norm_weight=self.p_norm_weight, p_norm_bias=self.p_norm_bias, + head_dim=head_dim, pattern_dim=self.pattern_dim, scaling=scaling, + update_steps_max=update_steps_max, update_steps_eps=update_steps_eps, + return_raw_associations=return_raw_associations, return_projected_patterns=return_pattern_projections) + else: + return hopfield_core_forward( + query=query, key=key, value=value, embed_dim_to_check=embed_dim_to_check, num_heads=self.num_heads, + in_proj_weight=self.in_proj_weight, in_proj_bias=self.in_proj_bias, bias_k=self.bias_k, + bias_v=self.bias_v, add_zero_attn=self.add_zero_attn, dropout_p=self.dropout, + out_proj_weight=out_weights, out_proj_bias=out_bias, training=self.training, + key_padding_mask=key_padding_mask, need_weights=need_weights, attn_mask=attn_mask, + + key_as_static=self.key_as_static, query_as_static=self.query_as_static, + value_as_static=self.value_as_static, value_as_connected=self.value_as_connected, + normalize_pattern=self.normalize_pattern, + p_norm_weight=self.p_norm_weight, p_norm_bias=self.p_norm_bias, + head_dim=head_dim, pattern_dim=self.pattern_dim, scaling=scaling, + update_steps_max=update_steps_max, update_steps_eps=update_steps_eps, + return_raw_associations=return_raw_associations, return_projected_patterns=return_pattern_projections) diff --git a/src/mhnfs/hopfield/examples/latch_sequence/modules/functional.py b/src/mhnfs/hopfield/examples/latch_sequence/modules/functional.py new file mode 100644 index 0000000000000000000000000000000000000000..2619c45fd87dc5f0348db393aa9b305611b4dd32 --- /dev/null +++ b/src/mhnfs/hopfield/examples/latch_sequence/modules/functional.py @@ -0,0 +1,450 @@ +import torch +import torch.nn as nn + +from torch.tensor import Tensor +from typing import Optional, Tuple, Union + + +def hopfield_core_forward(query, # type: Tensor + key, # type: Tensor + value, # type: Tensor + embed_dim_to_check, # type: int + num_heads, # type: int + in_proj_weight, # type: Optional[Tensor] + in_proj_bias, # type: Optional[Tensor] + bias_k, # type: Optional[Tensor] + bias_v, # type: Optional[Tensor] + add_zero_attn, # type: bool + dropout_p, # type: float + out_proj_weight, # type: Tensor + out_proj_bias, # type: Tensor + training=True, # type: bool + key_padding_mask=None, # type: Optional[Tensor] + need_weights=True, # type: bool + attn_mask=None, # type: Optional[Tensor] + use_separate_proj_weight=False, # type: bool + q_proj_weight=None, # type: Optional[Tensor] + k_proj_weight=None, # type: Optional[Tensor] + v_proj_weight=None, # type: Optional[Tensor] + static_k=None, # type: Optional[Tensor] + static_v=None, # type: Optional[Tensor] + + key_as_static=False, # type: bool + query_as_static=False, # type: bool + value_as_static=False, # type: bool + value_as_connected=False, # type: bool + normalize_pattern=False, # type: bool + p_norm_weight=None, # type: Optional[Tensor] + p_norm_bias=None, # type: Optional[Tensor] + head_dim=None, # type: Optional[int] + pattern_dim=None, # type: Optional[int] + scaling=None, # type: Optional[Union[float, Tensor]] + update_steps_max=0, # type: Optional[Union[int, Tensor]] + update_steps_eps=1e-4, # type: Union[float, Tensor] + return_raw_associations=False, # type: bool + return_projected_patterns=False # type: bool + ): + # type: (...) -> Tuple[Tensor, Optional[Tensor]] + r""" + Args: + query, key, value: map a query and a set of key-value pairs to an output. + See "Attention Is All You Need" for more details. + See "Hopfield Networks is All You Need" for more details in the setting of Hopfield networks. + embed_dim_to_check: total dimension of the model (in case of default head dimension). + num_heads: parallel attention heads. + in_proj_weight, in_proj_bias: input projection weight and bias. + bias_k, bias_v: bias of the key and value sequences to be added at dim=0. + add_zero_attn: add a new batch of zeros to the key and + value sequences at dim=1. + dropout_p: probability of an element to be zeroed. + out_proj_weight, out_proj_bias: the output projection weight and bias. + training: apply dropout if is ``True``. + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. This is an binary mask. When the value is True, + the corresponding value on the attention layer will be filled with -inf. + need_weights: output attn_output_weights. + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + use_separate_proj_weight: the function accept the proj. weights for query, key, + and value in different forms. If false, in_proj_weight will be used, which is + a combination of q_proj_weight, k_proj_weight, v_proj_weight. + q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias. + static_k, static_v: static key and value used for attention operators. + + key_as_static: interpret specified key as being static. + query_as_static: interpret specified key as being static. + value_as_static: interpret specified key as being static. + value_as_connected: connect value projection with key projection. + normalize_pattern: enable normalization of patterns. + p_norm_weight, p_norm_bias: pattern normalization weight and bias. + head_dim: dimensionality of each head. + pattern_dim: dimensionality of each projected value input. + scaling: scaling of association heads, often represented as beta (one entry per head). + update_steps_max: maximum count of association update steps (None equals to infinity). + update_steps_eps: minimum difference threshold between two consecutive association update steps. + return_raw_associations: return raw association (softmax) values, unmodified. + return_projected_patterns: return pattern projection values, unmodified. + + Shape: + Inputs: + - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions + will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + - static_k: :math:`(N*num_heads, S, head_dim)`, where S is the source sequence length, N is the batch size. + - static_v: :math:`(N*num_heads, S, head_dim)`, where S is the source sequence length, N is the batch size. + + - scaling: :math:`(num_heads,)`, where num_heads is the amount of heads. + + Outputs: + - attn_output: :math:`(L, N, E)`, where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_output_weights: :math:`(N, L, S)`, where N is the batch size, + L is the target sequence length, S is the source sequence length. + - attn_raw: :math:``(N, num_heads, L, S)`, where N is the batch size, + L is the target sequence length, S is the source sequence length. + """ + if not torch.jit.is_scripting(): + tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, + out_proj_weight, out_proj_bias) + if any([type(t) is not Tensor for t in tens_ops]) and nn.functional.has_torch_function(tens_ops): + return nn.functional.handle_torch_function( + hopfield_core_forward, tens_ops, query, key, value, + embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias, + bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight, + out_proj_bias, training=training, key_padding_mask=key_padding_mask, + need_weights=need_weights, attn_mask=attn_mask, + use_separate_proj_weight=use_separate_proj_weight, + q_proj_weight=q_proj_weight, k_proj_weight=k_proj_weight, + v_proj_weight=v_proj_weight, static_k=static_k, static_v=static_v, + key_as_static=key_as_static, query_as_static=query_as_static, + value_as_static=value_as_static, value_as_connected=value_as_connected, + normalize_pattern=normalize_pattern, p_norm_weight=p_norm_weight, p_norm_bias=p_norm_bias, + head_dim=head_dim, pattern_dim=pattern_dim, scaling=scaling, update_steps_max=update_steps_max, + update_steps_eps=update_steps_eps, return_raw_associations=return_raw_associations) + tgt_len, bsz, embed_dim = query.shape[0], value.shape[1], query.shape[2] + assert embed_dim == embed_dim_to_check + # allow MHA to have different sizes for the feature dimension + assert key.size(0) == value.size(0) and key.size(1) == value.size(1) + + assert (scaling is None) or (type(scaling) in (float, torch.Tensor)) + if type(scaling) == torch.Tensor: + assert scaling.ndimension() == 1 and scaling.shape[0] == num_heads, "only one entry per head." + + assert (update_steps_max is None) or (type(update_steps_max) in (int, torch.Tensor)) + if type(update_steps_max) == torch.Tensor: + assert update_steps_max.ndimension() == 1 and update_steps_max.shape[0] == num_heads, "only one entry per head." + elif type(update_steps_max) == int: + update_steps_max = torch.tensor([update_steps_max] * num_heads, dtype=torch.int32, device=query.device) + elif update_steps_max is None: + update_steps_max = -torch.ones(size=(num_heads,), dtype=torch.int32, device=query.device) + + assert type(update_steps_eps) in (float, torch.Tensor) + if type(update_steps_eps) == torch.Tensor: + assert update_steps_eps.ndimension() == 1 and update_steps_eps.shape[0] == num_heads, "only one entry per head." + assert (update_steps_eps <= 0.0).sum() == 0, "only positive thresholds allowed." + update_steps_eps = update_steps_eps.to(device=query.device) + elif type(update_steps_eps) == float: + assert update_steps_eps > 0, "only positive thresholds allowed." + update_steps_eps = torch.tensor([update_steps_eps] * num_heads, dtype=query.dtype, device=query.device) + + # Adapt dimensionality of each each. + if head_dim is None: + head_dim = embed_dim // num_heads + assert head_dim * num_heads == embed_dim, r'embed_dim must be divisible by num_heads.' + hopfield_dim = num_heads * head_dim + + # Adapt dimensionality of each value projection. + if pattern_dim is None: + pattern_dim = head_dim + assert (not value_as_connected) or (pattern_dim == head_dim) + + q, k, v, xi, src_len = None, None, None, None, 0 + update_step, xi_old, xi_difference_norm = 0, None, float(r'+inf') + update_active_heads = torch.tensor([[[True]]] * num_heads * bsz, device=query.device) + assert update_active_heads.any(), "at least one head needs to be active." + + #################################################################################################################### + # BEGIN HOPFIELD UPDATE ITERATION # + #################################################################################################################### + + while update_active_heads.any(): + + # The query is already projected into the "Hopfield" space at "update_step" equals 0. + # No more projection necessary if "update_step" greater than 0. + if update_step == 0: + if not use_separate_proj_weight: + + if torch.equal(query, key) and torch.equal(key, value) and not ( + key_as_static or query_as_static or value_as_static): + # self-attention + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1) + + elif torch.equal(key, value) and not (key_as_static or value_as_static): + # encoder-decoder attention + _start, _end = 0, hopfield_dim + if query_as_static: + q = query.repeat(1, num_heads, 1) + else: + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + q = nn.functional.linear(query, _w, _b) + _start = hopfield_dim + _end = None + + if key is None: + assert value is None + k = None + v = None + else: + + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1) + + else: + _start, _end = 0, hopfield_dim + if query_as_static: + q = query.repeat(1, num_heads, 1) + else: + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + q = nn.functional.linear(query, _w, _b) + _start += hopfield_dim + _end += hopfield_dim + + if key_as_static: + k = key.repeat(1, num_heads, 1) + else: + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + k = nn.functional.linear(key, _w, _b) + _start += hopfield_dim + _end += hopfield_dim + + if value_as_static: + v = value.repeat(1, num_heads, 1) + else: + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + v = nn.functional.linear(value, _w, _b) + else: + _start, _end = 0, hopfield_dim + if query_as_static: + q = query.repeat(1, num_heads, 1) + else: + q_proj_weight_non_opt = torch.jit._unwrap_optional(q_proj_weight) + len1, len2 = q_proj_weight_non_opt.size() + assert len1 == hopfield_dim and len2 == query.size(-1) + if in_proj_bias is not None: + q = nn.functional.linear(query, q_proj_weight_non_opt, in_proj_bias[_start:_end]) + _start += hopfield_dim + _end += hopfield_dim + else: + q = nn.functional.linear(query, q_proj_weight_non_opt, in_proj_bias) + + v = value + if key_as_static: + k = key.repeat(1, num_heads, 1) + else: + k_proj_weight_non_opt = torch.jit._unwrap_optional(k_proj_weight) + len1, len2 = k_proj_weight_non_opt.size() + assert len1 == hopfield_dim and len2 == key.size(-1) + + _bias = None if in_proj_bias is None else in_proj_bias[_start:_end] + k = nn.functional.linear(key, k_proj_weight_non_opt, _bias) + if value_as_connected: + v = nn.functional.linear(v, k_proj_weight_non_opt, _bias) + _start += hopfield_dim + _end += num_heads * pattern_dim + + if value_as_static: + if not (value_as_connected or key_as_static): + v = v.repeat(1, num_heads, 1) + else: + v_proj_weight_non_opt = torch.jit._unwrap_optional(v_proj_weight) + len1, len2 = v_proj_weight_non_opt.size() + assert len1 == (num_heads * pattern_dim) and len2 == v.size(-1) + if in_proj_bias is not None: + v = nn.functional.linear(v, v_proj_weight_non_opt, in_proj_bias[_start:]) + else: + v = nn.functional.linear(v, v_proj_weight_non_opt, in_proj_bias) + + if attn_mask is not None: + assert attn_mask.dtype == torch.float32 or attn_mask.dtype == torch.float64 or \ + attn_mask.dtype == torch.float16 or attn_mask.dtype == torch.uint8 or \ + attn_mask.dtype == torch.bool, \ + 'Only float, byte, and bool types are supported for attn_mask, not {}'.format(attn_mask.dtype) + if attn_mask.dtype == torch.uint8: + warnings.warn( + "Byte tensor for attn_mask in nn.HopfieldCore is deprecated. Use bool tensor instead.") + attn_mask = attn_mask.to(torch.bool) + + if attn_mask.dim() == 2: + attn_mask = attn_mask.unsqueeze(0) + if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: + raise RuntimeError('The size of the 2D attn_mask is not correct.') + elif attn_mask.dim() == 3: + if list(attn_mask.size()) != [bsz * num_heads, query.size(0), key.size(0)]: + raise RuntimeError('The size of the 3D attn_mask is not correct.') + else: + raise RuntimeError("attn_mask's dimension {} is not supported".format(attn_mask.dim())) + # attn_mask's dim is 3 now. + + # Optionally normalize patterns. + if normalize_pattern: + q = torch.nn.functional.layer_norm( + input=q.reshape(shape=(-1, head_dim)), normalized_shape=(head_dim,), + weight=p_norm_weight, bias=p_norm_bias).reshape(shape=q.shape) + k = torch.nn.functional.layer_norm( + input=k.reshape(shape=(-1, head_dim)), normalized_shape=(head_dim,), + weight=p_norm_weight, bias=p_norm_bias).reshape(shape=k.shape) + + else: + active_xi = xi.masked_select(mask=update_active_heads).view(size=(-1, *xi.shape[1:])) + active_k = k.masked_select(mask=update_active_heads).view(size=(-1, *k.shape[1:])) + q = torch.masked_scatter(input=q, mask=update_active_heads, source=torch.bmm(active_xi, active_k)) + + # Optionally scale association heads (each head separately). + if type(scaling) == float: + q = q * scaling + elif type(scaling) == torch.Tensor: + q = q * scaling.view(1, 1, -1).repeat(repeats=(1, 1, q.shape[2] // scaling.shape[0])) + + if update_step == 0: + # convert ByteTensor key_padding_mask to bool + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + warnings.warn( + "Byte tensor for key_padding_mask in nn.HopfieldCore is deprecated. Use bool tensor instead.") + key_padding_mask = key_padding_mask.to(torch.bool) + + if bias_k is not None and bias_v is not None: + if static_k is None and static_v is None and key_as_static is None and value_as_static is None: + k = torch.cat([k, bias_k.repeat(1, bsz, 1)]) + v = torch.cat([v, bias_v.repeat(1, bsz, 1)]) + if attn_mask is not None: + attn_mask = nn.functional.pad(attn_mask, [0, 1]) + if key_padding_mask is not None: + key_padding_mask = nn.functional.pad(key_padding_mask, [0, 1]) + else: + assert static_k is None, "bias cannot be added to static key." + assert static_v is None, "bias cannot be added to static value." + assert not key_as_static, "bias cannot be added to static key." + assert not value_as_static, "bias cannot be added to static value." + else: + assert bias_k is None + assert bias_v is None + + q = q.contiguous().view(tgt_len, -1, head_dim).transpose(0, 1) + if k is not None: + k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) + if v is not None: + v = v.contiguous().view(v.shape[0], bsz * num_heads, -1).transpose(0, 1) + + if static_k is not None: + assert static_k.size(0) == bsz * num_heads + assert static_k.size(2) == head_dim + k = static_k + + if static_v is not None: + assert static_v.size(0) == bsz * num_heads + assert static_v.size(2) == pattern_dim + v = static_v + + src_len = k.size(1) + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz + assert key_padding_mask.size(1) == src_len + + if add_zero_attn: + src_len += 1 + k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1) + v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1) + if attn_mask is not None: + attn_mask = nn.functional.pad(attn_mask, [0, 1]) + if key_padding_mask is not None: + key_padding_mask = nn.functional.pad(key_padding_mask, [0, 1]) + + attn_output_weights = torch.bmm(q, k.transpose(1, 2)) + assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len] + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_output_weights.masked_fill_(attn_mask, float('-inf')) + else: + attn_output_weights += attn_mask + + if key_padding_mask is not None: + attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) + attn_output_weights = attn_output_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2), + float('-inf'), + ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len) + + # Compute new xi for Hopfield retrieve iterations. + if xi is None: + xi = nn.functional.softmax(attn_output_weights, dim=-1) + else: + xi = torch.masked_scatter(input=xi, mask=update_active_heads, source=nn.functional.softmax( + attn_output_weights.masked_select(mask=update_active_heads).view(size=(-1, *xi.shape[1:])), dim=-1)) + + # Compute threshold-based stopping criterion for Hopfield retrieve iterations. + with torch.no_grad(): + xi_active = xi.view(size=(bsz, num_heads, tgt_len, src_len)) + update_active_heads = (update_step < update_steps_max) | (update_steps_max < 0) + if xi_old is not None: + update_active_heads &= ((xi_old - xi_active).norm(p=2, dim=(2, 3)).max(axis=0)[0]) > update_steps_eps + update_active_heads = update_active_heads.unsqueeze(dim=1).unsqueeze(dim=2).repeat(repeats=(bsz, 1, 1)) + xi_old = xi_active + update_step += 1 + + #################################################################################################################### + # END HOPFIELD UPDATE ITERATION # + #################################################################################################################### + + attn_output_weights = nn.functional.dropout(xi, p=dropout_p, training=training) + attn_output = torch.bmm(attn_output_weights, v) + assert list(attn_output.shape[:2]) == [bsz * num_heads, tgt_len] + attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, -1) + if out_proj_weight is not None: + assert attn_output.shape[2] == num_heads * pattern_dim + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) + + xi = xi.view(bsz, num_heads, tgt_len, src_len) if return_raw_associations else None + v = v.view(bsz, num_heads, src_len, -1) if return_projected_patterns else None + if need_weights: + # average attention weights over heads + attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) + return attn_output, attn_output_weights.sum(dim=1) / num_heads, xi, v + else: + return attn_output, None, xi, v diff --git a/src/mhnfs/hopfield/examples/latch_sequence/modules/transformer.py b/src/mhnfs/hopfield/examples/latch_sequence/modules/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..884e0cb4b57610cf1daf8147f2c3d59f17824750 --- /dev/null +++ b/src/mhnfs/hopfield/examples/latch_sequence/modules/transformer.py @@ -0,0 +1,209 @@ +import torch +import torch.nn as nn + +from copy import deepcopy +from torch import Tensor +from torch.nn.modules import Module +from typing import Optional, Tuple, Union + +from . import Hopfield + + +class HopfieldEncoderLayer(Module): + """ + Module with underlying Hopfield association to be used as an encoder in transformer-like architectures. + """ + + def __init__(self, + hopfield_association: Hopfield, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation: str = r'relu' + ): + """ + Initialise a new instance of a Hopfield association-based encoder module. + + :param hopfield_association: instance of Hopfield association module + :param dim_feedforward: depth of the linear projections applied internally + :param activation: activation to be applied on the result of the internal linear projections + :param dropout: dropout probability to be applied internally + """ + super(HopfieldEncoderLayer, self).__init__() + self.hopfield_association = deepcopy(hopfield_association) + + self.linear_residual = nn.Linear(self.hopfield_association.state_pattern_dim, dim_feedforward) + self.dropout_residual = nn.Dropout(dropout) + self.linear_output = nn.Linear(dim_feedforward, self.hopfield_association.state_pattern_dim) + + self.norm_residual = nn.LayerNorm(self.hopfield_association.state_pattern_dim) + self.norm_output = nn.LayerNorm(self.hopfield_association.state_pattern_dim) + self.dropout_hopfield_association = nn.Dropout(dropout) + self.dropout_output = nn.Dropout(dropout) + + self.activation_residual = getattr(torch, activation, None) + assert self.activation_residual is not None, r'invalid activation function supplied.' + self.reset_parameters() + + def reset_parameters(self) -> None: + """ + Reset parameters, including Hopfield association. + + :return: None + """ + for module in (self.hopfield_association, self.linear_residual, + self.linear_output, self.norm_residual, self.norm_output): + if hasattr(module, r'reset_parameters'): + module.reset_parameters() + + def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None) -> Tensor: + """ + Apply Hopfield encoding on specified data. + + :param src: data to be processed by Hopfield encoder module + :param src_mask: mask to be applied on association matrix + :param src_key_padding_mask: mask to be applied on stored patterns + :return: Hopfield-encoded input data + """ + data_associated = self.hopfield_association( + input=src, stored_pattern_padding_mask=src_key_padding_mask, association_mask=src_mask) + src = src + self.dropout_hopfield_association(input=data_associated) + src = self.norm_residual(input=src) + + result_residual_inner = self.activation_residual(input=self.linear_residual(input=src)) + data_associated = self.linear_output(input=self.dropout_residual(input=result_residual_inner)) + src = src + self.dropout_output(input=data_associated) + + return self.norm_output(input=src) + + def get_association_matrix(self, input: Union[Tensor, Tuple[Tensor, Tensor, Tensor]]) -> Tensor: + """ + Fetch Hopfield association matrix gathered by passing through the specified data. + + :param input: data to be passed through the Hopfield association + :return: association matrix as computed by the Hopfield core module + """ + return self.hopfield_association.get_association_matrix(input=input) + + @property + def batch_first(self) -> int: + return self.hopfield_association.batch_first + + @property + def input_size(self) -> int: + return self.hopfield_association.input_size + + @property + def output_size(self) -> int: + return self.linear_output.out_features + + +class HopfieldDecoderLayer(Module): + + def __init__(self, + hopfield_association_self: Hopfield, + hopfield_association_cross: Hopfield, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation: str = r'relu' + ): + """ + Initialise a new instance of a Hopfield association-based encoder module. + + :param hopfield_association_self: instance of Hopfield self-association module + :param hopfield_association_cross: instance of Hopfield cross-association module + :param dim_feedforward: depth of the linear projections applied internally + :param dropout: dropout probability to be applied internally + :param activation: activation to be applied on the result of the internal linear projections + """ + super(HopfieldDecoderLayer, self).__init__() + self.hopfield_association_self = deepcopy(hopfield_association_self) + self.hopfield_association_cross = deepcopy(hopfield_association_cross) + + self.linear_residual = nn.Linear(self.hopfield_association_self.state_pattern_dim, dim_feedforward) + self.dropout_residual = nn.Dropout(dropout) + self.linear_output = nn.Linear(dim_feedforward, self.hopfield_association_self.state_pattern_dim) + + self.norm_residual_self = nn.LayerNorm(self.hopfield_association_self.state_pattern_dim) + self.norm_residual_cross = nn.LayerNorm(self.hopfield_association_self.state_pattern_dim) + self.norm_output = nn.LayerNorm(self.hopfield_association_self.state_pattern_dim) + self.dropout_hopfield_association_self = nn.Dropout(dropout) + self.dropout_hopfield_association_cross = nn.Dropout(dropout) + self.dropout_output = nn.Dropout(dropout) + + self.activation_residual = getattr(torch, activation, None) + assert self.activation_residual is not None, r'invalid activation function supplied.' + self.reset_parameters() + + def reset_parameters(self) -> None: + """ + Reset parameters, including Hopfield association. + + :return: None + """ + for module in (self.hopfield_association_self, self.hopfield_association_cross, + self.linear_residual, self.linear_output, self.norm_residual_self, + self.norm_residual_cross, self.norm_output): + if hasattr(module, r'reset_parameters'): + module.reset_parameters() + + def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None) -> Tensor: + """ + Apply Hopfield decoding on specified data. + + :param tgt: data to be processed by Hopfield decoder module (self-association) + :param memory: data to be processed by Hopfield encoder module (cross-association) + :param tgt_mask: mask to be applied on self-association matrix + :param memory_mask: mask to be applied on cross-association matrix + :param tgt_key_padding_mask: mask to be applied on stored patterns + :param memory_key_padding_mask: mask to be applied on state patterns as well as pattern projection + :return: Hopfield-decoded input + """ + data_associated = self.hopfield_association_self( + input=tgt, stored_pattern_padding_mask=tgt_key_padding_mask, + association_mask=tgt_mask) + tgt = tgt + self.dropout_hopfield_association_self(input=data_associated) + tgt = self.norm_residual_self(input=tgt) + + data_associated = self.hopfield_association_cross( + input=(memory, tgt, memory), stored_pattern_padding_mask=memory_key_padding_mask, + association_mask=memory_mask) + tgt = tgt + self.dropout_hopfield_association_cross(input=data_associated) + tgt = self.norm_residual_cross(input=tgt) + + result_residual_inner = self.activation_residual(input=self.linear_residual(input=tgt)) + data_associated = self.linear_output(input=self.dropout_residual(input=result_residual_inner)) + tgt = tgt + self.dropout_output(input=data_associated) + return self.norm_output(input=tgt) + + def get_association_matrix_self(self, input: Union[Tensor, Tuple[Tensor, Tensor, Tensor]]) -> Tensor: + """ + Fetch Hopfield self-association matrix gathered by passing through the specified data. + + :param input: data to be passed through the Hopfield association + :return: association matrix as computed by the Hopfield core module + """ + return self.hopfield_association_self.get_association_matrix(input=input) + + def get_association_matrix_cross(self, input: Union[Tensor, Tuple[Tensor, Tensor, Tensor]]) -> Tensor: + """ + Fetch Hopfield cross-association matrix gathered by passing through the specified data. + + :param input: data to be passed through the Hopfield association + :return: association matrix as computed by the Hopfield core module + """ + return self.hopfield_association_cross.get_association_matrix(input=input) + + @property + def batch_first(self) -> int: + return self.hopfield_association_self.batch_first + + @property + def input_size(self) -> int: + return self.hopfield_association_self.input_size + + @property + def output_size(self) -> int: + return self.linear_output_self.out_features diff --git a/src/mhnfs/hopfield/examples/latch_sequence/resources/hopfield_adapted.pdf b/src/mhnfs/hopfield/examples/latch_sequence/resources/hopfield_adapted.pdf new file mode 100644 index 0000000000000000000000000000000000000000..a372354c143235492963af3c030fe9e57ef52384 Binary files /dev/null and b/src/mhnfs/hopfield/examples/latch_sequence/resources/hopfield_adapted.pdf differ diff --git a/src/mhnfs/hopfield/examples/latch_sequence/resources/hopfield_base.pdf b/src/mhnfs/hopfield/examples/latch_sequence/resources/hopfield_base.pdf new file mode 100644 index 0000000000000000000000000000000000000000..d79776dad1aaa8521d5e262da58b10a72729f9ec Binary files /dev/null and b/src/mhnfs/hopfield/examples/latch_sequence/resources/hopfield_base.pdf differ diff --git a/src/mhnfs/hopfield/examples/latch_sequence/resources/hopfield_lookup.pdf b/src/mhnfs/hopfield/examples/latch_sequence/resources/hopfield_lookup.pdf new file mode 100644 index 0000000000000000000000000000000000000000..7d77e71d3d435bb1c7f2cba80325f3568351c5ca Binary files /dev/null and b/src/mhnfs/hopfield/examples/latch_sequence/resources/hopfield_lookup.pdf differ diff --git a/src/mhnfs/hopfield/examples/latch_sequence/resources/hopfield_lookup_adapted.pdf b/src/mhnfs/hopfield/examples/latch_sequence/resources/hopfield_lookup_adapted.pdf new file mode 100644 index 0000000000000000000000000000000000000000..4311173d107115696730da2a301f5214834bf6c0 Binary files /dev/null and b/src/mhnfs/hopfield/examples/latch_sequence/resources/hopfield_lookup_adapted.pdf differ diff --git a/src/mhnfs/hopfield/examples/latch_sequence/resources/hopfield_pooling.pdf b/src/mhnfs/hopfield/examples/latch_sequence/resources/hopfield_pooling.pdf new file mode 100644 index 0000000000000000000000000000000000000000..0b6a6629bd3338c42b73d976bfcb60464ffb6387 Binary files /dev/null and b/src/mhnfs/hopfield/examples/latch_sequence/resources/hopfield_pooling.pdf differ diff --git a/src/mhnfs/hopfield/examples/latch_sequence/resources/hopfield_pooling_adapted.pdf b/src/mhnfs/hopfield/examples/latch_sequence/resources/hopfield_pooling_adapted.pdf new file mode 100644 index 0000000000000000000000000000000000000000..42603ab847c67dd4ab644a2b39a593f034f3a539 Binary files /dev/null and b/src/mhnfs/hopfield/examples/latch_sequence/resources/hopfield_pooling_adapted.pdf differ diff --git a/src/mhnfs/hopfield/examples/latch_sequence/resources/lstm_base.pdf b/src/mhnfs/hopfield/examples/latch_sequence/resources/lstm_base.pdf new file mode 100644 index 0000000000000000000000000000000000000000..b04c2d77f5af97d116474bc74fb32d37a8766f78 Binary files /dev/null and b/src/mhnfs/hopfield/examples/latch_sequence/resources/lstm_base.pdf differ diff --git a/src/mhnfs/hopfield/examples/mnist_bags/LICENSE b/src/mhnfs/hopfield/examples/mnist_bags/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..236fb4bb131d2ac5692a0f8bd71c54d2b8c3f002 --- /dev/null +++ b/src/mhnfs/hopfield/examples/mnist_bags/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2018 Maximilian Ilse and Jakub Tomczak + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/src/mhnfs/hopfield/examples/mnist_bags/README.md b/src/mhnfs/hopfield/examples/mnist_bags/README.md new file mode 100644 index 0000000000000000000000000000000000000000..27f0e35170dd51bd7f6ebf6ee18cb5e396bfe776 --- /dev/null +++ b/src/mhnfs/hopfield/examples/mnist_bags/README.md @@ -0,0 +1,28 @@ +# Application of Hopfield-based pooling on Attention-based Deep Multiple Instance Learning + +This notebook demonstrates how to apply the Hopfield pooling layer. +It is based on the PyTorch implementation of the paper [Attention-based Deep Multiple Instance Learning](https://github.com/AMLab-Amsterdam/AttentionDeepMIL) by +* Ilse, M., Tomczak, J. M., & Welling, M. (2018). Attention-based Deep Multiple Instance Learning. [arXiv preprint arXiv:1802.04712](https://arxiv.org/pdf/1802.04712.pdf). + + +## Installation +Download the PyTorch code of Attention-based Deep Multiple Instance Learning (ADMIL) from the accompanying [repository](https://github.com/AMLab-Amsterdam/AttentionDeepMIL) into a directory AttentionDeepMIL on the current directory level ([examples/mnist_bags](.)). Afterwards, line 60 of the file [model.py](https://github.com/AMLab-Amsterdam/AttentionDeepMIL/blob/master/model.py#L60) of this repository needs to be modified. Change +```python +error = 1. - Y_hat.eq(Y).cpu().float().mean().data[0] +``` +to +```python +error = 1. - Y_hat.eq(Y).cpu().float().mean().item() +``` + +Cell 5 specifies the parameters defining the data set properties, whereas cell 15 defines the Hopfield-based pooling network. As a last step, run the notebook. + + +## Note +* The neural network with Hopfield-based pooling, implemented in cell 15 of the [mnist_bags_demo.ipynb](mnist_bags_demo.ipynb) notebook is based on the models proposed in [ADMIL](https://github.com/AMLab-Amsterdam/AttentionDeepMIL). + +* The code in the [mnist_bags_demo.ipynb](mnist_bags_demo.ipynb) notebook is based on the [main.py](https://github.com/AMLab-Amsterdam/AttentionDeepMIL/blob/master/main.py) file from ADMIL. + + +## Disclaimer +The purpose of this notebook is merely to demonstrate how to use HopfieldPooling layer. In no way it is intended as a comparison of the methods. diff --git a/src/mhnfs/hopfield/examples/mnist_bags/mnist_bags_demo.ipynb b/src/mhnfs/hopfield/examples/mnist_bags/mnist_bags_demo.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..eb7996e2b392511f8501d753d3d1e2a770adb0b5 --- /dev/null +++ b/src/mhnfs/hopfield/examples/mnist_bags/mnist_bags_demo.ipynb @@ -0,0 +1,736 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Example: Attention-based Deep Multiple Instance Learning

" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "# Import general modules used for e.g. plotting.\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import os\n", + "import pandas as pd\n", + "import seaborn as sns\n", + "import sys\n", + "import torch\n", + "\n", + "# Import Hopfield-specific modules.\n", + "from modules import HopfieldPooling\n", + "\n", + "# Import auxiliary modules.\n", + "from distutils.version import LooseVersion\n", + "from typing import Optional, Tuple\n", + "\n", + "# Importing PyTorch specific modules.\n", + "from torch import Tensor\n", + "from torch.autograd import Variable\n", + "from torch.nn import Conv2d, Dropout, Linear, MaxPool2d, Module, ReLU, Sequential, Sigmoid\n", + "from torch.nn.utils import clip_grad_norm_\n", + "from torch.optim import AdamW\n", + "from torch.utils.data import DataLoader\n", + "\n", + "# Set plotting style.\n", + "sns.set()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Append the path of the Attention-based Deep Multiple Instance Learning (ADMIL) repository to the system path in order for Python to find the corresponding modules to import." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "sys.path.append(r'./AttentionDeepMIL')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Afterwards, the corresponding modules\n", + "- MnistBags\n", + "- Attention\n", + "- GatedAttention\n", + "\n", + "are imported to the global namespace." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "from dataloader import MnistBags\n", + "from model import Attention, GatedAttention" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Specific minimum versions of Python itself as well as of some used modules is required." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Installed Python version: 3.8.8 (βœ“)\n", + "Installed PyTorch version: 1.7.0 (βœ“)\n" + ] + } + ], + "source": [ + "python_check = '(\\u2713)' if sys.version_info >= (3, 8) else '(\\u2717)'\n", + "pytorch_check = '(\\u2713)' if torch.__version__ >= LooseVersion(r'1.5') else '(\\u2717)'\n", + "\n", + "print(f'Installed Python version: {sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro} {python_check}')\n", + "print(f'Installed PyTorch version: {torch.__version__} {pytorch_check}')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "

Create Dataset

\n", + "\n", + "The dataset itself falls into the category of binary classification tasks in the domain of Multiple Instance Learning (MIL) problems.\n", + "The MNIST-bags task was introcuded by Ilse and Tomczak:
\n", + "Ilse, M., Tomczak, J.M. and Welling, M., 2018. Attention-based deep multiple instance learning. arXiv preprint arXiv:1802.04712.

\n", + "Each bag comprises a collection of $28\\times{}28$ grayscale images/instances, whereas each instance is a sequence of pixel values in the range of $[0; 255]$. The amount of instances per pag is drawn from a Gaussian with specified mean and variance. The positive class is defined by the presence of the target number/digit, whereas the negative one by its absence. This demonstration shows, that HopfieldPooling is capable of learning and filtering each bag with respect to the class-defining target number/digit. Defining arguments are:\n", + "

\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
ArgumentValue (used in this demo)Description
target_number9Target number/digit defining class affiliation.
mean_bag_length10Mean amount of instances per bag.
var_bag_length2Variance of amount of instances per bag.
num_bag{200; 50}Amount of samples of the training as well as validation set.
\n", + "\n", + "Let's define the dataset using previously mentioned properties as well as a logging directory for storing all auxiliary outputs like performance plots." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "device = torch.device(r'cuda:0' if torch.cuda.is_available() else r'cpu')\n", + "\n", + "# Create data loader of training set.\n", + "data_loader_train = DataLoader(MnistBags(\n", + " target_number=9,\n", + " mean_bag_length=10,\n", + " var_bag_length=2,\n", + " num_bag=200,\n", + " train=True\n", + "), batch_size=1, shuffle=True)\n", + "\n", + "# Create data loader of validation set.\n", + "data_loader_eval = DataLoader(MnistBags(\n", + " target_number=9,\n", + " mean_bag_length=10,\n", + " var_bag_length=2,\n", + " num_bag=50,\n", + " train=False\n", + "), batch_size=1, shuffle=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "log_dir = f'resources/'\n", + "os.makedirs(log_dir, exist_ok=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "

Create Auxiliaries

\n", + "\n", + "Before digging into Hopfield-based networks, a few auxiliary variables and functions need to be defined. This is nothing special with respect to Hopfield-based networks, but rather common preparation work of (almost) every machine learning setting (e.g. definition of a data loader as well as a training loop). We will see, that this comprises the most work of this whole demo." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "def train_epoch(network: Module,\n", + " optimiser: AdamW,\n", + " data_loader: DataLoader\n", + " ) -> Tuple[float, float, float]:\n", + " \"\"\"\n", + " Execute one training epoch.\n", + " \n", + " :param network: network instance to train\n", + " :param optimiser: optimiser instance responsible for updating network parameters\n", + " :param data_loader: data loader instance providing training data\n", + " :return: tuple comprising training loss, training error as well as accuracy\n", + " \"\"\"\n", + " network.train()\n", + " losses, errors, accuracies = [], [], []\n", + " for data, target in data_loader:\n", + " data, target = data.to(device=device), target[0].to(device=device)\n", + "\n", + " # Process data by Hopfield-based network.\n", + " loss = network.calculate_objective(data, target)[0]\n", + "\n", + " # Update network parameters.\n", + " optimiser.zero_grad()\n", + " loss.backward()\n", + " clip_grad_norm_(parameters=network.parameters(), max_norm=1.0, norm_type=2)\n", + " optimiser.step()\n", + "\n", + " # Compute performance measures of current model.\n", + " error, prediction = network.calculate_classification_error(data, target)\n", + " accuracy = (prediction == target).to(dtype=torch.float32).mean()\n", + " accuracies.append(accuracy.detach().item())\n", + " errors.append(error)\n", + " losses.append(loss.detach().item())\n", + " \n", + " # Report progress of training procedure.\n", + " return sum(losses) / len(losses), sum(errors) / len(errors), sum(accuracies) / len(accuracies)\n", + "\n", + "\n", + "def eval_iter(network: Module,\n", + " data_loader: DataLoader\n", + " ) -> Tuple[float, float, float]:\n", + " \"\"\"\n", + " Evaluate the current model.\n", + " \n", + " :param network: network instance to evaluate\n", + " :param data_loader: data loader instance providing validation data\n", + " :return: tuple comprising validation loss, validation error as well as accuracy\n", + " \"\"\"\n", + " network.eval()\n", + " with torch.no_grad():\n", + " losses, errors, accuracies = [], [], []\n", + " for data, target in data_loader:\n", + " data, target = data.to(device=device), target[0].to(device=device)\n", + "\n", + " # Process data by Hopfield-based network.\n", + " loss = network.calculate_objective(data, target)[0]\n", + "\n", + " # Compute performance measures of current model.\n", + " error, prediction = network.calculate_classification_error(data, target)\n", + " accuracy = (prediction == target).to(dtype=torch.float32).mean()\n", + " accuracies.append(accuracy.detach().item())\n", + " errors.append(error)\n", + " losses.append(loss.detach().item())\n", + "\n", + " # Report progress of validation procedure.\n", + " return sum(losses) / len(losses), sum(errors) / len(errors), sum(accuracies) / len(accuracies)\n", + "\n", + " \n", + "def operate(network: Module,\n", + " optimiser: AdamW,\n", + " data_loader_train: DataLoader,\n", + " data_loader_eval: DataLoader,\n", + " num_epochs: int = 1\n", + " ) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:\n", + " \"\"\"\n", + " Train the specified network by gradient descent using backpropagation.\n", + " \n", + " :param network: network instance to train\n", + " :param optimiser: optimiser instance responsible for updating network parameters\n", + " :param data_loader_train: data loader instance providing training data\n", + " :param data_loader_eval: data loader instance providing validation data\n", + " :param num_epochs: amount of epochs to train\n", + " :return: data frame comprising training as well as evaluation performance\n", + " \"\"\"\n", + " losses, errors, accuracies = {r'train': [], r'eval': []}, {r'train': [], r'eval': []}, {r'train': [], r'eval': []}\n", + " for epoch in range(num_epochs):\n", + " \n", + " # Train network.\n", + " performance = train_epoch(network, optimiser, data_loader_train)\n", + " losses[r'train'].append(performance[0])\n", + " errors[r'train'].append(performance[1])\n", + " accuracies[r'train'].append(performance[2])\n", + " \n", + " # Evaluate current model.\n", + " performance = eval_iter(network, data_loader_eval)\n", + " losses[r'eval'].append(performance[0])\n", + " errors[r'eval'].append(performance[1])\n", + " accuracies[r'eval'].append(performance[2])\n", + " \n", + " # Report progress of training and validation procedures.\n", + " return pd.DataFrame(losses), pd.DataFrame(errors), pd.DataFrame(accuracies)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "def set_seed(seed: int = 42) -> None:\n", + " \"\"\"\n", + " Set seed for all underlying (pseudo) random number sources.\n", + " \n", + " :param seed: seed to be used\n", + " :return: None\n", + " \"\"\"\n", + " torch.manual_seed(42)\n", + " torch.backends.cudnn.deterministic = True\n", + " torch.backends.cudnn.benchmark = False\n", + "\n", + "\n", + "def plot_performance(loss: pd.DataFrame,\n", + " error: pd.DataFrame,\n", + " accuracy: pd.DataFrame,\n", + " log_file: str\n", + " ) -> None:\n", + " \"\"\"\n", + " Plot and save loss and accuracy.\n", + " \n", + " :param loss: loss to be plotted\n", + " :param error: error to be plotted\n", + " :param accuracy: accuracy to be plotted\n", + " :param log_file: target file for storing the resulting plot\n", + " :return: None\n", + " \"\"\"\n", + " fig, ax = plt.subplots(1, 3, figsize=(20, 7))\n", + " \n", + " loss_plot = sns.lineplot(data=loss, ax=ax[0])\n", + " loss_plot.set(xlabel=r'Epoch', ylabel=r'Loss')\n", + " \n", + " error_plot = sns.lineplot(data=error, ax=ax[1])\n", + " error_plot.set(xlabel=r'Epoch', ylabel=r'Error')\n", + " \n", + " accuracy_plot = sns.lineplot(data=accuracy, ax=ax[2])\n", + " accuracy_plot.set(xlabel=r'Epoch', ylabel=r'Accuracy')\n", + " \n", + " fig.tight_layout()\n", + " fig.savefig(log_file)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Attention-based Network

" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "set_seed()\n", + "network = Attention().to(device=device)\n", + "optimiser = AdamW(params=network.parameters(), lr=5e-4, weight_decay=1e-4)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Operate Attention-based Network

" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "losses, errors, accuracies = operate(\n", + " network=network,\n", + " optimiser=optimiser,\n", + " data_loader_train=data_loader_train,\n", + " data_loader_eval=data_loader_eval,\n", + " num_epochs=20)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plot_performance(loss=losses, error=errors, accuracy=accuracies, log_file=f'{log_dir}/attention_base.pdf')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

GatedAttention-based Network

" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "set_seed()\n", + "network = GatedAttention().to(device=device)\n", + "optimiser = AdamW(params=network.parameters(), lr=5e-4, weight_decay=1e-4)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Operate GatedAttention-based Network

" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "losses, errors, accuracies = operate(\n", + " network=network,\n", + " optimiser=optimiser,\n", + " data_loader_train=data_loader_train,\n", + " data_loader_eval=data_loader_eval,\n", + " num_epochs=20)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plot_performance(loss=losses, error=errors, accuracy=accuracies, log_file=f'{log_dir}/gated_attention_base.pdf')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Hopfield-based Pooling

" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "class HfPooling(Module): \n", + " def __init__(self):\n", + " \"\"\"\n", + " Initialize a new instance of a Hopfield-based pooling network.\n", + " \n", + " Note: all hyperparameters of the network are fixed for demonstration purposes.\n", + " Morevover, most of the notation of the original implementation is kept in order\n", + " to be easier comparable (partially ignoring PEP8).\n", + " \"\"\"\n", + " super(HfPooling, self).__init__()\n", + " self.L = 500\n", + " self.D = 128\n", + " self.K = 1\n", + "\n", + " self.feature_extractor_part1 = Sequential(\n", + " Conv2d(1, 20, kernel_size=5),\n", + " ReLU(),\n", + " MaxPool2d(2, stride=2),\n", + " Conv2d(20, 50, kernel_size=5),\n", + " ReLU(),\n", + " MaxPool2d(2, stride=2)\n", + " )\n", + " self.feature_extractor_part2 = Sequential(\n", + " Linear(50 * 4 * 4, self.L),\n", + " ReLU(),\n", + " )\n", + " self.hopfield_pooling = HopfieldPooling(\n", + " input_size=self.L, hidden_size=32, output_size=self.L, num_heads=1\n", + " )\n", + " self.dp = Dropout(\n", + " p=0.1\n", + " )\n", + " self.classifier = Sequential(\n", + " Linear(self.L * self.K, 1),\n", + " Sigmoid()\n", + " )\n", + " \n", + " def forward(self, input: Tensor) -> Tuple[Tensor, Tensor, Optional[Tensor]]:\n", + " \"\"\"\n", + " Compute result of Hopfield-based pooling network on specified data.\n", + " \n", + " :param input: data to be processed by the Hopfield-based pooling network\n", + " :return: result as computed by the Hopfield-based pooling network\n", + " \"\"\"\n", + " x = input.squeeze(0)\n", + " H = self.feature_extractor_part1(x)\n", + " H = H.view(-1, 50 * 4 * 4)\n", + " H = self.feature_extractor_part2(H)\n", + " \n", + " H = H.unsqueeze(0)\n", + " H = self.hopfield_pooling(H)\n", + " H = H.squeeze(0)\n", + " H = self.dp(H)\n", + "\n", + " Y_prob = self.classifier(H)\n", + " Y_hat = torch.ge(Y_prob, 0.5).float()\n", + "\n", + " return Y_prob, Y_hat, None\n", + "\n", + " def calculate_classification_error(self, input: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:\n", + " \"\"\"\n", + " Compute classification error of current model.\n", + " \n", + " :param input: data to be processed by the Hopfield-based pooling network\n", + " :param target: target to be used to compute the classification error of the current model\n", + " :return: classification error as well as predicted class\n", + " \"\"\"\n", + " Y = target.float()\n", + " _, Y_hat, _ = self.forward(input)\n", + " error = 1.0 - Y_hat.eq(Y).cpu().float().mean().item()\n", + "\n", + " return error, Y_hat\n", + "\n", + " def calculate_objective(self, input: Tensor, target: Tensor) -> Tuple[Tensor, Optional[Tensor]]:\n", + " \"\"\"\n", + " Compute objective of the current model.\n", + " \n", + " :param input: data to be processed by the Hopfield-based pooling network\n", + " :param target: target to be used to compute the objective of the current model\n", + " :return: objective as well as dummy A (see accompanying paper for more information)\n", + " \"\"\"\n", + " Y = target.float()\n", + " Y_prob, _, A = self.forward(input)\n", + " Y_prob = torch.clamp(Y_prob, min=1e-5, max=(1.0 - 1e-5))\n", + " neg_log_likelihood = -1.0 * (Y * torch.log(Y_prob) + (1.0 - Y) * torch.log(1.0 - Y_prob))\n", + "\n", + " return neg_log_likelihood, A" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "set_seed()\n", + "network = HfPooling().to(device=device)\n", + "optimiser = AdamW(params=network.parameters(), lr=5e-4, weight_decay=1e-4)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Operate HopfieldPooling-based Network

" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "losses, errors, accuracies = operate(\n", + " network=network,\n", + " optimiser=optimiser,\n", + " data_loader_train=data_loader_train,\n", + " data_loader_eval=data_loader_eval,\n", + " num_epochs=20)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plot_performance(loss=losses, error=errors, accuracy=accuracies, log_file=f'{log_dir}/hopfield_pooling.pdf')" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.8" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/src/mhnfs/hopfield/examples/mnist_bags/modules/__init__.py b/src/mhnfs/hopfield/examples/mnist_bags/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c11713739acb1aca3b32cca53acf0c3faf1fa9f6 --- /dev/null +++ b/src/mhnfs/hopfield/examples/mnist_bags/modules/__init__.py @@ -0,0 +1,898 @@ +import torch +import torch.nn as nn + +from math import sqrt +from torch import Tensor +from torch.nn import Module, Parameter +from typing import Optional, Tuple, Union + +from .activation import HopfieldCore + + +class Hopfield(Module): + """ + Module with underlying Hopfield association. + """ + + def __init__(self, + input_size: Optional[int] = None, + hidden_size: Optional[int] = None, + output_size: Optional[int] = None, + pattern_size: Optional[int] = None, + num_heads: int = 1, + scaling: Optional[Union[float, Tensor]] = None, + update_steps_max: Optional[Union[int, Tensor]] = 0, + update_steps_eps: Union[float, Tensor] = 1e-4, + + normalize_stored_pattern: bool = True, + normalize_stored_pattern_affine: bool = True, + normalize_state_pattern: bool = True, + normalize_state_pattern_affine: bool = True, + normalize_pattern_projection: bool = True, + normalize_pattern_projection_affine: bool = True, + normalize_hopfield_space: bool = False, + normalize_hopfield_space_affine: bool = False, + stored_pattern_as_static: bool = False, + state_pattern_as_static: bool = False, + pattern_projection_as_static: bool = False, + pattern_projection_as_connected: bool = False, + stored_pattern_size: Optional[int] = None, + pattern_projection_size: Optional[int] = None, + + batch_first: bool = True, + association_activation: Optional[str] = None, + dropout: float = 0.0, + input_bias: bool = True, + concat_bias_pattern: bool = False, + add_zero_association: bool = False, + disable_out_projection: bool = False + ): + """ + Initialise new instance of a Hopfield module. + + :param input_size: depth of the input (state pattern) + :param hidden_size: depth of the association space + :param output_size: depth of the output projection + :param pattern_size: depth of patterns to be selected + :param num_heads: amount of parallel association heads + :param scaling: scaling of association heads, often represented as beta (one entry per head) + :param update_steps_max: maximum count of association update steps (None equals to infinity) + :param update_steps_eps: minimum difference threshold between two consecutive association update steps + :param normalize_stored_pattern: apply normalization on stored patterns + :param normalize_stored_pattern_affine: additionally enable affine normalization of stored patterns + :param normalize_state_pattern: apply normalization on state patterns + :param normalize_state_pattern_affine: additionally enable affine normalization of state patterns + :param normalize_pattern_projection: apply normalization on the pattern projection + :param normalize_pattern_projection_affine: additionally enable affine normalization of pattern projection + :param normalize_hopfield_space: enable normalization of patterns in the Hopfield space + :param normalize_hopfield_space_affine: additionally enable affine normalization of patterns in Hopfield space + :param stored_pattern_as_static: interpret specified stored patterns as being static + :param state_pattern_as_static: interpret specified state patterns as being static + :param pattern_projection_as_static: interpret specified pattern projections as being static + :param pattern_projection_as_connected: connect pattern projection with stored pattern + :param stored_pattern_size: depth of input (stored pattern) + :param pattern_projection_size: depth of input (pattern projection) + :param batch_first: flag for specifying if the first dimension of data fed to "forward" reflects the batch size + :param association_activation: additional activation to be applied on the result of the Hopfield association + :param dropout: dropout probability applied on the association matrix + :param input_bias: bias to be added to input (state and stored pattern as well as pattern projection) + :param concat_bias_pattern: bias to be concatenated to stored pattern as well as pattern projection + :param add_zero_association: add a new batch of zeros to stored pattern as well as pattern projection + :param disable_out_projection: disable output projection + """ + super(Hopfield, self).__init__() + assert type(batch_first) == bool, f'"batch_first" needs to be a boolean, not {type(batch_first)}.' + assert (association_activation is None) or (type(association_activation) == str) + + # Initialise Hopfield association module. + self.association_core = HopfieldCore( + embed_dim=input_size, num_heads=num_heads, dropout=dropout, bias=input_bias, + add_bias_kv=concat_bias_pattern, add_zero_attn=add_zero_association, kdim=stored_pattern_size, + vdim=pattern_projection_size, head_dim=hidden_size, pattern_dim=pattern_size, out_dim=output_size, + disable_out_projection=disable_out_projection, key_as_static=stored_pattern_as_static, + query_as_static=state_pattern_as_static, value_as_static=pattern_projection_as_static, + value_as_connected=pattern_projection_as_connected, normalize_pattern=normalize_hopfield_space, + normalize_pattern_affine=normalize_hopfield_space_affine) + self.association_activation = None + if association_activation is not None: + self.association_activation = getattr(torch, association_activation, None) + + # Initialise stored pattern normalization. + self.norm_stored_pattern = None + if normalize_stored_pattern_affine: + assert normalize_stored_pattern, "affine normalization without normalization has no effect." + if normalize_stored_pattern: + normalized_shape = input_size if stored_pattern_size is None else stored_pattern_size + assert normalized_shape is not None, "stored pattern size required for setting up normalisation" + self.norm_stored_pattern = nn.LayerNorm( + normalized_shape=normalized_shape, elementwise_affine=normalize_stored_pattern_affine) + + # Initialise state pattern normalization. + self.norm_state_pattern = None + if normalize_state_pattern_affine: + assert normalize_state_pattern, "affine normalization without normalization has no effect." + if normalize_state_pattern: + assert input_size is not None, "input size required for setting up normalisation" + self.norm_state_pattern = nn.LayerNorm( + normalized_shape=input_size, elementwise_affine=normalize_state_pattern_affine) + + # Initialise pattern projection normalization. + self.norm_pattern_projection = None + if normalize_pattern_projection_affine: + assert normalize_pattern_projection, "affine normalization without normalization has no effect." + if normalize_pattern_projection: + normalized_shape = input_size if pattern_projection_size is None else pattern_projection_size + assert normalized_shape is not None, "pattern projection size required for setting up normalisation" + self.norm_pattern_projection = nn.LayerNorm( + normalized_shape=normalized_shape, elementwise_affine=normalize_pattern_projection_affine) + + # Initialise remaining auxiliary properties. + if self.association_core.static_execution: + self.__scaling = 1.0 if scaling is None else scaling + else: + assert self.association_core.head_dim > 0, f'invalid hidden dimension encountered.' + self.__scaling = (1.0 / sqrt(self.association_core.head_dim)) if scaling is None else scaling + self.__batch_first = batch_first + self.__update_steps_max = update_steps_max + self.__update_steps_eps = update_steps_eps + self.reset_parameters() + + def reset_parameters(self) -> None: + """ + Reset Hopfield association. + + :return: None + """ + for module in (self.association_core, self.norm_stored_pattern, + self.norm_state_pattern, self.norm_pattern_projection): + if hasattr(module, r'reset_parameters'): + module.reset_parameters() + + def _maybe_transpose(self, *args: Tuple[Tensor, ...]) -> Union[Tensor, Tuple[Tensor, ...]]: + """ + Eventually transpose specified data. + + :param args: tensors to eventually transpose (dependent on the state of "batch_first") + :return: eventually transposed tensors + """ + transposed_result = tuple(_.transpose(0, 1) for _ in args) if self.__batch_first else args + return transposed_result[0] if len(transposed_result) == 1 else transposed_result + + def _associate(self, data: Union[Tensor, Tuple[Tensor, Tensor, Tensor]], + return_raw_associations: bool = False, return_projected_patterns: bool = False, + stored_pattern_padding_mask: Optional[Tensor] = None, + association_mask: Optional[Tensor] = None) -> Tuple[Optional[Tensor], ...]: + """ + Apply Hopfield association module on specified data. + + :param data: data to be processed by Hopfield core module + :param return_raw_associations: return raw association (softmax) values, unmodified + :param return_projected_patterns: return pattern projection values, unmodified + :param stored_pattern_padding_mask: mask to be applied on stored patterns + :param association_mask: mask to be applied on inner association matrix + :return: Hopfield-processed input data + """ + assert (type(data) == Tensor) or ((type(data) == tuple) and (len(data) == 3)), \ + r'either one tensor to be used as "stored pattern", "state pattern" and' \ + r' "pattern_projection" must be provided, or three separate ones.' + if type(data) == Tensor: + stored_pattern, state_pattern, pattern_projection = data, data, data + else: + stored_pattern, state_pattern, pattern_projection = data + + # Optionally transpose data. + stored_pattern, state_pattern, pattern_projection = self._maybe_transpose( + stored_pattern, state_pattern, pattern_projection) + + # Optionally apply stored pattern normalization. + if self.norm_stored_pattern is not None: + stored_pattern = self.norm_stored_pattern(input=stored_pattern.reshape( + shape=(-1, stored_pattern.shape[2]))).reshape(shape=stored_pattern.shape) + + # Optionally apply state pattern normalization. + if self.norm_state_pattern is not None: + state_pattern = self.norm_state_pattern(input=state_pattern.reshape( + shape=(-1, state_pattern.shape[2]))).reshape(shape=state_pattern.shape) + + # Optionally apply pattern projection normalization. + if self.norm_pattern_projection is not None: + pattern_projection = self.norm_pattern_projection(input=pattern_projection.reshape( + shape=(-1, pattern_projection.shape[2]))).reshape(shape=pattern_projection.shape) + + # Apply Hopfield association and optional activation function. + return self.association_core( + query=state_pattern, key=stored_pattern, value=pattern_projection, + key_padding_mask=stored_pattern_padding_mask, need_weights=False, attn_mask=association_mask, + scaling=self.__scaling, update_steps_max=self.__update_steps_max, update_steps_eps=self.__update_steps_eps, + return_raw_associations=return_raw_associations, return_pattern_projections=return_projected_patterns) + + def forward(self, input: Union[Tensor, Tuple[Tensor, Tensor, Tensor]], + stored_pattern_padding_mask: Optional[Tensor] = None, + association_mask: Optional[Tensor] = None) -> Tensor: + """ + Apply Hopfield association on specified data. + + :param input: data to be processed by Hopfield association module + :param stored_pattern_padding_mask: mask to be applied on stored patterns + :param association_mask: mask to be applied on inner association matrix + :return: Hopfield-processed input data + """ + association_output = self._maybe_transpose(self._associate( + data=input, return_raw_associations=False, + stored_pattern_padding_mask=stored_pattern_padding_mask, + association_mask=association_mask)[0]) + if self.association_activation is not None: + association_output = self.association_activation(association_output) + return association_output + + def get_association_matrix(self, input: Union[Tensor, Tuple[Tensor, Tensor, Tensor]], + stored_pattern_padding_mask: Optional[Tensor] = None, + association_mask: Optional[Tensor] = None) -> Tensor: + """ + Fetch Hopfield association matrix gathered by passing through the specified data. + + :param input: data to be passed through the Hopfield association + :param stored_pattern_padding_mask: mask to be applied on stored patterns + :param association_mask: mask to be applied on inner association matrix + :return: association matrix as computed by the Hopfield core module + """ + with torch.no_grad(): + return self._associate( + data=input, return_raw_associations=True, + stored_pattern_padding_mask=stored_pattern_padding_mask, + association_mask=association_mask)[2] + + def get_projected_pattern_matrix(self, input: Union[Tensor, Tuple[Tensor, Tensor, Tensor]], + stored_pattern_padding_mask: Optional[Tensor] = None, + association_mask: Optional[Tensor] = None) -> Tensor: + """ + Fetch Hopfield projected pattern matrix gathered by passing through the specified data. + + :param input: data to be passed through the Hopfield association + :param stored_pattern_padding_mask: mask to be applied on stored patterns + :param association_mask: mask to be applied on inner association matrix + :return: pattern projection matrix as computed by the Hopfield core module + """ + with torch.no_grad(): + return self._associate( + data=input, return_projected_patterns=True, + stored_pattern_padding_mask=stored_pattern_padding_mask, + association_mask=association_mask)[3] + + @property + def batch_first(self) -> bool: + return self.__batch_first + + @property + def scaling(self) -> Union[float, Tensor]: + return self.__scaling.clone() if type(self.__scaling) == Tensor else self.__scaling + + @property + def stored_pattern_dim(self) -> Optional[int]: + return self.association_core.kdim + + @property + def state_pattern_dim(self) -> Optional[int]: + return self.association_core.embed_dim + + @property + def pattern_projection_dim(self) -> Optional[int]: + return self.association_core.vdim + + @property + def input_size(self) -> Optional[int]: + return self.state_pattern_dim + + @property + def hidden_size(self) -> Optional[int]: + return self.association_core.head_dim + + @property + def output_size(self) -> Optional[int]: + return self.association_core.out_dim + + @property + def pattern_size(self) -> Optional[int]: + return self.association_core.pattern_dim + + @property + def update_steps_max(self) -> Optional[Union[int, Tensor]]: + return self.__update_steps_max.clone() if type(self.__update_steps_max) == Tensor else self.__update_steps_max + + @property + def update_steps_eps(self) -> Optional[Union[float, Tensor]]: + return self.__update_steps_eps.clone() if type(self.__update_steps_eps) == Tensor else self.__update_steps_eps + + @property + def stored_pattern_as_static(self) -> bool: + return self.association_core.key_as_static + + @property + def state_pattern_as_static(self) -> bool: + return self.association_core.query_as_static + + @property + def pattern_projection_as_static(self) -> bool: + return self.association_core.value_as_static + + @property + def normalize_stored_pattern(self) -> bool: + return self.norm_stored_pattern is not None + + @property + def normalize_stored_pattern_affine(self) -> bool: + return self.normalize_stored_pattern and self.norm_stored_pattern.elementwise_affine + + @property + def normalize_state_pattern(self) -> bool: + return self.norm_state_pattern is not None + + @property + def normalize_state_pattern_affine(self) -> bool: + return self.normalize_state_pattern and self.norm_state_pattern.elementwise_affine + + @property + def normalize_pattern_projection(self) -> bool: + return self.norm_pattern_projection is not None + + @property + def normalize_pattern_projection_affine(self) -> bool: + return self.normalize_pattern_projection and self.norm_pattern_projection.elementwise_affine + + @property + def normalize_hopfield_space(self) -> bool: + return self.hopfield.normalize_hopfield_space + + @property + def normalize_hopfield_space_affine(self) -> bool: + return self.hopfield.normalize_hopfield_space_affine + + +class HopfieldPooling(Module): + """ + Wrapper class encapsulating a trainable but fixed state pattern and "Hopfield" in + one combined module to be used as a Hopfield-based pooling layer. + """ + + def __init__(self, + input_size: int, + hidden_size: Optional[int] = None, + output_size: Optional[int] = None, + pattern_size: Optional[int] = None, + num_heads: int = 1, + scaling: Optional[Union[float, Tensor]] = None, + update_steps_max: Optional[Union[int, Tensor]] = 0, + update_steps_eps: Union[float, Tensor] = 1e-4, + + normalize_stored_pattern: bool = True, + normalize_stored_pattern_affine: bool = True, + normalize_state_pattern: bool = True, + normalize_state_pattern_affine: bool = True, + normalize_pattern_projection: bool = True, + normalize_pattern_projection_affine: bool = True, + normalize_hopfield_space: bool = False, + normalize_hopfield_space_affine: bool = False, + stored_pattern_as_static: bool = False, + state_pattern_as_static: bool = False, + pattern_projection_as_static: bool = False, + pattern_projection_as_connected: bool = False, + stored_pattern_size: Optional[int] = None, + pattern_projection_size: Optional[int] = None, + + batch_first: bool = True, + association_activation: Optional[str] = None, + dropout: float = 0.0, + input_bias: bool = True, + concat_bias_pattern: bool = False, + add_zero_association: bool = False, + disable_out_projection: bool = False, + quantity: int = 1, + trainable: bool = True + ): + """ + Initialise a new instance of a Hopfield-based pooling layer. + + :param input_size: depth of the input (state pattern) + :param hidden_size: depth of the association space + :param output_size: depth of the output projection + :param pattern_size: depth of patterns to be selected + :param num_heads: amount of parallel association heads + :param scaling: scaling of association heads, often represented as beta (one entry per head) + :param update_steps_max: maximum count of association update steps (None equals to infinity) + :param update_steps_eps: minimum difference threshold between two consecutive association update steps + :param normalize_stored_pattern: apply normalization on stored patterns + :param normalize_stored_pattern_affine: additionally enable affine normalization of stored patterns + :param normalize_state_pattern: apply normalization on state patterns + :param normalize_state_pattern_affine: additionally enable affine normalization of state patterns + :param normalize_pattern_projection: apply normalization on the pattern projection + :param normalize_pattern_projection_affine: additionally enable affine normalization of pattern projection + :param normalize_hopfield_space: enable normalization of patterns in the Hopfield space + :param normalize_hopfield_space_affine: additionally enable affine normalization of patterns in Hopfield space + :param stored_pattern_as_static: interpret specified stored patterns as being static + :param state_pattern_as_static: interpret specified state patterns as being static + :param pattern_projection_as_static: interpret specified pattern projections as being static + :param pattern_projection_as_connected: connect pattern projection with stored pattern + :param stored_pattern_size: depth of input (stored pattern) + :param pattern_projection_size: depth of input (pattern projection) + :param batch_first: flag for specifying if the first dimension of data fed to "forward" reflects the batch size + :param association_activation: additional activation to be applied on the result of the Hopfield association + :param dropout: dropout probability applied on the association matrix + :param input_bias: bias to be added to input (state and stored pattern as well as pattern projection) + :param concat_bias_pattern: bias to be concatenated to stored pattern as well as pattern projection + :param add_zero_association: add a new batch of zeros to stored pattern as well as pattern projection + :param disable_out_projection: disable output projection + :param quantity: amount of state patterns + :param trainable: state pattern used for pooling is trainable + """ + super(HopfieldPooling, self).__init__() + self.hopfield = Hopfield( + input_size=input_size, hidden_size=hidden_size, output_size=output_size, pattern_size=pattern_size, + num_heads=num_heads, scaling=scaling, update_steps_max=update_steps_max, update_steps_eps=update_steps_eps, + normalize_stored_pattern=normalize_stored_pattern, + normalize_stored_pattern_affine=normalize_stored_pattern_affine, + normalize_state_pattern=normalize_state_pattern, + normalize_state_pattern_affine=normalize_state_pattern_affine, + normalize_pattern_projection=normalize_pattern_projection, + normalize_pattern_projection_affine=normalize_pattern_projection_affine, + normalize_hopfield_space=normalize_hopfield_space, + normalize_hopfield_space_affine=normalize_hopfield_space_affine, + stored_pattern_as_static=stored_pattern_as_static, state_pattern_as_static=state_pattern_as_static, + pattern_projection_as_static=pattern_projection_as_static, + pattern_projection_as_connected=pattern_projection_as_connected, stored_pattern_size=stored_pattern_size, + pattern_projection_size=pattern_projection_size, batch_first=batch_first, + association_activation=association_activation, dropout=dropout, input_bias=input_bias, + concat_bias_pattern=concat_bias_pattern, add_zero_association=add_zero_association, + disable_out_projection=disable_out_projection) + self._quantity = quantity + pooling_weight_size = self.hopfield.hidden_size if state_pattern_as_static else self.hopfield.input_size + self.pooling_weights = nn.Parameter(torch.empty(size=(*( + (1, quantity) if batch_first else (quantity, 1) + ), input_size if pooling_weight_size is None else pooling_weight_size)), requires_grad=trainable) + self.reset_parameters() + + def reset_parameters(self) -> None: + """ + Reset pooling weights and underlying Hopfield association. + + :return: None + """ + if hasattr(self.hopfield, r'reset_parameters'): + self.hopfield.reset_parameters() + + # Explicitly initialise pooling weights. + nn.init.normal_(self.pooling_weights, mean=0.0, std=0.02) + + def _prepare_input(self, input: Union[Tensor, Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tensor, Tensor]: + """ + Prepare input for Hopfield association. + + :param input: data to be prepared + :return: stored pattern, expanded state pattern as well as pattern projection + """ + assert (type(input) == Tensor) or ((type(input) == tuple) and (len(input) == 2)), \ + r'either one tensor to be used as "stored pattern" and' \ + r' "pattern_projection" must be provided, or two separate ones.' + if type(input) == Tensor: + stored_pattern, pattern_projection = input, input + else: + stored_pattern, pattern_projection = input + + batch_size = stored_pattern.shape[0 if self.batch_first else 1] + return stored_pattern, self.pooling_weights.expand(size=(*( + (batch_size, self.quantity) if self.batch_first else (self.quantity, batch_size) + ), self.pooling_weights.shape[2])), pattern_projection + + def forward(self, input: Union[Tensor, Tuple[Tensor, Tensor]], stored_pattern_padding_mask: Optional[Tensor] = None, + association_mask: Optional[Tensor] = None) -> Tensor: + """ + Compute Hopfield-based pooling on specified data. + + :param input: data to be pooled + :param stored_pattern_padding_mask: mask to be applied on stored patterns + :param association_mask: mask to be applied on inner association matrix + :return: Hopfield-pooled input data + """ + return self.hopfield( + input=self._prepare_input(input=input), + stored_pattern_padding_mask=stored_pattern_padding_mask, + association_mask=association_mask).flatten(start_dim=1) + + def get_association_matrix(self, input: Union[Tensor, Tuple[Tensor, Tensor]], + stored_pattern_padding_mask: Optional[Tensor] = None, + association_mask: Optional[Tensor] = None) -> Tensor: + """ + Fetch Hopfield association matrix used for pooling gathered by passing through the specified data. + + :param input: data to be passed through the Hopfield association + :param stored_pattern_padding_mask: mask to be applied on stored patterns + :param association_mask: mask to be applied on inner association matrix + :return: association matrix as computed by the Hopfield core module + """ + with torch.no_grad(): + return self.hopfield.get_association_matrix( + input=self._prepare_input(input=input), + stored_pattern_padding_mask=stored_pattern_padding_mask, + association_mask=association_mask) + + def get_projected_pattern_matrix(self, input: Union[Tensor, Tuple[Tensor, Tensor]], + stored_pattern_padding_mask: Optional[Tensor] = None, + association_mask: Optional[Tensor] = None) -> Tensor: + """ + Fetch Hopfield projected pattern matrix gathered by passing through the specified data. + + :param input: data to be passed through the Hopfield association + :param stored_pattern_padding_mask: mask to be applied on stored patterns + :param association_mask: mask to be applied on inner association matrix + :return: pattern projection matrix as computed by the Hopfield core module + """ + with torch.no_grad(): + return self.hopfield.get_projected_pattern_matrix( + input=self._prepare_input(input=input), + stored_pattern_padding_mask=stored_pattern_padding_mask, + association_mask=association_mask) + + @property + def batch_first(self) -> bool: + return self.hopfield.batch_first + + @property + def scaling(self) -> Union[float, Tensor]: + return self.hopfield.scaling + + @property + def stored_pattern_dim(self) -> Optional[int]: + return self.hopfield.stored_pattern_dim + + @property + def state_pattern_dim(self) -> Optional[int]: + return self.hopfield.state_pattern_dim + + @property + def pattern_projection_dim(self) -> Optional[int]: + return self.hopfield.pattern_projection_dim + + @property + def input_size(self) -> Optional[int]: + return self.hopfield.input_size + + @property + def hidden_size(self) -> int: + return self.hopfield.hidden_size + + @property + def output_size(self) -> Optional[int]: + return self.hopfield.output_size + + @property + def pattern_size(self) -> Optional[int]: + return self.hopfield.pattern_size + + @property + def quantity(self) -> int: + return self._quantity + + @property + def update_steps_max(self) -> Optional[Union[int, Tensor]]: + return self.hopfield.update_steps_max + + @property + def update_steps_eps(self) -> Optional[Union[float, Tensor]]: + return self.hopfield.update_steps_eps + + @property + def stored_pattern_as_static(self) -> bool: + return self.hopfield.stored_pattern_as_static + + @property + def state_pattern_as_static(self) -> bool: + return self.hopfield.state_pattern_as_static + + @property + def pattern_projection_as_static(self) -> bool: + return self.hopfield.pattern_projection_as_static + + @property + def normalize_stored_pattern(self) -> bool: + return self.hopfield.normalize_stored_pattern + + @property + def normalize_stored_pattern_affine(self) -> bool: + return self.hopfield.normalize_stored_pattern_affine + + @property + def normalize_state_pattern(self) -> bool: + return self.hopfield.normalize_state_pattern + + @property + def normalize_state_pattern_affine(self) -> bool: + return self.hopfield.normalize_state_pattern_affine + + @property + def normalize_pattern_projection(self) -> bool: + return self.hopfield.normalize_pattern_projection + + @property + def normalize_pattern_projection_affine(self) -> bool: + return self.hopfield.normalize_pattern_projection_affine + + +class HopfieldLayer(Module): + """ + Wrapper class encapsulating a trainable but fixed stored pattern, pattern projection and "Hopfield" in + one combined module to be used as a Hopfield-based pooling layer. + """ + + def __init__(self, + input_size: int, + hidden_size: Optional[int] = None, + output_size: Optional[int] = None, + pattern_size: Optional[int] = None, + num_heads: int = 1, + scaling: Optional[Union[float, Tensor]] = None, + update_steps_max: Optional[Union[int, Tensor]] = 0, + update_steps_eps: Union[float, Tensor] = 1e-4, + lookup_weights_as_separated: bool = False, + lookup_targets_as_trainable: bool = True, + + normalize_stored_pattern: bool = True, + normalize_stored_pattern_affine: bool = True, + normalize_state_pattern: bool = True, + normalize_state_pattern_affine: bool = True, + normalize_pattern_projection: bool = True, + normalize_pattern_projection_affine: bool = True, + normalize_hopfield_space: bool = False, + normalize_hopfield_space_affine: bool = False, + stored_pattern_as_static: bool = False, + state_pattern_as_static: bool = False, + pattern_projection_as_static: bool = False, + pattern_projection_as_connected: bool = False, + stored_pattern_size: Optional[int] = None, + pattern_projection_size: Optional[int] = None, + + batch_first: bool = True, + association_activation: Optional[str] = None, + dropout: float = 0.0, + input_bias: bool = True, + concat_bias_pattern: bool = False, + add_zero_association: bool = False, + disable_out_projection: bool = False, + quantity: int = 1, + trainable: bool = True + ): + """ + Initialise a new instance of a Hopfield-based lookup layer. + + :param input_size: depth of the input (state pattern) + :param hidden_size: depth of the association space + :param output_size: depth of the output projection + :param pattern_size: depth of patterns to be selected + :param num_heads: amount of parallel association heads + :param scaling: scaling of association heads, often represented as beta (one entry per head) + :param update_steps_max: maximum count of association update steps (None equals to infinity) + :param update_steps_eps: minimum difference threshold between two consecutive association update steps + :param lookup_weights_as_separated: separate lookup weights from lookup target weights + :param lookup_targets_as_trainable: employ trainable lookup target weights (used as pattern projection input) + :param normalize_stored_pattern: apply normalization on stored patterns + :param normalize_stored_pattern_affine: additionally enable affine normalization of stored patterns + :param normalize_state_pattern: apply normalization on state patterns + :param normalize_state_pattern_affine: additionally enable affine normalization of state patterns + :param normalize_pattern_projection: apply normalization on the pattern projection + :param normalize_pattern_projection_affine: additionally enable affine normalization of pattern projection + :param normalize_hopfield_space: enable normalization of patterns in the Hopfield space + :param normalize_hopfield_space_affine: additionally enable affine normalization of patterns in Hopfield space + :param stored_pattern_as_static: interpret specified stored patterns as being static + :param state_pattern_as_static: interpret specified state patterns as being static + :param pattern_projection_as_static: interpret specified pattern projections as being static + :param pattern_projection_as_connected: connect pattern projection with stored pattern + :param stored_pattern_size: depth of input (stored pattern) + :param pattern_projection_size: depth of input (pattern projection) + :param batch_first: flag for specifying if the first dimension of data fed to "forward" reflects the batch size + :param association_activation: additional activation to be applied on the result of the Hopfield association + :param dropout: dropout probability applied on the association matrix + :param input_bias: bias to be added to input (state and stored pattern as well as pattern projection) + :param concat_bias_pattern: bias to be concatenated to stored pattern as well as pattern projection + :param add_zero_association: add a new batch of zeros to stored pattern as well as pattern projection + :param disable_out_projection: disable output projection + :param quantity: amount of stored patterns + :param trainable: stored pattern used for lookup is trainable + """ + super(HopfieldLayer, self).__init__() + self.hopfield = Hopfield( + input_size=input_size, hidden_size=hidden_size, output_size=output_size, pattern_size=pattern_size, + num_heads=num_heads, scaling=scaling, update_steps_max=update_steps_max, update_steps_eps=update_steps_eps, + normalize_stored_pattern=normalize_stored_pattern, + normalize_stored_pattern_affine=normalize_stored_pattern_affine, + normalize_state_pattern=normalize_state_pattern, + normalize_state_pattern_affine=normalize_state_pattern_affine, + normalize_pattern_projection=normalize_pattern_projection, + normalize_pattern_projection_affine=normalize_pattern_projection_affine, + normalize_hopfield_space=normalize_hopfield_space, + normalize_hopfield_space_affine=normalize_hopfield_space_affine, + stored_pattern_as_static=stored_pattern_as_static, state_pattern_as_static=state_pattern_as_static, + pattern_projection_as_static=pattern_projection_as_static, + pattern_projection_as_connected=pattern_projection_as_connected, stored_pattern_size=stored_pattern_size, + pattern_projection_size=pattern_projection_size, batch_first=batch_first, + association_activation=association_activation, dropout=dropout, input_bias=input_bias, + concat_bias_pattern=concat_bias_pattern, add_zero_association=add_zero_association, + disable_out_projection=disable_out_projection) + self._quantity = quantity + lookup_weight_size = self.hopfield.hidden_size if stored_pattern_as_static else self.hopfield.stored_pattern_dim + self.lookup_weights = nn.Parameter(torch.empty(size=(*( + (1, quantity) if batch_first else (quantity, 1) + ), input_size if lookup_weight_size is None else lookup_weight_size)), requires_grad=trainable) + + if lookup_weights_as_separated: + target_weight_size = self.lookup_weights.shape[ + 2] if pattern_projection_size is None else pattern_projection_size + self.target_weights = nn.Parameter(torch.empty(size=(*( + (1, quantity) if batch_first else (quantity, 1) + ), target_weight_size)), requires_grad=lookup_targets_as_trainable) + else: + self.register_parameter(name=r'target_weights', param=None) + self.reset_parameters() + + def reset_parameters(self) -> None: + """ + Reset lookup and lookup target weights, including underlying Hopfield association. + + :return: None + """ + if hasattr(self.hopfield, r'reset_parameters'): + self.hopfield.reset_parameters() + + # Explicitly initialise lookup and target weights. + nn.init.normal_(self.lookup_weights, mean=0.0, std=0.02) + if self.target_weights is not None: + nn.init.normal_(self.target_weights, mean=0.0, std=0.02) + + def _prepare_input(self, input: Tensor) -> Tuple[Tensor, Tensor, Tensor]: + """ + Prepare input for Hopfield association. + + :param input: data to be prepared + :return: stored pattern, expanded state pattern as well as pattern projection + """ + batch_size = input.shape[0 if self.batch_first else 1] + stored_pattern = self.lookup_weights.expand(size=(*( + (batch_size, self.quantity) if self.batch_first else (self.quantity, batch_size) + ), self.lookup_weights.shape[2])) + if self.target_weights is None: + pattern_projection = stored_pattern + else: + pattern_projection = self.target_weights.expand(size=(*( + (batch_size, self.quantity) if self.batch_first else (self.quantity, batch_size) + ), self.target_weights.shape[2])) + + return stored_pattern, input, pattern_projection + + def forward(self, input: Tensor, stored_pattern_padding_mask: Optional[Tensor] = None, + association_mask: Optional[Tensor] = None) -> Tensor: + """ + Compute Hopfield-based lookup on specified data. + + :param input: data to used in lookup + :param stored_pattern_padding_mask: mask to be applied on stored patterns + :param association_mask: mask to be applied on inner association matrix + :return: result of Hopfield-based lookup on input data + """ + return self.hopfield( + input=self._prepare_input(input=input), + stored_pattern_padding_mask=stored_pattern_padding_mask, + association_mask=association_mask) + + def get_association_matrix(self, input: Tensor, stored_pattern_padding_mask: Optional[Tensor] = None, + association_mask: Optional[Tensor] = None) -> Tensor: + """ + Fetch Hopfield association matrix used for lookup gathered by passing through the specified data. + + :param input: data to be passed through the Hopfield association + :param stored_pattern_padding_mask: mask to be applied on stored patterns + :param association_mask: mask to be applied on inner association matrix + :return: association matrix as computed by the Hopfield core module + """ + with torch.no_grad(): + return self.hopfield.get_association_matrix( + input=self._prepare_input(input=input), + stored_pattern_padding_mask=stored_pattern_padding_mask, + association_mask=association_mask) + + def get_projected_pattern_matrix(self, input: Union[Tensor, Tuple[Tensor, Tensor]], + stored_pattern_padding_mask: Optional[Tensor] = None, + association_mask: Optional[Tensor] = None) -> Tensor: + """ + Fetch Hopfield projected pattern matrix gathered by passing through the specified data. + + :param input: data to be passed through the Hopfield association + :param stored_pattern_padding_mask: mask to be applied on stored patterns + :param association_mask: mask to be applied on inner association matrix + :return: pattern projection matrix as computed by the Hopfield core module + """ + with torch.no_grad(): + return self.hopfield.get_projected_pattern_matrix( + input=self._prepare_input(input=input), + stored_pattern_padding_mask=stored_pattern_padding_mask, + association_mask=association_mask) + + @property + def batch_first(self) -> bool: + return self.hopfield.batch_first + + @property + def scaling(self) -> Union[float, Tensor]: + return self.hopfield.scaling + + @property + def stored_pattern_dim(self) -> Optional[int]: + return self.hopfield.stored_pattern_dim + + @property + def state_pattern_dim(self) -> Optional[int]: + return self.hopfield.state_pattern_dim + + @property + def pattern_projection_dim(self) -> Optional[int]: + return self.hopfield.pattern_projection_dim + + @property + def input_size(self) -> Optional[int]: + return self.hopfield.input_size + + @property + def hidden_size(self) -> int: + return self.hopfield.hidden_size + + @property + def output_size(self) -> Optional[int]: + return self.hopfield.output_size + + @property + def pattern_size(self) -> Optional[int]: + return self.hopfield.pattern_size + + @property + def quantity(self) -> int: + return self._quantity + + @property + def update_steps_max(self) -> Optional[Union[int, Tensor]]: + return self.hopfield.update_steps_max + + @property + def update_steps_eps(self) -> Optional[Union[float, Tensor]]: + return self.hopfield.update_steps_eps + + @property + def stored_pattern_as_static(self) -> bool: + return self.hopfield.stored_pattern_as_static + + @property + def state_pattern_as_static(self) -> bool: + return self.hopfield.state_pattern_as_static + + @property + def pattern_projection_as_static(self) -> bool: + return self.hopfield.pattern_projection_as_static + + @property + def normalize_stored_pattern(self) -> bool: + return self.hopfield.normalize_stored_pattern + + @property + def normalize_stored_pattern_affine(self) -> bool: + return self.hopfield.normalize_stored_pattern_affine + + @property + def normalize_state_pattern(self) -> bool: + return self.hopfield.normalize_state_pattern + + @property + def normalize_state_pattern_affine(self) -> bool: + return self.hopfield.normalize_state_pattern_affine + + @property + def normalize_pattern_projection(self) -> bool: + return self.hopfield.normalize_pattern_projection + + @property + def normalize_pattern_projection_affine(self) -> bool: + return self.hopfield.normalize_pattern_projection_affine diff --git a/src/mhnfs/hopfield/examples/mnist_bags/modules/activation.py b/src/mhnfs/hopfield/examples/mnist_bags/modules/activation.py new file mode 100644 index 0000000000000000000000000000000000000000..6dcd6e56cddb954cf6b049687bdf5e7783aa2bc9 --- /dev/null +++ b/src/mhnfs/hopfield/examples/mnist_bags/modules/activation.py @@ -0,0 +1,337 @@ +import torch +import torch.nn as nn + +from torch import Tensor +from torch.nn import Linear, Module, Parameter +from typing import Optional + +from .functional import hopfield_core_forward + +try: + from torch.nn.modules.linear import _LinearWithBias +except ImportError: + _LinearWithBias = None + + +class HopfieldCore(Module): + r"""Allows the model to jointly attend to information + from different representation subspaces. + See references: "Hopfield Networks is All You Need" and + "Attention Is All You Need" (on which this implementation is partly based on). + + .. math:: + \text{HopfieldHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O + \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) + + Args: + embed_dim: total dimension of the model. + num_heads: parallel attention heads. + dropout: a Dropout layer on attn_output_weights. Default: 0.0. + bias: add bias as module parameter. Default: True. + add_bias_kv: add bias to the key and value sequences at dim=0. + add_zero_attn: add a new batch of zeros to the key and + value sequences at dim=1. + kdim: total number of features in key. Default: None. + vdim: total number of features in value. Default: None. + + Note: if kdim and vdim are None, they will be set to embed_dim such that + query, key, and value have the same number of features. + + Examples:: + + >>> hopfield_attn = HopfieldCore(embed_dim, num_heads) + >>> attn_output, attn_output_weights, attn_matrix = hopfield_attn(query, key, value) + """ + __annotations__ = { + 'bias_k': torch._jit_internal.Optional[torch.Tensor], + 'bias_v': torch._jit_internal.Optional[torch.Tensor], + } + + def __init__(self, + embed_dim=None, # type: Optional[int] + num_heads=1, # type: int + dropout=0.0, # type: float + bias=True, # type: bool + add_bias_kv=False, # type: bool + add_zero_attn=False, # type: bool + kdim=None, # type: Optional[int] + vdim=None, # type: Optional[int] + + head_dim=None, # type: Optional[int] + pattern_dim=None, # type: Optional[int] + out_dim=None, # type: Optional[int] + disable_out_projection=False, # type: bool + key_as_static=False, # type: bool + query_as_static=False, # type: bool + value_as_static=False, # type: bool + value_as_connected=False, # type: bool + normalize_pattern=False, # type: bool + normalize_pattern_affine=False # type: bool + ): + super(HopfieldCore, self).__init__() + + assert (type(key_as_static) == bool) and (type(query_as_static) == bool) and (type(value_as_static) == bool) + self.key_as_static, self.query_as_static, self.value_as_static = key_as_static, query_as_static, value_as_static + num_non_static = 3 - (self.key_as_static + self.query_as_static + self.value_as_static) + assert 0 <= num_non_static < 4 + + self.value_as_connected = value_as_connected + self.normalize_pattern, self.normalize_pattern_affine = normalize_pattern, normalize_pattern_affine + self.disable_out_projection = disable_out_projection + + # In case of a static-only executions, check corresponding projections and normalizations. + self.static_execution = self._check_execution_mode() + if self.static_execution: + embed_dim, kdim, vdim = None, None, None + if embed_dim is None: + assert self.static_execution, r'static-only execution requires all projections to be deactivated.' + + # Check and set all other properties, conditioned on . + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self._qkv_same_embed_dim = all(( + self.kdim == embed_dim, self.vdim == embed_dim, pattern_dim is None, not self.value_as_connected)) + assert (not self.value_as_connected) or (self.kdim == self.vdim), r'key and value need to be of same dimension.' + + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = None + self.pattern_dim = pattern_dim + self.virtual_hopfield_dim = None + self.virtual_pattern_dim = None + if not self.static_execution: + if head_dim is None: + self.head_dim = embed_dim // num_heads + assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads." + else: + assert head_dim > 0, "dimension of the association space has to be positive." + self.head_dim = head_dim + if self.pattern_dim is None: + self.pattern_dim = self.head_dim + self.virtual_hopfield_dim = self.num_heads * self.head_dim + self.virtual_pattern_dim = self.num_heads * self.pattern_dim + + self.out_dim = embed_dim if out_dim is None else out_dim + assert disable_out_projection or (self.out_dim > 0), "output projection dimension has to be positive." + + if normalize_pattern_affine: + assert normalize_pattern, "affine pattern normalization without pattern normalization has no effect." + self.p_norm_weight = Parameter(torch.Tensor(head_dim)) + self.p_norm_bias = Parameter(torch.Tensor(head_dim)) + else: + self.register_parameter('p_norm_weight', None) + self.register_parameter('p_norm_bias', None) + + if self._qkv_same_embed_dim is False: + if query_as_static: + self.register_parameter('q_proj_weight', None) + else: + self.q_proj_weight = Parameter(torch.Tensor(self.virtual_hopfield_dim, embed_dim)) + if key_as_static: + self.register_parameter('k_proj_weight', None) + else: + self.k_proj_weight = Parameter(torch.Tensor(self.virtual_hopfield_dim, self.kdim)) + if value_as_static: + self.register_parameter('v_proj_weight', None) + else: + self.v_proj_weight = Parameter(torch.Tensor( + self.virtual_pattern_dim, + self.virtual_hopfield_dim if (value_as_connected and not key_as_static) else self.vdim)) + self.register_parameter('in_proj_weight', None) + else: + if num_non_static > 0: + self.in_proj_weight = Parameter(torch.empty( + (not query_as_static) * self.virtual_hopfield_dim + + (not key_as_static) * self.virtual_hopfield_dim + + (not value_as_static) * self.virtual_pattern_dim, embed_dim)) + else: + self.register_parameter('in_proj_weight', None) + self.register_parameter('q_proj_weight', None) + self.register_parameter('k_proj_weight', None) + self.register_parameter('v_proj_weight', None) + + if bias and (num_non_static > 0): + self.in_proj_bias = Parameter(torch.empty( + (not query_as_static) * self.virtual_hopfield_dim + + (not key_as_static) * self.virtual_hopfield_dim + self.virtual_pattern_dim)) + else: + self.register_parameter('in_proj_bias', None) + if disable_out_projection: + self.register_parameter('out_proj', None) + else: + if bias and _LinearWithBias is not None: + self.out_proj = _LinearWithBias(self.virtual_pattern_dim, self.out_dim) + else: + self.out_proj = Linear(self.virtual_pattern_dim, self.out_dim, bias=bias) + + self.bias_k, self.bias_v = None, None + if add_bias_kv: + if not key_as_static: + self.bias_k = Parameter(torch.empty(1, 1, self.virtual_hopfield_dim)) + if not value_as_static: + self.bias_v = Parameter(torch.empty(1, 1, self.virtual_hopfield_dim)) + assert not (self.bias_k is None and self.bias_v is None), r'cannot set key/value bias if both are static.' + + self.add_zero_attn = add_zero_attn + self.reset_parameters() + + def _check_execution_mode(self) -> bool: + return all(( + self.key_as_static, self.query_as_static, self.value_as_static, not self.value_as_connected, + not self.normalize_pattern, not self.normalize_pattern_affine, self.disable_out_projection + )) + + def reset_parameters(self): + if self.p_norm_weight is not None: + nn.init.ones_(self.p_norm_weight) + nn.init.zeros_(self.p_norm_bias) + + if self._qkv_same_embed_dim and (self.in_proj_weight is not None): + nn.init.normal_(self.in_proj_weight, mean=0.0, std=0.02) + else: + if self.q_proj_weight is not None: + nn.init.normal_(self.q_proj_weight, mean=0.0, std=0.02) + if self.k_proj_weight is not None: + nn.init.normal_(self.k_proj_weight, mean=0.0, std=0.02) + if self.v_proj_weight is not None: + nn.init.normal_(self.v_proj_weight, mean=0.0, std=0.02) + + if self.in_proj_bias is not None: + nn.init.constant_(self.in_proj_bias, 0.0) + if not self.disable_out_projection: + nn.init.normal_(self.out_proj.weight, mean=0.0, std=0.02) + if self.out_proj.bias is not None: + nn.init.constant_(self.out_proj.bias, 0.0) + if self.bias_k is not None: + nn.init.normal_(self.bias_k, mean=0.0, std=0.02) + if self.bias_v is not None: + nn.init.normal_(self.bias_v, mean=0.0, std=0.02) + + def __setstate__(self, state): + super(HopfieldCore, self).__setstate__(state) + + def forward(self, + query, # type: Tensor + key, # type: Tensor + value, # type: Tensor + key_padding_mask=None, # type: Optional[Tensor] + need_weights=True, # type: bool + attn_mask=None, # type: Optional[Tensor] + + scaling=None, # type: Optional[Tensor] + update_steps_max=0, # type: Optional[int] + update_steps_eps=1e-4, # type: float + return_raw_associations=False, # type: bool + return_pattern_projections=False # type: bool + ): + # type: (...) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]] + r""" + Args: + query, key, value: map a query and a set of key-value pairs to an output. + See "Attention Is All You Need" for more details. + See "Hopfield Networks is All You Need" for more details in the setting of Hopfield networks. + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. When given a binary mask and a value is True, + the corresponding value on the attention layer will be ignored. When given + a byte mask and a value is non-zero, the corresponding value on the attention + layer will be ignored. + need_weights: output attn_output_weights. + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + + scaling: scaling of association heads, often represented as beta (one entry per head). + update_steps_max: maximum count of association update steps (None equals to infinity). + update_steps_eps: minimum difference threshold between two consecutive association update steps. + return_raw_associations: return raw association (softmax) values, unmodified. + return_pattern_projections: return pattern projection values, unmodified. + + Shape: + - Inputs: + - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the position + with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + + - scaling: :math:`(num_heads,)`, where num_heads is the amount of heads. + + - Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_output_weights: :math:`(N, L, S)` where N is the batch size, + L is the target sequence length, S is the source sequence length. + - attn_raw: :math:``(N, num_heads, L, S)`, where N is the batch size, + L is the target sequence length, S is the source sequence length. + """ + if self.query_as_static and self.key_as_static: + assert query.shape[2] == key.shape[2], \ + f'query shape[2] of {query.shape[2]} and key shape[2] of {key.shape[2]} need to be equal' + head_dim, embed_dim_to_check = query.shape[2], query.shape[2] + else: + assert self.query_as_static or (query.shape[2] == self.embed_dim), \ + f'query shape[2] of {query.shape[2]} invalid, needs to be {self.embed_dim}.' + assert (not self.query_as_static) or (self.query_as_static and query.shape[2] == self.head_dim), \ + f'query shape[2] of {query.shape[2]} invalid, needs to be {self.head_dim}' + + assert self.key_as_static or (key.shape[2] == self.kdim), \ + f'key shape[2] of {key.shape[2]} invalid, needs to be {self.kdim}.' + assert (not self.key_as_static) or (self.key_as_static and key.shape[2] == self.head_dim), \ + f'key shape[2] of {key.shape[2]} invalid, needs to be {self.head_dim}' + head_dim, embed_dim_to_check = self.head_dim, self.head_dim if self.query_as_static else self.embed_dim + + assert self.value_as_static or (value.shape[2] == self.vdim), \ + f'value shape[2] of {value.shape[2]} invalid, needs to be {self.vdim}.' + assert any(( + not self.value_as_static, self.value_as_static and value.shape[2] == self.pattern_dim, + self.disable_out_projection) + ), f'value shape[2] of {value.shape[2]} invalid, needs to be {self.pattern_dim}' + + out_weights, out_bias = None, None + if not self.disable_out_projection: + out_weights, out_bias = self.out_proj.weight, self.out_proj.bias + + if not self._qkv_same_embed_dim: + return hopfield_core_forward( + query=query, key=key, value=value, embed_dim_to_check=embed_dim_to_check, num_heads=self.num_heads, + in_proj_weight=self.in_proj_weight, in_proj_bias=self.in_proj_bias, bias_k=self.bias_k, + bias_v=self.bias_v, add_zero_attn=self.add_zero_attn, dropout_p=self.dropout, + out_proj_weight=out_weights, out_proj_bias=out_bias, training=self.training, + key_padding_mask=key_padding_mask, need_weights=need_weights, attn_mask=attn_mask, + use_separate_proj_weight=True, q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight, + v_proj_weight=self.v_proj_weight, + + key_as_static=self.key_as_static, query_as_static=self.query_as_static, + value_as_static=self.value_as_static, value_as_connected=self.value_as_connected, + normalize_pattern=self.normalize_pattern, + p_norm_weight=self.p_norm_weight, p_norm_bias=self.p_norm_bias, + head_dim=head_dim, pattern_dim=self.pattern_dim, scaling=scaling, + update_steps_max=update_steps_max, update_steps_eps=update_steps_eps, + return_raw_associations=return_raw_associations, return_projected_patterns=return_pattern_projections) + else: + return hopfield_core_forward( + query=query, key=key, value=value, embed_dim_to_check=embed_dim_to_check, num_heads=self.num_heads, + in_proj_weight=self.in_proj_weight, in_proj_bias=self.in_proj_bias, bias_k=self.bias_k, + bias_v=self.bias_v, add_zero_attn=self.add_zero_attn, dropout_p=self.dropout, + out_proj_weight=out_weights, out_proj_bias=out_bias, training=self.training, + key_padding_mask=key_padding_mask, need_weights=need_weights, attn_mask=attn_mask, + + key_as_static=self.key_as_static, query_as_static=self.query_as_static, + value_as_static=self.value_as_static, value_as_connected=self.value_as_connected, + normalize_pattern=self.normalize_pattern, + p_norm_weight=self.p_norm_weight, p_norm_bias=self.p_norm_bias, + head_dim=head_dim, pattern_dim=self.pattern_dim, scaling=scaling, + update_steps_max=update_steps_max, update_steps_eps=update_steps_eps, + return_raw_associations=return_raw_associations, return_projected_patterns=return_pattern_projections) diff --git a/src/mhnfs/hopfield/examples/mnist_bags/modules/functional.py b/src/mhnfs/hopfield/examples/mnist_bags/modules/functional.py new file mode 100644 index 0000000000000000000000000000000000000000..2619c45fd87dc5f0348db393aa9b305611b4dd32 --- /dev/null +++ b/src/mhnfs/hopfield/examples/mnist_bags/modules/functional.py @@ -0,0 +1,450 @@ +import torch +import torch.nn as nn + +from torch.tensor import Tensor +from typing import Optional, Tuple, Union + + +def hopfield_core_forward(query, # type: Tensor + key, # type: Tensor + value, # type: Tensor + embed_dim_to_check, # type: int + num_heads, # type: int + in_proj_weight, # type: Optional[Tensor] + in_proj_bias, # type: Optional[Tensor] + bias_k, # type: Optional[Tensor] + bias_v, # type: Optional[Tensor] + add_zero_attn, # type: bool + dropout_p, # type: float + out_proj_weight, # type: Tensor + out_proj_bias, # type: Tensor + training=True, # type: bool + key_padding_mask=None, # type: Optional[Tensor] + need_weights=True, # type: bool + attn_mask=None, # type: Optional[Tensor] + use_separate_proj_weight=False, # type: bool + q_proj_weight=None, # type: Optional[Tensor] + k_proj_weight=None, # type: Optional[Tensor] + v_proj_weight=None, # type: Optional[Tensor] + static_k=None, # type: Optional[Tensor] + static_v=None, # type: Optional[Tensor] + + key_as_static=False, # type: bool + query_as_static=False, # type: bool + value_as_static=False, # type: bool + value_as_connected=False, # type: bool + normalize_pattern=False, # type: bool + p_norm_weight=None, # type: Optional[Tensor] + p_norm_bias=None, # type: Optional[Tensor] + head_dim=None, # type: Optional[int] + pattern_dim=None, # type: Optional[int] + scaling=None, # type: Optional[Union[float, Tensor]] + update_steps_max=0, # type: Optional[Union[int, Tensor]] + update_steps_eps=1e-4, # type: Union[float, Tensor] + return_raw_associations=False, # type: bool + return_projected_patterns=False # type: bool + ): + # type: (...) -> Tuple[Tensor, Optional[Tensor]] + r""" + Args: + query, key, value: map a query and a set of key-value pairs to an output. + See "Attention Is All You Need" for more details. + See "Hopfield Networks is All You Need" for more details in the setting of Hopfield networks. + embed_dim_to_check: total dimension of the model (in case of default head dimension). + num_heads: parallel attention heads. + in_proj_weight, in_proj_bias: input projection weight and bias. + bias_k, bias_v: bias of the key and value sequences to be added at dim=0. + add_zero_attn: add a new batch of zeros to the key and + value sequences at dim=1. + dropout_p: probability of an element to be zeroed. + out_proj_weight, out_proj_bias: the output projection weight and bias. + training: apply dropout if is ``True``. + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. This is an binary mask. When the value is True, + the corresponding value on the attention layer will be filled with -inf. + need_weights: output attn_output_weights. + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + use_separate_proj_weight: the function accept the proj. weights for query, key, + and value in different forms. If false, in_proj_weight will be used, which is + a combination of q_proj_weight, k_proj_weight, v_proj_weight. + q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias. + static_k, static_v: static key and value used for attention operators. + + key_as_static: interpret specified key as being static. + query_as_static: interpret specified key as being static. + value_as_static: interpret specified key as being static. + value_as_connected: connect value projection with key projection. + normalize_pattern: enable normalization of patterns. + p_norm_weight, p_norm_bias: pattern normalization weight and bias. + head_dim: dimensionality of each head. + pattern_dim: dimensionality of each projected value input. + scaling: scaling of association heads, often represented as beta (one entry per head). + update_steps_max: maximum count of association update steps (None equals to infinity). + update_steps_eps: minimum difference threshold between two consecutive association update steps. + return_raw_associations: return raw association (softmax) values, unmodified. + return_projected_patterns: return pattern projection values, unmodified. + + Shape: + Inputs: + - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions + will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + - static_k: :math:`(N*num_heads, S, head_dim)`, where S is the source sequence length, N is the batch size. + - static_v: :math:`(N*num_heads, S, head_dim)`, where S is the source sequence length, N is the batch size. + + - scaling: :math:`(num_heads,)`, where num_heads is the amount of heads. + + Outputs: + - attn_output: :math:`(L, N, E)`, where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_output_weights: :math:`(N, L, S)`, where N is the batch size, + L is the target sequence length, S is the source sequence length. + - attn_raw: :math:``(N, num_heads, L, S)`, where N is the batch size, + L is the target sequence length, S is the source sequence length. + """ + if not torch.jit.is_scripting(): + tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, + out_proj_weight, out_proj_bias) + if any([type(t) is not Tensor for t in tens_ops]) and nn.functional.has_torch_function(tens_ops): + return nn.functional.handle_torch_function( + hopfield_core_forward, tens_ops, query, key, value, + embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias, + bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight, + out_proj_bias, training=training, key_padding_mask=key_padding_mask, + need_weights=need_weights, attn_mask=attn_mask, + use_separate_proj_weight=use_separate_proj_weight, + q_proj_weight=q_proj_weight, k_proj_weight=k_proj_weight, + v_proj_weight=v_proj_weight, static_k=static_k, static_v=static_v, + key_as_static=key_as_static, query_as_static=query_as_static, + value_as_static=value_as_static, value_as_connected=value_as_connected, + normalize_pattern=normalize_pattern, p_norm_weight=p_norm_weight, p_norm_bias=p_norm_bias, + head_dim=head_dim, pattern_dim=pattern_dim, scaling=scaling, update_steps_max=update_steps_max, + update_steps_eps=update_steps_eps, return_raw_associations=return_raw_associations) + tgt_len, bsz, embed_dim = query.shape[0], value.shape[1], query.shape[2] + assert embed_dim == embed_dim_to_check + # allow MHA to have different sizes for the feature dimension + assert key.size(0) == value.size(0) and key.size(1) == value.size(1) + + assert (scaling is None) or (type(scaling) in (float, torch.Tensor)) + if type(scaling) == torch.Tensor: + assert scaling.ndimension() == 1 and scaling.shape[0] == num_heads, "only one entry per head." + + assert (update_steps_max is None) or (type(update_steps_max) in (int, torch.Tensor)) + if type(update_steps_max) == torch.Tensor: + assert update_steps_max.ndimension() == 1 and update_steps_max.shape[0] == num_heads, "only one entry per head." + elif type(update_steps_max) == int: + update_steps_max = torch.tensor([update_steps_max] * num_heads, dtype=torch.int32, device=query.device) + elif update_steps_max is None: + update_steps_max = -torch.ones(size=(num_heads,), dtype=torch.int32, device=query.device) + + assert type(update_steps_eps) in (float, torch.Tensor) + if type(update_steps_eps) == torch.Tensor: + assert update_steps_eps.ndimension() == 1 and update_steps_eps.shape[0] == num_heads, "only one entry per head." + assert (update_steps_eps <= 0.0).sum() == 0, "only positive thresholds allowed." + update_steps_eps = update_steps_eps.to(device=query.device) + elif type(update_steps_eps) == float: + assert update_steps_eps > 0, "only positive thresholds allowed." + update_steps_eps = torch.tensor([update_steps_eps] * num_heads, dtype=query.dtype, device=query.device) + + # Adapt dimensionality of each each. + if head_dim is None: + head_dim = embed_dim // num_heads + assert head_dim * num_heads == embed_dim, r'embed_dim must be divisible by num_heads.' + hopfield_dim = num_heads * head_dim + + # Adapt dimensionality of each value projection. + if pattern_dim is None: + pattern_dim = head_dim + assert (not value_as_connected) or (pattern_dim == head_dim) + + q, k, v, xi, src_len = None, None, None, None, 0 + update_step, xi_old, xi_difference_norm = 0, None, float(r'+inf') + update_active_heads = torch.tensor([[[True]]] * num_heads * bsz, device=query.device) + assert update_active_heads.any(), "at least one head needs to be active." + + #################################################################################################################### + # BEGIN HOPFIELD UPDATE ITERATION # + #################################################################################################################### + + while update_active_heads.any(): + + # The query is already projected into the "Hopfield" space at "update_step" equals 0. + # No more projection necessary if "update_step" greater than 0. + if update_step == 0: + if not use_separate_proj_weight: + + if torch.equal(query, key) and torch.equal(key, value) and not ( + key_as_static or query_as_static or value_as_static): + # self-attention + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1) + + elif torch.equal(key, value) and not (key_as_static or value_as_static): + # encoder-decoder attention + _start, _end = 0, hopfield_dim + if query_as_static: + q = query.repeat(1, num_heads, 1) + else: + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + q = nn.functional.linear(query, _w, _b) + _start = hopfield_dim + _end = None + + if key is None: + assert value is None + k = None + v = None + else: + + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1) + + else: + _start, _end = 0, hopfield_dim + if query_as_static: + q = query.repeat(1, num_heads, 1) + else: + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + q = nn.functional.linear(query, _w, _b) + _start += hopfield_dim + _end += hopfield_dim + + if key_as_static: + k = key.repeat(1, num_heads, 1) + else: + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + k = nn.functional.linear(key, _w, _b) + _start += hopfield_dim + _end += hopfield_dim + + if value_as_static: + v = value.repeat(1, num_heads, 1) + else: + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + v = nn.functional.linear(value, _w, _b) + else: + _start, _end = 0, hopfield_dim + if query_as_static: + q = query.repeat(1, num_heads, 1) + else: + q_proj_weight_non_opt = torch.jit._unwrap_optional(q_proj_weight) + len1, len2 = q_proj_weight_non_opt.size() + assert len1 == hopfield_dim and len2 == query.size(-1) + if in_proj_bias is not None: + q = nn.functional.linear(query, q_proj_weight_non_opt, in_proj_bias[_start:_end]) + _start += hopfield_dim + _end += hopfield_dim + else: + q = nn.functional.linear(query, q_proj_weight_non_opt, in_proj_bias) + + v = value + if key_as_static: + k = key.repeat(1, num_heads, 1) + else: + k_proj_weight_non_opt = torch.jit._unwrap_optional(k_proj_weight) + len1, len2 = k_proj_weight_non_opt.size() + assert len1 == hopfield_dim and len2 == key.size(-1) + + _bias = None if in_proj_bias is None else in_proj_bias[_start:_end] + k = nn.functional.linear(key, k_proj_weight_non_opt, _bias) + if value_as_connected: + v = nn.functional.linear(v, k_proj_weight_non_opt, _bias) + _start += hopfield_dim + _end += num_heads * pattern_dim + + if value_as_static: + if not (value_as_connected or key_as_static): + v = v.repeat(1, num_heads, 1) + else: + v_proj_weight_non_opt = torch.jit._unwrap_optional(v_proj_weight) + len1, len2 = v_proj_weight_non_opt.size() + assert len1 == (num_heads * pattern_dim) and len2 == v.size(-1) + if in_proj_bias is not None: + v = nn.functional.linear(v, v_proj_weight_non_opt, in_proj_bias[_start:]) + else: + v = nn.functional.linear(v, v_proj_weight_non_opt, in_proj_bias) + + if attn_mask is not None: + assert attn_mask.dtype == torch.float32 or attn_mask.dtype == torch.float64 or \ + attn_mask.dtype == torch.float16 or attn_mask.dtype == torch.uint8 or \ + attn_mask.dtype == torch.bool, \ + 'Only float, byte, and bool types are supported for attn_mask, not {}'.format(attn_mask.dtype) + if attn_mask.dtype == torch.uint8: + warnings.warn( + "Byte tensor for attn_mask in nn.HopfieldCore is deprecated. Use bool tensor instead.") + attn_mask = attn_mask.to(torch.bool) + + if attn_mask.dim() == 2: + attn_mask = attn_mask.unsqueeze(0) + if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: + raise RuntimeError('The size of the 2D attn_mask is not correct.') + elif attn_mask.dim() == 3: + if list(attn_mask.size()) != [bsz * num_heads, query.size(0), key.size(0)]: + raise RuntimeError('The size of the 3D attn_mask is not correct.') + else: + raise RuntimeError("attn_mask's dimension {} is not supported".format(attn_mask.dim())) + # attn_mask's dim is 3 now. + + # Optionally normalize patterns. + if normalize_pattern: + q = torch.nn.functional.layer_norm( + input=q.reshape(shape=(-1, head_dim)), normalized_shape=(head_dim,), + weight=p_norm_weight, bias=p_norm_bias).reshape(shape=q.shape) + k = torch.nn.functional.layer_norm( + input=k.reshape(shape=(-1, head_dim)), normalized_shape=(head_dim,), + weight=p_norm_weight, bias=p_norm_bias).reshape(shape=k.shape) + + else: + active_xi = xi.masked_select(mask=update_active_heads).view(size=(-1, *xi.shape[1:])) + active_k = k.masked_select(mask=update_active_heads).view(size=(-1, *k.shape[1:])) + q = torch.masked_scatter(input=q, mask=update_active_heads, source=torch.bmm(active_xi, active_k)) + + # Optionally scale association heads (each head separately). + if type(scaling) == float: + q = q * scaling + elif type(scaling) == torch.Tensor: + q = q * scaling.view(1, 1, -1).repeat(repeats=(1, 1, q.shape[2] // scaling.shape[0])) + + if update_step == 0: + # convert ByteTensor key_padding_mask to bool + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + warnings.warn( + "Byte tensor for key_padding_mask in nn.HopfieldCore is deprecated. Use bool tensor instead.") + key_padding_mask = key_padding_mask.to(torch.bool) + + if bias_k is not None and bias_v is not None: + if static_k is None and static_v is None and key_as_static is None and value_as_static is None: + k = torch.cat([k, bias_k.repeat(1, bsz, 1)]) + v = torch.cat([v, bias_v.repeat(1, bsz, 1)]) + if attn_mask is not None: + attn_mask = nn.functional.pad(attn_mask, [0, 1]) + if key_padding_mask is not None: + key_padding_mask = nn.functional.pad(key_padding_mask, [0, 1]) + else: + assert static_k is None, "bias cannot be added to static key." + assert static_v is None, "bias cannot be added to static value." + assert not key_as_static, "bias cannot be added to static key." + assert not value_as_static, "bias cannot be added to static value." + else: + assert bias_k is None + assert bias_v is None + + q = q.contiguous().view(tgt_len, -1, head_dim).transpose(0, 1) + if k is not None: + k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) + if v is not None: + v = v.contiguous().view(v.shape[0], bsz * num_heads, -1).transpose(0, 1) + + if static_k is not None: + assert static_k.size(0) == bsz * num_heads + assert static_k.size(2) == head_dim + k = static_k + + if static_v is not None: + assert static_v.size(0) == bsz * num_heads + assert static_v.size(2) == pattern_dim + v = static_v + + src_len = k.size(1) + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz + assert key_padding_mask.size(1) == src_len + + if add_zero_attn: + src_len += 1 + k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1) + v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1) + if attn_mask is not None: + attn_mask = nn.functional.pad(attn_mask, [0, 1]) + if key_padding_mask is not None: + key_padding_mask = nn.functional.pad(key_padding_mask, [0, 1]) + + attn_output_weights = torch.bmm(q, k.transpose(1, 2)) + assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len] + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_output_weights.masked_fill_(attn_mask, float('-inf')) + else: + attn_output_weights += attn_mask + + if key_padding_mask is not None: + attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) + attn_output_weights = attn_output_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2), + float('-inf'), + ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len) + + # Compute new xi for Hopfield retrieve iterations. + if xi is None: + xi = nn.functional.softmax(attn_output_weights, dim=-1) + else: + xi = torch.masked_scatter(input=xi, mask=update_active_heads, source=nn.functional.softmax( + attn_output_weights.masked_select(mask=update_active_heads).view(size=(-1, *xi.shape[1:])), dim=-1)) + + # Compute threshold-based stopping criterion for Hopfield retrieve iterations. + with torch.no_grad(): + xi_active = xi.view(size=(bsz, num_heads, tgt_len, src_len)) + update_active_heads = (update_step < update_steps_max) | (update_steps_max < 0) + if xi_old is not None: + update_active_heads &= ((xi_old - xi_active).norm(p=2, dim=(2, 3)).max(axis=0)[0]) > update_steps_eps + update_active_heads = update_active_heads.unsqueeze(dim=1).unsqueeze(dim=2).repeat(repeats=(bsz, 1, 1)) + xi_old = xi_active + update_step += 1 + + #################################################################################################################### + # END HOPFIELD UPDATE ITERATION # + #################################################################################################################### + + attn_output_weights = nn.functional.dropout(xi, p=dropout_p, training=training) + attn_output = torch.bmm(attn_output_weights, v) + assert list(attn_output.shape[:2]) == [bsz * num_heads, tgt_len] + attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, -1) + if out_proj_weight is not None: + assert attn_output.shape[2] == num_heads * pattern_dim + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) + + xi = xi.view(bsz, num_heads, tgt_len, src_len) if return_raw_associations else None + v = v.view(bsz, num_heads, src_len, -1) if return_projected_patterns else None + if need_weights: + # average attention weights over heads + attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) + return attn_output, attn_output_weights.sum(dim=1) / num_heads, xi, v + else: + return attn_output, None, xi, v diff --git a/src/mhnfs/hopfield/examples/mnist_bags/modules/transformer.py b/src/mhnfs/hopfield/examples/mnist_bags/modules/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..884e0cb4b57610cf1daf8147f2c3d59f17824750 --- /dev/null +++ b/src/mhnfs/hopfield/examples/mnist_bags/modules/transformer.py @@ -0,0 +1,209 @@ +import torch +import torch.nn as nn + +from copy import deepcopy +from torch import Tensor +from torch.nn.modules import Module +from typing import Optional, Tuple, Union + +from . import Hopfield + + +class HopfieldEncoderLayer(Module): + """ + Module with underlying Hopfield association to be used as an encoder in transformer-like architectures. + """ + + def __init__(self, + hopfield_association: Hopfield, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation: str = r'relu' + ): + """ + Initialise a new instance of a Hopfield association-based encoder module. + + :param hopfield_association: instance of Hopfield association module + :param dim_feedforward: depth of the linear projections applied internally + :param activation: activation to be applied on the result of the internal linear projections + :param dropout: dropout probability to be applied internally + """ + super(HopfieldEncoderLayer, self).__init__() + self.hopfield_association = deepcopy(hopfield_association) + + self.linear_residual = nn.Linear(self.hopfield_association.state_pattern_dim, dim_feedforward) + self.dropout_residual = nn.Dropout(dropout) + self.linear_output = nn.Linear(dim_feedforward, self.hopfield_association.state_pattern_dim) + + self.norm_residual = nn.LayerNorm(self.hopfield_association.state_pattern_dim) + self.norm_output = nn.LayerNorm(self.hopfield_association.state_pattern_dim) + self.dropout_hopfield_association = nn.Dropout(dropout) + self.dropout_output = nn.Dropout(dropout) + + self.activation_residual = getattr(torch, activation, None) + assert self.activation_residual is not None, r'invalid activation function supplied.' + self.reset_parameters() + + def reset_parameters(self) -> None: + """ + Reset parameters, including Hopfield association. + + :return: None + """ + for module in (self.hopfield_association, self.linear_residual, + self.linear_output, self.norm_residual, self.norm_output): + if hasattr(module, r'reset_parameters'): + module.reset_parameters() + + def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None) -> Tensor: + """ + Apply Hopfield encoding on specified data. + + :param src: data to be processed by Hopfield encoder module + :param src_mask: mask to be applied on association matrix + :param src_key_padding_mask: mask to be applied on stored patterns + :return: Hopfield-encoded input data + """ + data_associated = self.hopfield_association( + input=src, stored_pattern_padding_mask=src_key_padding_mask, association_mask=src_mask) + src = src + self.dropout_hopfield_association(input=data_associated) + src = self.norm_residual(input=src) + + result_residual_inner = self.activation_residual(input=self.linear_residual(input=src)) + data_associated = self.linear_output(input=self.dropout_residual(input=result_residual_inner)) + src = src + self.dropout_output(input=data_associated) + + return self.norm_output(input=src) + + def get_association_matrix(self, input: Union[Tensor, Tuple[Tensor, Tensor, Tensor]]) -> Tensor: + """ + Fetch Hopfield association matrix gathered by passing through the specified data. + + :param input: data to be passed through the Hopfield association + :return: association matrix as computed by the Hopfield core module + """ + return self.hopfield_association.get_association_matrix(input=input) + + @property + def batch_first(self) -> int: + return self.hopfield_association.batch_first + + @property + def input_size(self) -> int: + return self.hopfield_association.input_size + + @property + def output_size(self) -> int: + return self.linear_output.out_features + + +class HopfieldDecoderLayer(Module): + + def __init__(self, + hopfield_association_self: Hopfield, + hopfield_association_cross: Hopfield, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation: str = r'relu' + ): + """ + Initialise a new instance of a Hopfield association-based encoder module. + + :param hopfield_association_self: instance of Hopfield self-association module + :param hopfield_association_cross: instance of Hopfield cross-association module + :param dim_feedforward: depth of the linear projections applied internally + :param dropout: dropout probability to be applied internally + :param activation: activation to be applied on the result of the internal linear projections + """ + super(HopfieldDecoderLayer, self).__init__() + self.hopfield_association_self = deepcopy(hopfield_association_self) + self.hopfield_association_cross = deepcopy(hopfield_association_cross) + + self.linear_residual = nn.Linear(self.hopfield_association_self.state_pattern_dim, dim_feedforward) + self.dropout_residual = nn.Dropout(dropout) + self.linear_output = nn.Linear(dim_feedforward, self.hopfield_association_self.state_pattern_dim) + + self.norm_residual_self = nn.LayerNorm(self.hopfield_association_self.state_pattern_dim) + self.norm_residual_cross = nn.LayerNorm(self.hopfield_association_self.state_pattern_dim) + self.norm_output = nn.LayerNorm(self.hopfield_association_self.state_pattern_dim) + self.dropout_hopfield_association_self = nn.Dropout(dropout) + self.dropout_hopfield_association_cross = nn.Dropout(dropout) + self.dropout_output = nn.Dropout(dropout) + + self.activation_residual = getattr(torch, activation, None) + assert self.activation_residual is not None, r'invalid activation function supplied.' + self.reset_parameters() + + def reset_parameters(self) -> None: + """ + Reset parameters, including Hopfield association. + + :return: None + """ + for module in (self.hopfield_association_self, self.hopfield_association_cross, + self.linear_residual, self.linear_output, self.norm_residual_self, + self.norm_residual_cross, self.norm_output): + if hasattr(module, r'reset_parameters'): + module.reset_parameters() + + def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None) -> Tensor: + """ + Apply Hopfield decoding on specified data. + + :param tgt: data to be processed by Hopfield decoder module (self-association) + :param memory: data to be processed by Hopfield encoder module (cross-association) + :param tgt_mask: mask to be applied on self-association matrix + :param memory_mask: mask to be applied on cross-association matrix + :param tgt_key_padding_mask: mask to be applied on stored patterns + :param memory_key_padding_mask: mask to be applied on state patterns as well as pattern projection + :return: Hopfield-decoded input + """ + data_associated = self.hopfield_association_self( + input=tgt, stored_pattern_padding_mask=tgt_key_padding_mask, + association_mask=tgt_mask) + tgt = tgt + self.dropout_hopfield_association_self(input=data_associated) + tgt = self.norm_residual_self(input=tgt) + + data_associated = self.hopfield_association_cross( + input=(memory, tgt, memory), stored_pattern_padding_mask=memory_key_padding_mask, + association_mask=memory_mask) + tgt = tgt + self.dropout_hopfield_association_cross(input=data_associated) + tgt = self.norm_residual_cross(input=tgt) + + result_residual_inner = self.activation_residual(input=self.linear_residual(input=tgt)) + data_associated = self.linear_output(input=self.dropout_residual(input=result_residual_inner)) + tgt = tgt + self.dropout_output(input=data_associated) + return self.norm_output(input=tgt) + + def get_association_matrix_self(self, input: Union[Tensor, Tuple[Tensor, Tensor, Tensor]]) -> Tensor: + """ + Fetch Hopfield self-association matrix gathered by passing through the specified data. + + :param input: data to be passed through the Hopfield association + :return: association matrix as computed by the Hopfield core module + """ + return self.hopfield_association_self.get_association_matrix(input=input) + + def get_association_matrix_cross(self, input: Union[Tensor, Tuple[Tensor, Tensor, Tensor]]) -> Tensor: + """ + Fetch Hopfield cross-association matrix gathered by passing through the specified data. + + :param input: data to be passed through the Hopfield association + :return: association matrix as computed by the Hopfield core module + """ + return self.hopfield_association_cross.get_association_matrix(input=input) + + @property + def batch_first(self) -> int: + return self.hopfield_association_self.batch_first + + @property + def input_size(self) -> int: + return self.hopfield_association_self.input_size + + @property + def output_size(self) -> int: + return self.linear_output_self.out_features diff --git a/src/mhnfs/hopfield/examples/mnist_bags/resources/attention_base.pdf b/src/mhnfs/hopfield/examples/mnist_bags/resources/attention_base.pdf new file mode 100644 index 0000000000000000000000000000000000000000..9bda3dcc2e08778205dc34f009e1ae59efdf4ac4 Binary files /dev/null and b/src/mhnfs/hopfield/examples/mnist_bags/resources/attention_base.pdf differ diff --git a/src/mhnfs/hopfield/examples/mnist_bags/resources/gated_attention_base.pdf b/src/mhnfs/hopfield/examples/mnist_bags/resources/gated_attention_base.pdf new file mode 100644 index 0000000000000000000000000000000000000000..f810c2f8e6d8f277e4f0073c3d1c9312e986caa4 Binary files /dev/null and b/src/mhnfs/hopfield/examples/mnist_bags/resources/gated_attention_base.pdf differ diff --git a/src/mhnfs/hopfield/examples/mnist_bags/resources/hopfield_pooling.pdf b/src/mhnfs/hopfield/examples/mnist_bags/resources/hopfield_pooling.pdf new file mode 100644 index 0000000000000000000000000000000000000000..998d74eea9f281277cfb49cb05f25fd6dc8fbd2d Binary files /dev/null and b/src/mhnfs/hopfield/examples/mnist_bags/resources/hopfield_pooling.pdf differ diff --git a/src/mhnfs/hopfield/modules/__init__.py b/src/mhnfs/hopfield/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c11713739acb1aca3b32cca53acf0c3faf1fa9f6 --- /dev/null +++ b/src/mhnfs/hopfield/modules/__init__.py @@ -0,0 +1,898 @@ +import torch +import torch.nn as nn + +from math import sqrt +from torch import Tensor +from torch.nn import Module, Parameter +from typing import Optional, Tuple, Union + +from .activation import HopfieldCore + + +class Hopfield(Module): + """ + Module with underlying Hopfield association. + """ + + def __init__(self, + input_size: Optional[int] = None, + hidden_size: Optional[int] = None, + output_size: Optional[int] = None, + pattern_size: Optional[int] = None, + num_heads: int = 1, + scaling: Optional[Union[float, Tensor]] = None, + update_steps_max: Optional[Union[int, Tensor]] = 0, + update_steps_eps: Union[float, Tensor] = 1e-4, + + normalize_stored_pattern: bool = True, + normalize_stored_pattern_affine: bool = True, + normalize_state_pattern: bool = True, + normalize_state_pattern_affine: bool = True, + normalize_pattern_projection: bool = True, + normalize_pattern_projection_affine: bool = True, + normalize_hopfield_space: bool = False, + normalize_hopfield_space_affine: bool = False, + stored_pattern_as_static: bool = False, + state_pattern_as_static: bool = False, + pattern_projection_as_static: bool = False, + pattern_projection_as_connected: bool = False, + stored_pattern_size: Optional[int] = None, + pattern_projection_size: Optional[int] = None, + + batch_first: bool = True, + association_activation: Optional[str] = None, + dropout: float = 0.0, + input_bias: bool = True, + concat_bias_pattern: bool = False, + add_zero_association: bool = False, + disable_out_projection: bool = False + ): + """ + Initialise new instance of a Hopfield module. + + :param input_size: depth of the input (state pattern) + :param hidden_size: depth of the association space + :param output_size: depth of the output projection + :param pattern_size: depth of patterns to be selected + :param num_heads: amount of parallel association heads + :param scaling: scaling of association heads, often represented as beta (one entry per head) + :param update_steps_max: maximum count of association update steps (None equals to infinity) + :param update_steps_eps: minimum difference threshold between two consecutive association update steps + :param normalize_stored_pattern: apply normalization on stored patterns + :param normalize_stored_pattern_affine: additionally enable affine normalization of stored patterns + :param normalize_state_pattern: apply normalization on state patterns + :param normalize_state_pattern_affine: additionally enable affine normalization of state patterns + :param normalize_pattern_projection: apply normalization on the pattern projection + :param normalize_pattern_projection_affine: additionally enable affine normalization of pattern projection + :param normalize_hopfield_space: enable normalization of patterns in the Hopfield space + :param normalize_hopfield_space_affine: additionally enable affine normalization of patterns in Hopfield space + :param stored_pattern_as_static: interpret specified stored patterns as being static + :param state_pattern_as_static: interpret specified state patterns as being static + :param pattern_projection_as_static: interpret specified pattern projections as being static + :param pattern_projection_as_connected: connect pattern projection with stored pattern + :param stored_pattern_size: depth of input (stored pattern) + :param pattern_projection_size: depth of input (pattern projection) + :param batch_first: flag for specifying if the first dimension of data fed to "forward" reflects the batch size + :param association_activation: additional activation to be applied on the result of the Hopfield association + :param dropout: dropout probability applied on the association matrix + :param input_bias: bias to be added to input (state and stored pattern as well as pattern projection) + :param concat_bias_pattern: bias to be concatenated to stored pattern as well as pattern projection + :param add_zero_association: add a new batch of zeros to stored pattern as well as pattern projection + :param disable_out_projection: disable output projection + """ + super(Hopfield, self).__init__() + assert type(batch_first) == bool, f'"batch_first" needs to be a boolean, not {type(batch_first)}.' + assert (association_activation is None) or (type(association_activation) == str) + + # Initialise Hopfield association module. + self.association_core = HopfieldCore( + embed_dim=input_size, num_heads=num_heads, dropout=dropout, bias=input_bias, + add_bias_kv=concat_bias_pattern, add_zero_attn=add_zero_association, kdim=stored_pattern_size, + vdim=pattern_projection_size, head_dim=hidden_size, pattern_dim=pattern_size, out_dim=output_size, + disable_out_projection=disable_out_projection, key_as_static=stored_pattern_as_static, + query_as_static=state_pattern_as_static, value_as_static=pattern_projection_as_static, + value_as_connected=pattern_projection_as_connected, normalize_pattern=normalize_hopfield_space, + normalize_pattern_affine=normalize_hopfield_space_affine) + self.association_activation = None + if association_activation is not None: + self.association_activation = getattr(torch, association_activation, None) + + # Initialise stored pattern normalization. + self.norm_stored_pattern = None + if normalize_stored_pattern_affine: + assert normalize_stored_pattern, "affine normalization without normalization has no effect." + if normalize_stored_pattern: + normalized_shape = input_size if stored_pattern_size is None else stored_pattern_size + assert normalized_shape is not None, "stored pattern size required for setting up normalisation" + self.norm_stored_pattern = nn.LayerNorm( + normalized_shape=normalized_shape, elementwise_affine=normalize_stored_pattern_affine) + + # Initialise state pattern normalization. + self.norm_state_pattern = None + if normalize_state_pattern_affine: + assert normalize_state_pattern, "affine normalization without normalization has no effect." + if normalize_state_pattern: + assert input_size is not None, "input size required for setting up normalisation" + self.norm_state_pattern = nn.LayerNorm( + normalized_shape=input_size, elementwise_affine=normalize_state_pattern_affine) + + # Initialise pattern projection normalization. + self.norm_pattern_projection = None + if normalize_pattern_projection_affine: + assert normalize_pattern_projection, "affine normalization without normalization has no effect." + if normalize_pattern_projection: + normalized_shape = input_size if pattern_projection_size is None else pattern_projection_size + assert normalized_shape is not None, "pattern projection size required for setting up normalisation" + self.norm_pattern_projection = nn.LayerNorm( + normalized_shape=normalized_shape, elementwise_affine=normalize_pattern_projection_affine) + + # Initialise remaining auxiliary properties. + if self.association_core.static_execution: + self.__scaling = 1.0 if scaling is None else scaling + else: + assert self.association_core.head_dim > 0, f'invalid hidden dimension encountered.' + self.__scaling = (1.0 / sqrt(self.association_core.head_dim)) if scaling is None else scaling + self.__batch_first = batch_first + self.__update_steps_max = update_steps_max + self.__update_steps_eps = update_steps_eps + self.reset_parameters() + + def reset_parameters(self) -> None: + """ + Reset Hopfield association. + + :return: None + """ + for module in (self.association_core, self.norm_stored_pattern, + self.norm_state_pattern, self.norm_pattern_projection): + if hasattr(module, r'reset_parameters'): + module.reset_parameters() + + def _maybe_transpose(self, *args: Tuple[Tensor, ...]) -> Union[Tensor, Tuple[Tensor, ...]]: + """ + Eventually transpose specified data. + + :param args: tensors to eventually transpose (dependent on the state of "batch_first") + :return: eventually transposed tensors + """ + transposed_result = tuple(_.transpose(0, 1) for _ in args) if self.__batch_first else args + return transposed_result[0] if len(transposed_result) == 1 else transposed_result + + def _associate(self, data: Union[Tensor, Tuple[Tensor, Tensor, Tensor]], + return_raw_associations: bool = False, return_projected_patterns: bool = False, + stored_pattern_padding_mask: Optional[Tensor] = None, + association_mask: Optional[Tensor] = None) -> Tuple[Optional[Tensor], ...]: + """ + Apply Hopfield association module on specified data. + + :param data: data to be processed by Hopfield core module + :param return_raw_associations: return raw association (softmax) values, unmodified + :param return_projected_patterns: return pattern projection values, unmodified + :param stored_pattern_padding_mask: mask to be applied on stored patterns + :param association_mask: mask to be applied on inner association matrix + :return: Hopfield-processed input data + """ + assert (type(data) == Tensor) or ((type(data) == tuple) and (len(data) == 3)), \ + r'either one tensor to be used as "stored pattern", "state pattern" and' \ + r' "pattern_projection" must be provided, or three separate ones.' + if type(data) == Tensor: + stored_pattern, state_pattern, pattern_projection = data, data, data + else: + stored_pattern, state_pattern, pattern_projection = data + + # Optionally transpose data. + stored_pattern, state_pattern, pattern_projection = self._maybe_transpose( + stored_pattern, state_pattern, pattern_projection) + + # Optionally apply stored pattern normalization. + if self.norm_stored_pattern is not None: + stored_pattern = self.norm_stored_pattern(input=stored_pattern.reshape( + shape=(-1, stored_pattern.shape[2]))).reshape(shape=stored_pattern.shape) + + # Optionally apply state pattern normalization. + if self.norm_state_pattern is not None: + state_pattern = self.norm_state_pattern(input=state_pattern.reshape( + shape=(-1, state_pattern.shape[2]))).reshape(shape=state_pattern.shape) + + # Optionally apply pattern projection normalization. + if self.norm_pattern_projection is not None: + pattern_projection = self.norm_pattern_projection(input=pattern_projection.reshape( + shape=(-1, pattern_projection.shape[2]))).reshape(shape=pattern_projection.shape) + + # Apply Hopfield association and optional activation function. + return self.association_core( + query=state_pattern, key=stored_pattern, value=pattern_projection, + key_padding_mask=stored_pattern_padding_mask, need_weights=False, attn_mask=association_mask, + scaling=self.__scaling, update_steps_max=self.__update_steps_max, update_steps_eps=self.__update_steps_eps, + return_raw_associations=return_raw_associations, return_pattern_projections=return_projected_patterns) + + def forward(self, input: Union[Tensor, Tuple[Tensor, Tensor, Tensor]], + stored_pattern_padding_mask: Optional[Tensor] = None, + association_mask: Optional[Tensor] = None) -> Tensor: + """ + Apply Hopfield association on specified data. + + :param input: data to be processed by Hopfield association module + :param stored_pattern_padding_mask: mask to be applied on stored patterns + :param association_mask: mask to be applied on inner association matrix + :return: Hopfield-processed input data + """ + association_output = self._maybe_transpose(self._associate( + data=input, return_raw_associations=False, + stored_pattern_padding_mask=stored_pattern_padding_mask, + association_mask=association_mask)[0]) + if self.association_activation is not None: + association_output = self.association_activation(association_output) + return association_output + + def get_association_matrix(self, input: Union[Tensor, Tuple[Tensor, Tensor, Tensor]], + stored_pattern_padding_mask: Optional[Tensor] = None, + association_mask: Optional[Tensor] = None) -> Tensor: + """ + Fetch Hopfield association matrix gathered by passing through the specified data. + + :param input: data to be passed through the Hopfield association + :param stored_pattern_padding_mask: mask to be applied on stored patterns + :param association_mask: mask to be applied on inner association matrix + :return: association matrix as computed by the Hopfield core module + """ + with torch.no_grad(): + return self._associate( + data=input, return_raw_associations=True, + stored_pattern_padding_mask=stored_pattern_padding_mask, + association_mask=association_mask)[2] + + def get_projected_pattern_matrix(self, input: Union[Tensor, Tuple[Tensor, Tensor, Tensor]], + stored_pattern_padding_mask: Optional[Tensor] = None, + association_mask: Optional[Tensor] = None) -> Tensor: + """ + Fetch Hopfield projected pattern matrix gathered by passing through the specified data. + + :param input: data to be passed through the Hopfield association + :param stored_pattern_padding_mask: mask to be applied on stored patterns + :param association_mask: mask to be applied on inner association matrix + :return: pattern projection matrix as computed by the Hopfield core module + """ + with torch.no_grad(): + return self._associate( + data=input, return_projected_patterns=True, + stored_pattern_padding_mask=stored_pattern_padding_mask, + association_mask=association_mask)[3] + + @property + def batch_first(self) -> bool: + return self.__batch_first + + @property + def scaling(self) -> Union[float, Tensor]: + return self.__scaling.clone() if type(self.__scaling) == Tensor else self.__scaling + + @property + def stored_pattern_dim(self) -> Optional[int]: + return self.association_core.kdim + + @property + def state_pattern_dim(self) -> Optional[int]: + return self.association_core.embed_dim + + @property + def pattern_projection_dim(self) -> Optional[int]: + return self.association_core.vdim + + @property + def input_size(self) -> Optional[int]: + return self.state_pattern_dim + + @property + def hidden_size(self) -> Optional[int]: + return self.association_core.head_dim + + @property + def output_size(self) -> Optional[int]: + return self.association_core.out_dim + + @property + def pattern_size(self) -> Optional[int]: + return self.association_core.pattern_dim + + @property + def update_steps_max(self) -> Optional[Union[int, Tensor]]: + return self.__update_steps_max.clone() if type(self.__update_steps_max) == Tensor else self.__update_steps_max + + @property + def update_steps_eps(self) -> Optional[Union[float, Tensor]]: + return self.__update_steps_eps.clone() if type(self.__update_steps_eps) == Tensor else self.__update_steps_eps + + @property + def stored_pattern_as_static(self) -> bool: + return self.association_core.key_as_static + + @property + def state_pattern_as_static(self) -> bool: + return self.association_core.query_as_static + + @property + def pattern_projection_as_static(self) -> bool: + return self.association_core.value_as_static + + @property + def normalize_stored_pattern(self) -> bool: + return self.norm_stored_pattern is not None + + @property + def normalize_stored_pattern_affine(self) -> bool: + return self.normalize_stored_pattern and self.norm_stored_pattern.elementwise_affine + + @property + def normalize_state_pattern(self) -> bool: + return self.norm_state_pattern is not None + + @property + def normalize_state_pattern_affine(self) -> bool: + return self.normalize_state_pattern and self.norm_state_pattern.elementwise_affine + + @property + def normalize_pattern_projection(self) -> bool: + return self.norm_pattern_projection is not None + + @property + def normalize_pattern_projection_affine(self) -> bool: + return self.normalize_pattern_projection and self.norm_pattern_projection.elementwise_affine + + @property + def normalize_hopfield_space(self) -> bool: + return self.hopfield.normalize_hopfield_space + + @property + def normalize_hopfield_space_affine(self) -> bool: + return self.hopfield.normalize_hopfield_space_affine + + +class HopfieldPooling(Module): + """ + Wrapper class encapsulating a trainable but fixed state pattern and "Hopfield" in + one combined module to be used as a Hopfield-based pooling layer. + """ + + def __init__(self, + input_size: int, + hidden_size: Optional[int] = None, + output_size: Optional[int] = None, + pattern_size: Optional[int] = None, + num_heads: int = 1, + scaling: Optional[Union[float, Tensor]] = None, + update_steps_max: Optional[Union[int, Tensor]] = 0, + update_steps_eps: Union[float, Tensor] = 1e-4, + + normalize_stored_pattern: bool = True, + normalize_stored_pattern_affine: bool = True, + normalize_state_pattern: bool = True, + normalize_state_pattern_affine: bool = True, + normalize_pattern_projection: bool = True, + normalize_pattern_projection_affine: bool = True, + normalize_hopfield_space: bool = False, + normalize_hopfield_space_affine: bool = False, + stored_pattern_as_static: bool = False, + state_pattern_as_static: bool = False, + pattern_projection_as_static: bool = False, + pattern_projection_as_connected: bool = False, + stored_pattern_size: Optional[int] = None, + pattern_projection_size: Optional[int] = None, + + batch_first: bool = True, + association_activation: Optional[str] = None, + dropout: float = 0.0, + input_bias: bool = True, + concat_bias_pattern: bool = False, + add_zero_association: bool = False, + disable_out_projection: bool = False, + quantity: int = 1, + trainable: bool = True + ): + """ + Initialise a new instance of a Hopfield-based pooling layer. + + :param input_size: depth of the input (state pattern) + :param hidden_size: depth of the association space + :param output_size: depth of the output projection + :param pattern_size: depth of patterns to be selected + :param num_heads: amount of parallel association heads + :param scaling: scaling of association heads, often represented as beta (one entry per head) + :param update_steps_max: maximum count of association update steps (None equals to infinity) + :param update_steps_eps: minimum difference threshold between two consecutive association update steps + :param normalize_stored_pattern: apply normalization on stored patterns + :param normalize_stored_pattern_affine: additionally enable affine normalization of stored patterns + :param normalize_state_pattern: apply normalization on state patterns + :param normalize_state_pattern_affine: additionally enable affine normalization of state patterns + :param normalize_pattern_projection: apply normalization on the pattern projection + :param normalize_pattern_projection_affine: additionally enable affine normalization of pattern projection + :param normalize_hopfield_space: enable normalization of patterns in the Hopfield space + :param normalize_hopfield_space_affine: additionally enable affine normalization of patterns in Hopfield space + :param stored_pattern_as_static: interpret specified stored patterns as being static + :param state_pattern_as_static: interpret specified state patterns as being static + :param pattern_projection_as_static: interpret specified pattern projections as being static + :param pattern_projection_as_connected: connect pattern projection with stored pattern + :param stored_pattern_size: depth of input (stored pattern) + :param pattern_projection_size: depth of input (pattern projection) + :param batch_first: flag for specifying if the first dimension of data fed to "forward" reflects the batch size + :param association_activation: additional activation to be applied on the result of the Hopfield association + :param dropout: dropout probability applied on the association matrix + :param input_bias: bias to be added to input (state and stored pattern as well as pattern projection) + :param concat_bias_pattern: bias to be concatenated to stored pattern as well as pattern projection + :param add_zero_association: add a new batch of zeros to stored pattern as well as pattern projection + :param disable_out_projection: disable output projection + :param quantity: amount of state patterns + :param trainable: state pattern used for pooling is trainable + """ + super(HopfieldPooling, self).__init__() + self.hopfield = Hopfield( + input_size=input_size, hidden_size=hidden_size, output_size=output_size, pattern_size=pattern_size, + num_heads=num_heads, scaling=scaling, update_steps_max=update_steps_max, update_steps_eps=update_steps_eps, + normalize_stored_pattern=normalize_stored_pattern, + normalize_stored_pattern_affine=normalize_stored_pattern_affine, + normalize_state_pattern=normalize_state_pattern, + normalize_state_pattern_affine=normalize_state_pattern_affine, + normalize_pattern_projection=normalize_pattern_projection, + normalize_pattern_projection_affine=normalize_pattern_projection_affine, + normalize_hopfield_space=normalize_hopfield_space, + normalize_hopfield_space_affine=normalize_hopfield_space_affine, + stored_pattern_as_static=stored_pattern_as_static, state_pattern_as_static=state_pattern_as_static, + pattern_projection_as_static=pattern_projection_as_static, + pattern_projection_as_connected=pattern_projection_as_connected, stored_pattern_size=stored_pattern_size, + pattern_projection_size=pattern_projection_size, batch_first=batch_first, + association_activation=association_activation, dropout=dropout, input_bias=input_bias, + concat_bias_pattern=concat_bias_pattern, add_zero_association=add_zero_association, + disable_out_projection=disable_out_projection) + self._quantity = quantity + pooling_weight_size = self.hopfield.hidden_size if state_pattern_as_static else self.hopfield.input_size + self.pooling_weights = nn.Parameter(torch.empty(size=(*( + (1, quantity) if batch_first else (quantity, 1) + ), input_size if pooling_weight_size is None else pooling_weight_size)), requires_grad=trainable) + self.reset_parameters() + + def reset_parameters(self) -> None: + """ + Reset pooling weights and underlying Hopfield association. + + :return: None + """ + if hasattr(self.hopfield, r'reset_parameters'): + self.hopfield.reset_parameters() + + # Explicitly initialise pooling weights. + nn.init.normal_(self.pooling_weights, mean=0.0, std=0.02) + + def _prepare_input(self, input: Union[Tensor, Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tensor, Tensor]: + """ + Prepare input for Hopfield association. + + :param input: data to be prepared + :return: stored pattern, expanded state pattern as well as pattern projection + """ + assert (type(input) == Tensor) or ((type(input) == tuple) and (len(input) == 2)), \ + r'either one tensor to be used as "stored pattern" and' \ + r' "pattern_projection" must be provided, or two separate ones.' + if type(input) == Tensor: + stored_pattern, pattern_projection = input, input + else: + stored_pattern, pattern_projection = input + + batch_size = stored_pattern.shape[0 if self.batch_first else 1] + return stored_pattern, self.pooling_weights.expand(size=(*( + (batch_size, self.quantity) if self.batch_first else (self.quantity, batch_size) + ), self.pooling_weights.shape[2])), pattern_projection + + def forward(self, input: Union[Tensor, Tuple[Tensor, Tensor]], stored_pattern_padding_mask: Optional[Tensor] = None, + association_mask: Optional[Tensor] = None) -> Tensor: + """ + Compute Hopfield-based pooling on specified data. + + :param input: data to be pooled + :param stored_pattern_padding_mask: mask to be applied on stored patterns + :param association_mask: mask to be applied on inner association matrix + :return: Hopfield-pooled input data + """ + return self.hopfield( + input=self._prepare_input(input=input), + stored_pattern_padding_mask=stored_pattern_padding_mask, + association_mask=association_mask).flatten(start_dim=1) + + def get_association_matrix(self, input: Union[Tensor, Tuple[Tensor, Tensor]], + stored_pattern_padding_mask: Optional[Tensor] = None, + association_mask: Optional[Tensor] = None) -> Tensor: + """ + Fetch Hopfield association matrix used for pooling gathered by passing through the specified data. + + :param input: data to be passed through the Hopfield association + :param stored_pattern_padding_mask: mask to be applied on stored patterns + :param association_mask: mask to be applied on inner association matrix + :return: association matrix as computed by the Hopfield core module + """ + with torch.no_grad(): + return self.hopfield.get_association_matrix( + input=self._prepare_input(input=input), + stored_pattern_padding_mask=stored_pattern_padding_mask, + association_mask=association_mask) + + def get_projected_pattern_matrix(self, input: Union[Tensor, Tuple[Tensor, Tensor]], + stored_pattern_padding_mask: Optional[Tensor] = None, + association_mask: Optional[Tensor] = None) -> Tensor: + """ + Fetch Hopfield projected pattern matrix gathered by passing through the specified data. + + :param input: data to be passed through the Hopfield association + :param stored_pattern_padding_mask: mask to be applied on stored patterns + :param association_mask: mask to be applied on inner association matrix + :return: pattern projection matrix as computed by the Hopfield core module + """ + with torch.no_grad(): + return self.hopfield.get_projected_pattern_matrix( + input=self._prepare_input(input=input), + stored_pattern_padding_mask=stored_pattern_padding_mask, + association_mask=association_mask) + + @property + def batch_first(self) -> bool: + return self.hopfield.batch_first + + @property + def scaling(self) -> Union[float, Tensor]: + return self.hopfield.scaling + + @property + def stored_pattern_dim(self) -> Optional[int]: + return self.hopfield.stored_pattern_dim + + @property + def state_pattern_dim(self) -> Optional[int]: + return self.hopfield.state_pattern_dim + + @property + def pattern_projection_dim(self) -> Optional[int]: + return self.hopfield.pattern_projection_dim + + @property + def input_size(self) -> Optional[int]: + return self.hopfield.input_size + + @property + def hidden_size(self) -> int: + return self.hopfield.hidden_size + + @property + def output_size(self) -> Optional[int]: + return self.hopfield.output_size + + @property + def pattern_size(self) -> Optional[int]: + return self.hopfield.pattern_size + + @property + def quantity(self) -> int: + return self._quantity + + @property + def update_steps_max(self) -> Optional[Union[int, Tensor]]: + return self.hopfield.update_steps_max + + @property + def update_steps_eps(self) -> Optional[Union[float, Tensor]]: + return self.hopfield.update_steps_eps + + @property + def stored_pattern_as_static(self) -> bool: + return self.hopfield.stored_pattern_as_static + + @property + def state_pattern_as_static(self) -> bool: + return self.hopfield.state_pattern_as_static + + @property + def pattern_projection_as_static(self) -> bool: + return self.hopfield.pattern_projection_as_static + + @property + def normalize_stored_pattern(self) -> bool: + return self.hopfield.normalize_stored_pattern + + @property + def normalize_stored_pattern_affine(self) -> bool: + return self.hopfield.normalize_stored_pattern_affine + + @property + def normalize_state_pattern(self) -> bool: + return self.hopfield.normalize_state_pattern + + @property + def normalize_state_pattern_affine(self) -> bool: + return self.hopfield.normalize_state_pattern_affine + + @property + def normalize_pattern_projection(self) -> bool: + return self.hopfield.normalize_pattern_projection + + @property + def normalize_pattern_projection_affine(self) -> bool: + return self.hopfield.normalize_pattern_projection_affine + + +class HopfieldLayer(Module): + """ + Wrapper class encapsulating a trainable but fixed stored pattern, pattern projection and "Hopfield" in + one combined module to be used as a Hopfield-based pooling layer. + """ + + def __init__(self, + input_size: int, + hidden_size: Optional[int] = None, + output_size: Optional[int] = None, + pattern_size: Optional[int] = None, + num_heads: int = 1, + scaling: Optional[Union[float, Tensor]] = None, + update_steps_max: Optional[Union[int, Tensor]] = 0, + update_steps_eps: Union[float, Tensor] = 1e-4, + lookup_weights_as_separated: bool = False, + lookup_targets_as_trainable: bool = True, + + normalize_stored_pattern: bool = True, + normalize_stored_pattern_affine: bool = True, + normalize_state_pattern: bool = True, + normalize_state_pattern_affine: bool = True, + normalize_pattern_projection: bool = True, + normalize_pattern_projection_affine: bool = True, + normalize_hopfield_space: bool = False, + normalize_hopfield_space_affine: bool = False, + stored_pattern_as_static: bool = False, + state_pattern_as_static: bool = False, + pattern_projection_as_static: bool = False, + pattern_projection_as_connected: bool = False, + stored_pattern_size: Optional[int] = None, + pattern_projection_size: Optional[int] = None, + + batch_first: bool = True, + association_activation: Optional[str] = None, + dropout: float = 0.0, + input_bias: bool = True, + concat_bias_pattern: bool = False, + add_zero_association: bool = False, + disable_out_projection: bool = False, + quantity: int = 1, + trainable: bool = True + ): + """ + Initialise a new instance of a Hopfield-based lookup layer. + + :param input_size: depth of the input (state pattern) + :param hidden_size: depth of the association space + :param output_size: depth of the output projection + :param pattern_size: depth of patterns to be selected + :param num_heads: amount of parallel association heads + :param scaling: scaling of association heads, often represented as beta (one entry per head) + :param update_steps_max: maximum count of association update steps (None equals to infinity) + :param update_steps_eps: minimum difference threshold between two consecutive association update steps + :param lookup_weights_as_separated: separate lookup weights from lookup target weights + :param lookup_targets_as_trainable: employ trainable lookup target weights (used as pattern projection input) + :param normalize_stored_pattern: apply normalization on stored patterns + :param normalize_stored_pattern_affine: additionally enable affine normalization of stored patterns + :param normalize_state_pattern: apply normalization on state patterns + :param normalize_state_pattern_affine: additionally enable affine normalization of state patterns + :param normalize_pattern_projection: apply normalization on the pattern projection + :param normalize_pattern_projection_affine: additionally enable affine normalization of pattern projection + :param normalize_hopfield_space: enable normalization of patterns in the Hopfield space + :param normalize_hopfield_space_affine: additionally enable affine normalization of patterns in Hopfield space + :param stored_pattern_as_static: interpret specified stored patterns as being static + :param state_pattern_as_static: interpret specified state patterns as being static + :param pattern_projection_as_static: interpret specified pattern projections as being static + :param pattern_projection_as_connected: connect pattern projection with stored pattern + :param stored_pattern_size: depth of input (stored pattern) + :param pattern_projection_size: depth of input (pattern projection) + :param batch_first: flag for specifying if the first dimension of data fed to "forward" reflects the batch size + :param association_activation: additional activation to be applied on the result of the Hopfield association + :param dropout: dropout probability applied on the association matrix + :param input_bias: bias to be added to input (state and stored pattern as well as pattern projection) + :param concat_bias_pattern: bias to be concatenated to stored pattern as well as pattern projection + :param add_zero_association: add a new batch of zeros to stored pattern as well as pattern projection + :param disable_out_projection: disable output projection + :param quantity: amount of stored patterns + :param trainable: stored pattern used for lookup is trainable + """ + super(HopfieldLayer, self).__init__() + self.hopfield = Hopfield( + input_size=input_size, hidden_size=hidden_size, output_size=output_size, pattern_size=pattern_size, + num_heads=num_heads, scaling=scaling, update_steps_max=update_steps_max, update_steps_eps=update_steps_eps, + normalize_stored_pattern=normalize_stored_pattern, + normalize_stored_pattern_affine=normalize_stored_pattern_affine, + normalize_state_pattern=normalize_state_pattern, + normalize_state_pattern_affine=normalize_state_pattern_affine, + normalize_pattern_projection=normalize_pattern_projection, + normalize_pattern_projection_affine=normalize_pattern_projection_affine, + normalize_hopfield_space=normalize_hopfield_space, + normalize_hopfield_space_affine=normalize_hopfield_space_affine, + stored_pattern_as_static=stored_pattern_as_static, state_pattern_as_static=state_pattern_as_static, + pattern_projection_as_static=pattern_projection_as_static, + pattern_projection_as_connected=pattern_projection_as_connected, stored_pattern_size=stored_pattern_size, + pattern_projection_size=pattern_projection_size, batch_first=batch_first, + association_activation=association_activation, dropout=dropout, input_bias=input_bias, + concat_bias_pattern=concat_bias_pattern, add_zero_association=add_zero_association, + disable_out_projection=disable_out_projection) + self._quantity = quantity + lookup_weight_size = self.hopfield.hidden_size if stored_pattern_as_static else self.hopfield.stored_pattern_dim + self.lookup_weights = nn.Parameter(torch.empty(size=(*( + (1, quantity) if batch_first else (quantity, 1) + ), input_size if lookup_weight_size is None else lookup_weight_size)), requires_grad=trainable) + + if lookup_weights_as_separated: + target_weight_size = self.lookup_weights.shape[ + 2] if pattern_projection_size is None else pattern_projection_size + self.target_weights = nn.Parameter(torch.empty(size=(*( + (1, quantity) if batch_first else (quantity, 1) + ), target_weight_size)), requires_grad=lookup_targets_as_trainable) + else: + self.register_parameter(name=r'target_weights', param=None) + self.reset_parameters() + + def reset_parameters(self) -> None: + """ + Reset lookup and lookup target weights, including underlying Hopfield association. + + :return: None + """ + if hasattr(self.hopfield, r'reset_parameters'): + self.hopfield.reset_parameters() + + # Explicitly initialise lookup and target weights. + nn.init.normal_(self.lookup_weights, mean=0.0, std=0.02) + if self.target_weights is not None: + nn.init.normal_(self.target_weights, mean=0.0, std=0.02) + + def _prepare_input(self, input: Tensor) -> Tuple[Tensor, Tensor, Tensor]: + """ + Prepare input for Hopfield association. + + :param input: data to be prepared + :return: stored pattern, expanded state pattern as well as pattern projection + """ + batch_size = input.shape[0 if self.batch_first else 1] + stored_pattern = self.lookup_weights.expand(size=(*( + (batch_size, self.quantity) if self.batch_first else (self.quantity, batch_size) + ), self.lookup_weights.shape[2])) + if self.target_weights is None: + pattern_projection = stored_pattern + else: + pattern_projection = self.target_weights.expand(size=(*( + (batch_size, self.quantity) if self.batch_first else (self.quantity, batch_size) + ), self.target_weights.shape[2])) + + return stored_pattern, input, pattern_projection + + def forward(self, input: Tensor, stored_pattern_padding_mask: Optional[Tensor] = None, + association_mask: Optional[Tensor] = None) -> Tensor: + """ + Compute Hopfield-based lookup on specified data. + + :param input: data to used in lookup + :param stored_pattern_padding_mask: mask to be applied on stored patterns + :param association_mask: mask to be applied on inner association matrix + :return: result of Hopfield-based lookup on input data + """ + return self.hopfield( + input=self._prepare_input(input=input), + stored_pattern_padding_mask=stored_pattern_padding_mask, + association_mask=association_mask) + + def get_association_matrix(self, input: Tensor, stored_pattern_padding_mask: Optional[Tensor] = None, + association_mask: Optional[Tensor] = None) -> Tensor: + """ + Fetch Hopfield association matrix used for lookup gathered by passing through the specified data. + + :param input: data to be passed through the Hopfield association + :param stored_pattern_padding_mask: mask to be applied on stored patterns + :param association_mask: mask to be applied on inner association matrix + :return: association matrix as computed by the Hopfield core module + """ + with torch.no_grad(): + return self.hopfield.get_association_matrix( + input=self._prepare_input(input=input), + stored_pattern_padding_mask=stored_pattern_padding_mask, + association_mask=association_mask) + + def get_projected_pattern_matrix(self, input: Union[Tensor, Tuple[Tensor, Tensor]], + stored_pattern_padding_mask: Optional[Tensor] = None, + association_mask: Optional[Tensor] = None) -> Tensor: + """ + Fetch Hopfield projected pattern matrix gathered by passing through the specified data. + + :param input: data to be passed through the Hopfield association + :param stored_pattern_padding_mask: mask to be applied on stored patterns + :param association_mask: mask to be applied on inner association matrix + :return: pattern projection matrix as computed by the Hopfield core module + """ + with torch.no_grad(): + return self.hopfield.get_projected_pattern_matrix( + input=self._prepare_input(input=input), + stored_pattern_padding_mask=stored_pattern_padding_mask, + association_mask=association_mask) + + @property + def batch_first(self) -> bool: + return self.hopfield.batch_first + + @property + def scaling(self) -> Union[float, Tensor]: + return self.hopfield.scaling + + @property + def stored_pattern_dim(self) -> Optional[int]: + return self.hopfield.stored_pattern_dim + + @property + def state_pattern_dim(self) -> Optional[int]: + return self.hopfield.state_pattern_dim + + @property + def pattern_projection_dim(self) -> Optional[int]: + return self.hopfield.pattern_projection_dim + + @property + def input_size(self) -> Optional[int]: + return self.hopfield.input_size + + @property + def hidden_size(self) -> int: + return self.hopfield.hidden_size + + @property + def output_size(self) -> Optional[int]: + return self.hopfield.output_size + + @property + def pattern_size(self) -> Optional[int]: + return self.hopfield.pattern_size + + @property + def quantity(self) -> int: + return self._quantity + + @property + def update_steps_max(self) -> Optional[Union[int, Tensor]]: + return self.hopfield.update_steps_max + + @property + def update_steps_eps(self) -> Optional[Union[float, Tensor]]: + return self.hopfield.update_steps_eps + + @property + def stored_pattern_as_static(self) -> bool: + return self.hopfield.stored_pattern_as_static + + @property + def state_pattern_as_static(self) -> bool: + return self.hopfield.state_pattern_as_static + + @property + def pattern_projection_as_static(self) -> bool: + return self.hopfield.pattern_projection_as_static + + @property + def normalize_stored_pattern(self) -> bool: + return self.hopfield.normalize_stored_pattern + + @property + def normalize_stored_pattern_affine(self) -> bool: + return self.hopfield.normalize_stored_pattern_affine + + @property + def normalize_state_pattern(self) -> bool: + return self.hopfield.normalize_state_pattern + + @property + def normalize_state_pattern_affine(self) -> bool: + return self.hopfield.normalize_state_pattern_affine + + @property + def normalize_pattern_projection(self) -> bool: + return self.hopfield.normalize_pattern_projection + + @property + def normalize_pattern_projection_affine(self) -> bool: + return self.hopfield.normalize_pattern_projection_affine diff --git a/src/mhnfs/hopfield/modules/__pycache__/__init__.cpython-37.pyc b/src/mhnfs/hopfield/modules/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fbd9b3c90d8ce62bd7a17c050da6cff3fd01e226 Binary files /dev/null and b/src/mhnfs/hopfield/modules/__pycache__/__init__.cpython-37.pyc differ diff --git a/src/mhnfs/hopfield/modules/__pycache__/__init__.cpython-38.pyc b/src/mhnfs/hopfield/modules/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..95278bbf7a3923658822391d5b1433e5b750c165 Binary files /dev/null and b/src/mhnfs/hopfield/modules/__pycache__/__init__.cpython-38.pyc differ diff --git a/src/mhnfs/hopfield/modules/__pycache__/__init__.cpython-39.pyc b/src/mhnfs/hopfield/modules/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dadf35bb102978b133cf39fc9c256dd53730aecf Binary files /dev/null and b/src/mhnfs/hopfield/modules/__pycache__/__init__.cpython-39.pyc differ diff --git a/src/mhnfs/hopfield/modules/__pycache__/activation.cpython-37.pyc b/src/mhnfs/hopfield/modules/__pycache__/activation.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..30ba1f5f8709e4c6910d179dbbba71bc629b944e Binary files /dev/null and b/src/mhnfs/hopfield/modules/__pycache__/activation.cpython-37.pyc differ diff --git a/src/mhnfs/hopfield/modules/__pycache__/activation.cpython-38.pyc b/src/mhnfs/hopfield/modules/__pycache__/activation.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..546029f7a27cdd97a69876aeb23f5e28942e5dc4 Binary files /dev/null and b/src/mhnfs/hopfield/modules/__pycache__/activation.cpython-38.pyc differ diff --git a/src/mhnfs/hopfield/modules/__pycache__/activation.cpython-39.pyc b/src/mhnfs/hopfield/modules/__pycache__/activation.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bb910f32c6ced43766d7117d2afe7d9b84114ed3 Binary files /dev/null and b/src/mhnfs/hopfield/modules/__pycache__/activation.cpython-39.pyc differ diff --git a/src/mhnfs/hopfield/modules/__pycache__/functional.cpython-37.pyc b/src/mhnfs/hopfield/modules/__pycache__/functional.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6924f592598e3b8e995d8cdd7144a4680c4e4265 Binary files /dev/null and b/src/mhnfs/hopfield/modules/__pycache__/functional.cpython-37.pyc differ diff --git a/src/mhnfs/hopfield/modules/__pycache__/functional.cpython-38.pyc b/src/mhnfs/hopfield/modules/__pycache__/functional.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4a7bd2c4087f0d9f0208d51fe8fa0f65639d4fe1 Binary files /dev/null and b/src/mhnfs/hopfield/modules/__pycache__/functional.cpython-38.pyc differ diff --git a/src/mhnfs/hopfield/modules/__pycache__/functional.cpython-39.pyc b/src/mhnfs/hopfield/modules/__pycache__/functional.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef4132b40a9aacc1917b88569c6fe57bd96ed805 Binary files /dev/null and b/src/mhnfs/hopfield/modules/__pycache__/functional.cpython-39.pyc differ diff --git a/src/mhnfs/hopfield/modules/activation.py b/src/mhnfs/hopfield/modules/activation.py new file mode 100644 index 0000000000000000000000000000000000000000..a84b44154be6f4fbbcf275166d5e470a1e2c0f90 --- /dev/null +++ b/src/mhnfs/hopfield/modules/activation.py @@ -0,0 +1,449 @@ +import torch +import torch.nn as nn + +from torch import Tensor +from torch.nn import Linear, Module, Parameter +from typing import Optional + +from .functional import hopfield_core_forward + +try: + from torch.nn.modules.linear import _LinearWithBias +except ImportError: + _LinearWithBias = None + + +class HopfieldCore(Module): + """Allows the model to jointly attend to information + from different representation subspaces. + See references: "Hopfield Networks is All You Need" and + "Attention Is All You Need" (on which this implementation is partly based on). + + .. math:: + \text{HopfieldHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O + \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) + + Args: + embed_dim: total dimension of the model. + num_heads: parallel attention heads. + dropout: a Dropout layer on attn_output_weights. Default: 0.0. + bias: add bias as module parameter. Default: True. + add_bias_kv: add bias to the key and value sequences at dim=0. + add_zero_attn: add a new batch of zeros to the key and + value sequences at dim=1. + kdim: total number of features in key. Default: None. + vdim: total number of features in value. Default: None. + + Note: if kdim and vdim are None, they will be set to embed_dim such that + query, key, and value have the same number of features. + + Examples:: + + >>> hopfield_attn = HopfieldCore(embed_dim, num_heads) + >>> attn_output, attn_output_weights, attn_matrix = hopfield_attn(query, key, value) + """ + + __annotations__ = { + "bias_k": torch._jit_internal.Optional[torch.Tensor], + "bias_v": torch._jit_internal.Optional[torch.Tensor], + } + + def __init__( + self, + embed_dim=None, # type: Optional[int] + num_heads=1, # type: int + dropout=0.0, # type: float + bias=True, # type: bool + add_bias_kv=False, # type: bool + add_zero_attn=False, # type: bool + kdim=None, # type: Optional[int] + vdim=None, # type: Optional[int] + head_dim=None, # type: Optional[int] + pattern_dim=None, # type: Optional[int] + out_dim=None, # type: Optional[int] + disable_out_projection=False, # type: bool + key_as_static=False, # type: bool + query_as_static=False, # type: bool + value_as_static=False, # type: bool + value_as_connected=False, # type: bool + normalize_pattern=False, # type: bool + normalize_pattern_affine=False, # type: bool + ): + super(HopfieldCore, self).__init__() + + assert ( + (type(key_as_static) == bool) + and (type(query_as_static) == bool) + and (type(value_as_static) == bool) + ) + self.key_as_static, self.query_as_static, self.value_as_static = ( + key_as_static, + query_as_static, + value_as_static, + ) + num_non_static = 3 - ( + self.key_as_static + self.query_as_static + self.value_as_static + ) + assert 0 <= num_non_static < 4 + + self.value_as_connected = value_as_connected + self.normalize_pattern, self.normalize_pattern_affine = ( + normalize_pattern, + normalize_pattern_affine, + ) + self.disable_out_projection = disable_out_projection + + # In case of a static-only executions, check corresponding projections and normalizations. + self.static_execution = self._check_execution_mode() + if self.static_execution: + embed_dim, kdim, vdim = None, None, None + if embed_dim is None: + assert ( + self.static_execution + ), r"static-only execution requires all projections to be deactivated." + + # Check and set all other properties, conditioned on . + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self._qkv_same_embed_dim = all( + ( + self.kdim == embed_dim, + self.vdim == embed_dim, + pattern_dim is None, + not self.value_as_connected, + ) + ) + assert (not self.value_as_connected) or ( + self.kdim == self.vdim + ), r"key and value need to be of same dimension." + + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = None + self.pattern_dim = pattern_dim + self.virtual_hopfield_dim = None + self.virtual_pattern_dim = None + if not self.static_execution: + if head_dim is None: + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads." + else: + assert ( + head_dim > 0 + ), "dimension of the association space has to be positive." + self.head_dim = head_dim + if self.pattern_dim is None: + self.pattern_dim = self.head_dim + self.virtual_hopfield_dim = self.num_heads * self.head_dim + self.virtual_pattern_dim = self.num_heads * self.pattern_dim + + self.out_dim = embed_dim if out_dim is None else out_dim + assert disable_out_projection or ( + self.out_dim > 0 + ), "output projection dimension has to be positive." + + if normalize_pattern_affine: + assert ( + normalize_pattern + ), "affine pattern normalization without pattern normalization has no effect." + self.p_norm_weight = Parameter(torch.Tensor(head_dim)) + self.p_norm_bias = Parameter(torch.Tensor(head_dim)) + else: + self.register_parameter("p_norm_weight", None) + self.register_parameter("p_norm_bias", None) + + if self._qkv_same_embed_dim is False: + if query_as_static: + self.register_parameter("q_proj_weight", None) + else: + self.q_proj_weight = Parameter( + torch.Tensor(self.virtual_hopfield_dim, embed_dim) + ) + if key_as_static: + self.register_parameter("k_proj_weight", None) + else: + self.k_proj_weight = Parameter( + torch.Tensor(self.virtual_hopfield_dim, self.kdim) + ) + if value_as_static: + self.register_parameter("v_proj_weight", None) + else: + self.v_proj_weight = Parameter( + torch.Tensor( + self.virtual_pattern_dim, + self.virtual_hopfield_dim + if (value_as_connected and not key_as_static) + else self.vdim, + ) + ) + self.register_parameter("in_proj_weight", None) + else: + if num_non_static > 0: + self.in_proj_weight = Parameter( + torch.empty( + (not query_as_static) * self.virtual_hopfield_dim + + (not key_as_static) * self.virtual_hopfield_dim + + (not value_as_static) * self.virtual_pattern_dim, + embed_dim, + ) + ) + else: + self.register_parameter("in_proj_weight", None) + self.register_parameter("q_proj_weight", None) + self.register_parameter("k_proj_weight", None) + self.register_parameter("v_proj_weight", None) + + if bias and (num_non_static > 0): + self.in_proj_bias = Parameter( + torch.empty( + (not query_as_static) * self.virtual_hopfield_dim + + (not key_as_static) * self.virtual_hopfield_dim + + self.virtual_pattern_dim + ) + ) + else: + self.register_parameter("in_proj_bias", None) + if disable_out_projection: + self.register_parameter("out_proj", None) + else: + if bias and _LinearWithBias is not None: + self.out_proj = _LinearWithBias(self.virtual_pattern_dim, self.out_dim) + else: + self.out_proj = Linear( + self.virtual_pattern_dim, self.out_dim, bias=bias + ) + + self.bias_k, self.bias_v = None, None + if add_bias_kv: + if not key_as_static: + self.bias_k = Parameter(torch.empty(1, 1, self.virtual_hopfield_dim)) + if not value_as_static: + self.bias_v = Parameter(torch.empty(1, 1, self.virtual_hopfield_dim)) + assert not ( + self.bias_k is None and self.bias_v is None + ), r"cannot set key/value bias if both are static." + + self.add_zero_attn = add_zero_attn + self.reset_parameters() + + def _check_execution_mode(self) -> bool: + return all( + ( + self.key_as_static, + self.query_as_static, + self.value_as_static, + not self.value_as_connected, + not self.normalize_pattern, + not self.normalize_pattern_affine, + self.disable_out_projection, + ) + ) + + def reset_parameters(self): + if self.p_norm_weight is not None: + nn.init.ones_(self.p_norm_weight) + nn.init.zeros_(self.p_norm_bias) + + if self._qkv_same_embed_dim and (self.in_proj_weight is not None): + nn.init.normal_(self.in_proj_weight, mean=0.0, std=0.02) + else: + if self.q_proj_weight is not None: + nn.init.normal_(self.q_proj_weight, mean=0.0, std=0.02) + if self.k_proj_weight is not None: + nn.init.normal_(self.k_proj_weight, mean=0.0, std=0.02) + if self.v_proj_weight is not None: + nn.init.normal_(self.v_proj_weight, mean=0.0, std=0.02) + + if self.in_proj_bias is not None: + nn.init.constant_(self.in_proj_bias, 0.0) + if not self.disable_out_projection: + nn.init.normal_(self.out_proj.weight, mean=0.0, std=0.02) + if self.out_proj.bias is not None: + nn.init.constant_(self.out_proj.bias, 0.0) + if self.bias_k is not None: + nn.init.normal_(self.bias_k, mean=0.0, std=0.02) + if self.bias_v is not None: + nn.init.normal_(self.bias_v, mean=0.0, std=0.02) + + def __setstate__(self, state): + super(HopfieldCore, self).__setstate__(state) + + def forward( + self, + query, # type: Tensor + key, # type: Tensor + value, # type: Tensor + key_padding_mask=None, # type: Optional[Tensor] + need_weights=True, # type: bool + attn_mask=None, # type: Optional[Tensor] + scaling=None, # type: Optional[Tensor] + update_steps_max=0, # type: Optional[int] + update_steps_eps=1e-4, # type: float + return_raw_associations=False, # type: bool + return_pattern_projections=False, # type: bool + ): + # type: (...) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]] + r""" + Args: + query, key, value: map a query and a set of key-value pairs to an output. + See "Attention Is All You Need" for more details. + See "Hopfield Networks is All You Need" for more details in the setting of Hopfield networks. + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. When given a binary mask and a value is True, + the corresponding value on the attention layer will be ignored. When given + a byte mask and a value is non-zero, the corresponding value on the attention + layer will be ignored. + need_weights: output attn_output_weights. + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + + scaling: scaling of association heads, often represented as beta (one entry per head). + update_steps_max: maximum count of association update steps (None equals to infinity). + update_steps_eps: minimum difference threshold between two consecutive association update steps. + return_raw_associations: return raw association (softmax) values, unmodified. + return_pattern_projections: return pattern projection values, unmodified. + + Shape: + - Inputs: + - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the position + with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + + - scaling: :math:`(num_heads,)`, where num_heads is the amount of heads. + + - Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_output_weights: :math:`(N, L, S)` where N is the batch size, + L is the target sequence length, S is the source sequence length. + - attn_raw: :math:``(N, num_heads, L, S)`, where N is the batch size, + L is the target sequence length, S is the source sequence length. + """ + if self.query_as_static and self.key_as_static: + assert ( + query.shape[2] == key.shape[2] + ), f"query shape[2] of {query.shape[2]} and key shape[2] of {key.shape[2]} need to be equal" + head_dim, embed_dim_to_check = query.shape[2], query.shape[2] + else: + assert self.query_as_static or ( + query.shape[2] == self.embed_dim + ), f"query shape[2] of {query.shape[2]} invalid, needs to be {self.embed_dim}." + assert (not self.query_as_static) or ( + self.query_as_static and query.shape[2] == self.head_dim + ), f"query shape[2] of {query.shape[2]} invalid, needs to be {self.head_dim}" + + assert self.key_as_static or ( + key.shape[2] == self.kdim + ), f"key shape[2] of {key.shape[2]} invalid, needs to be {self.kdim}." + assert (not self.key_as_static) or ( + self.key_as_static and key.shape[2] == self.head_dim + ), f"key shape[2] of {key.shape[2]} invalid, needs to be {self.head_dim}" + head_dim, embed_dim_to_check = ( + self.head_dim, + self.head_dim if self.query_as_static else self.embed_dim, + ) + + assert self.value_as_static or ( + value.shape[2] == self.vdim + ), f"value shape[2] of {value.shape[2]} invalid, needs to be {self.vdim}." + assert any( + ( + not self.value_as_static, + self.value_as_static and value.shape[2] == self.pattern_dim, + self.disable_out_projection, + ) + ), f"value shape[2] of {value.shape[2]} invalid, needs to be {self.pattern_dim}" + + out_weights, out_bias = None, None + if not self.disable_out_projection: + out_weights, out_bias = self.out_proj.weight, self.out_proj.bias + + if not self._qkv_same_embed_dim: + return hopfield_core_forward( + query=query, + key=key, + value=value, + embed_dim_to_check=embed_dim_to_check, + num_heads=self.num_heads, + in_proj_weight=self.in_proj_weight, + in_proj_bias=self.in_proj_bias, + bias_k=self.bias_k, + bias_v=self.bias_v, + add_zero_attn=self.add_zero_attn, + dropout_p=self.dropout, + out_proj_weight=out_weights, + out_proj_bias=out_bias, + training=self.training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + use_separate_proj_weight=True, + q_proj_weight=self.q_proj_weight, + k_proj_weight=self.k_proj_weight, + v_proj_weight=self.v_proj_weight, + key_as_static=self.key_as_static, + query_as_static=self.query_as_static, + value_as_static=self.value_as_static, + value_as_connected=self.value_as_connected, + normalize_pattern=self.normalize_pattern, + p_norm_weight=self.p_norm_weight, + p_norm_bias=self.p_norm_bias, + head_dim=head_dim, + pattern_dim=self.pattern_dim, + scaling=scaling, + update_steps_max=update_steps_max, + update_steps_eps=update_steps_eps, + return_raw_associations=return_raw_associations, + return_projected_patterns=return_pattern_projections, + ) + else: + return hopfield_core_forward( + query=query, + key=key, + value=value, + embed_dim_to_check=embed_dim_to_check, + num_heads=self.num_heads, + in_proj_weight=self.in_proj_weight, + in_proj_bias=self.in_proj_bias, + bias_k=self.bias_k, + bias_v=self.bias_v, + add_zero_attn=self.add_zero_attn, + dropout_p=self.dropout, + out_proj_weight=out_weights, + out_proj_bias=out_bias, + training=self.training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + key_as_static=self.key_as_static, + query_as_static=self.query_as_static, + value_as_static=self.value_as_static, + value_as_connected=self.value_as_connected, + normalize_pattern=self.normalize_pattern, + p_norm_weight=self.p_norm_weight, + p_norm_bias=self.p_norm_bias, + head_dim=head_dim, + pattern_dim=self.pattern_dim, + scaling=scaling, + update_steps_max=update_steps_max, + update_steps_eps=update_steps_eps, + return_raw_associations=return_raw_associations, + return_projected_patterns=return_pattern_projections, + ) diff --git a/src/mhnfs/hopfield/modules/functional.py b/src/mhnfs/hopfield/modules/functional.py new file mode 100644 index 0000000000000000000000000000000000000000..b3d429c272d3c06c0d96d72b084d7756a119de6c --- /dev/null +++ b/src/mhnfs/hopfield/modules/functional.py @@ -0,0 +1,450 @@ +import torch +import torch.nn as nn + +from torch import Tensor +from typing import Optional, Tuple, Union + + +def hopfield_core_forward(query, # type: Tensor + key, # type: Tensor + value, # type: Tensor + embed_dim_to_check, # type: int + num_heads, # type: int + in_proj_weight, # type: Optional[Tensor] + in_proj_bias, # type: Optional[Tensor] + bias_k, # type: Optional[Tensor] + bias_v, # type: Optional[Tensor] + add_zero_attn, # type: bool + dropout_p, # type: float + out_proj_weight, # type: Tensor + out_proj_bias, # type: Tensor + training=True, # type: bool + key_padding_mask=None, # type: Optional[Tensor] + need_weights=True, # type: bool + attn_mask=None, # type: Optional[Tensor] + use_separate_proj_weight=False, # type: bool + q_proj_weight=None, # type: Optional[Tensor] + k_proj_weight=None, # type: Optional[Tensor] + v_proj_weight=None, # type: Optional[Tensor] + static_k=None, # type: Optional[Tensor] + static_v=None, # type: Optional[Tensor] + + key_as_static=False, # type: bool + query_as_static=False, # type: bool + value_as_static=False, # type: bool + value_as_connected=False, # type: bool + normalize_pattern=False, # type: bool + p_norm_weight=None, # type: Optional[Tensor] + p_norm_bias=None, # type: Optional[Tensor] + head_dim=None, # type: Optional[int] + pattern_dim=None, # type: Optional[int] + scaling=None, # type: Optional[Union[float, Tensor]] + update_steps_max=0, # type: Optional[Union[int, Tensor]] + update_steps_eps=1e-4, # type: Union[float, Tensor] + return_raw_associations=False, # type: bool + return_projected_patterns=False # type: bool + ): + # type: (...) -> Tuple[Tensor, Optional[Tensor]] + r""" + Args: + query, key, value: map a query and a set of key-value pairs to an output. + See "Attention Is All You Need" for more details. + See "Hopfield Networks is All You Need" for more details in the setting of Hopfield networks. + embed_dim_to_check: total dimension of the model (in case of default head dimension). + num_heads: parallel attention heads. + in_proj_weight, in_proj_bias: input projection weight and bias. + bias_k, bias_v: bias of the key and value sequences to be added at dim=0. + add_zero_attn: add a new batch of zeros to the key and + value sequences at dim=1. + dropout_p: probability of an element to be zeroed. + out_proj_weight, out_proj_bias: the output projection weight and bias. + training: apply dropout if is ``True``. + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. This is an binary mask. When the value is True, + the corresponding value on the attention layer will be filled with -inf. + need_weights: output attn_output_weights. + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + use_separate_proj_weight: the function accept the proj. weights for query, key, + and value in different forms. If false, in_proj_weight will be used, which is + a combination of q_proj_weight, k_proj_weight, v_proj_weight. + q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias. + static_k, static_v: static key and value used for attention operators. + + key_as_static: interpret specified key as being static. + query_as_static: interpret specified key as being static. + value_as_static: interpret specified key as being static. + value_as_connected: connect value projection with key projection. + normalize_pattern: enable normalization of patterns. + p_norm_weight, p_norm_bias: pattern normalization weight and bias. + head_dim: dimensionality of each head. + pattern_dim: dimensionality of each projected value input. + scaling: scaling of association heads, often represented as beta (one entry per head). + update_steps_max: maximum count of association update steps (None equals to infinity). + update_steps_eps: minimum difference threshold between two consecutive association update steps. + return_raw_associations: return raw association (softmax) values, unmodified. + return_projected_patterns: return pattern projection values, unmodified. + + Shape: + Inputs: + - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions + will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + - static_k: :math:`(N*num_heads, S, head_dim)`, where S is the source sequence length, N is the batch size. + - static_v: :math:`(N*num_heads, S, head_dim)`, where S is the source sequence length, N is the batch size. + + - scaling: :math:`(num_heads,)`, where num_heads is the amount of heads. + + Outputs: + - attn_output: :math:`(L, N, E)`, where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_output_weights: :math:`(N, L, S)`, where N is the batch size, + L is the target sequence length, S is the source sequence length. + - attn_raw: :math:``(N, num_heads, L, S)`, where N is the batch size, + L is the target sequence length, S is the source sequence length. + """ + if not torch.jit.is_scripting(): + tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, + out_proj_weight, out_proj_bias) + if any([type(t) is not Tensor for t in tens_ops]) and nn.functional.has_torch_function(tens_ops): + return nn.functional.handle_torch_function( + hopfield_core_forward, tens_ops, query, key, value, + embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias, + bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight, + out_proj_bias, training=training, key_padding_mask=key_padding_mask, + need_weights=need_weights, attn_mask=attn_mask, + use_separate_proj_weight=use_separate_proj_weight, + q_proj_weight=q_proj_weight, k_proj_weight=k_proj_weight, + v_proj_weight=v_proj_weight, static_k=static_k, static_v=static_v, + key_as_static=key_as_static, query_as_static=query_as_static, + value_as_static=value_as_static, value_as_connected=value_as_connected, + normalize_pattern=normalize_pattern, p_norm_weight=p_norm_weight, p_norm_bias=p_norm_bias, + head_dim=head_dim, pattern_dim=pattern_dim, scaling=scaling, update_steps_max=update_steps_max, + update_steps_eps=update_steps_eps, return_raw_associations=return_raw_associations) + tgt_len, bsz, embed_dim = query.shape[0], value.shape[1], query.shape[2] + assert embed_dim == embed_dim_to_check + # allow MHA to have different sizes for the feature dimension + assert key.size(0) == value.size(0) and key.size(1) == value.size(1) + + assert (scaling is None) or (type(scaling) in (float, torch.Tensor)) + if type(scaling) == torch.Tensor: + assert scaling.ndimension() == 1 and scaling.shape[0] == num_heads, "only one entry per head." + + assert (update_steps_max is None) or (type(update_steps_max) in (int, torch.Tensor)) + if type(update_steps_max) == torch.Tensor: + assert update_steps_max.ndimension() == 1 and update_steps_max.shape[0] == num_heads, "only one entry per head." + elif type(update_steps_max) == int: + update_steps_max = torch.tensor([update_steps_max] * num_heads, dtype=torch.int32, device=query.device) + elif update_steps_max is None: + update_steps_max = -torch.ones(size=(num_heads,), dtype=torch.int32, device=query.device) + + assert type(update_steps_eps) in (float, torch.Tensor) + if type(update_steps_eps) == torch.Tensor: + assert update_steps_eps.ndimension() == 1 and update_steps_eps.shape[0] == num_heads, "only one entry per head." + assert (update_steps_eps <= 0.0).sum() == 0, "only positive thresholds allowed." + update_steps_eps = update_steps_eps.to(device=query.device) + elif type(update_steps_eps) == float: + assert update_steps_eps > 0, "only positive thresholds allowed." + update_steps_eps = torch.tensor([update_steps_eps] * num_heads, dtype=query.dtype, device=query.device) + + # Adapt dimensionality of each each. + if head_dim is None: + head_dim = embed_dim // num_heads + assert head_dim * num_heads == embed_dim, r'embed_dim must be divisible by num_heads.' + hopfield_dim = num_heads * head_dim + + # Adapt dimensionality of each value projection. + if pattern_dim is None: + pattern_dim = head_dim + assert (not value_as_connected) or (pattern_dim == head_dim) + + q, k, v, xi, src_len = None, None, None, None, 0 + update_step, xi_old, xi_difference_norm = 0, None, float(r'+inf') + update_active_heads = torch.tensor([[[True]]] * num_heads * bsz, device=query.device) + assert update_active_heads.any(), "at least one head needs to be active." + + #################################################################################################################### + # BEGIN HOPFIELD UPDATE ITERATION # + #################################################################################################################### + + while update_active_heads.any(): + + # The query is already projected into the "Hopfield" space at "update_step" equals 0. + # No more projection necessary if "update_step" greater than 0. + if update_step == 0: + if not use_separate_proj_weight: + + if torch.equal(query, key) and torch.equal(key, value) and not ( + key_as_static or query_as_static or value_as_static): + # self-attention + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1) + + elif torch.equal(key, value) and not (key_as_static or value_as_static): + # encoder-decoder attention + _start, _end = 0, hopfield_dim + if query_as_static: + q = query.repeat(1, num_heads, 1) + else: + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + q = nn.functional.linear(query, _w, _b) + _start = hopfield_dim + _end = None + + if key is None: + assert value is None + k = None + v = None + else: + + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1) + + else: + _start, _end = 0, hopfield_dim + if query_as_static: + q = query.repeat(1, num_heads, 1) + else: + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + q = nn.functional.linear(query, _w, _b) + _start += hopfield_dim + _end += hopfield_dim + + if key_as_static: + k = key.repeat(1, num_heads, 1) + else: + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + k = nn.functional.linear(key, _w, _b) + _start += hopfield_dim + _end += hopfield_dim + + if value_as_static: + v = value.repeat(1, num_heads, 1) + else: + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + v = nn.functional.linear(value, _w, _b) + else: + _start, _end = 0, hopfield_dim + if query_as_static: + q = query.repeat(1, num_heads, 1) + else: + q_proj_weight_non_opt = torch.jit._unwrap_optional(q_proj_weight) + len1, len2 = q_proj_weight_non_opt.size() + assert len1 == hopfield_dim and len2 == query.size(-1) + if in_proj_bias is not None: + q = nn.functional.linear(query, q_proj_weight_non_opt, in_proj_bias[_start:_end]) + _start += hopfield_dim + _end += hopfield_dim + else: + q = nn.functional.linear(query, q_proj_weight_non_opt, in_proj_bias) + + v = value + if key_as_static: + k = key.repeat(1, num_heads, 1) + else: + k_proj_weight_non_opt = torch.jit._unwrap_optional(k_proj_weight) + len1, len2 = k_proj_weight_non_opt.size() + assert len1 == hopfield_dim and len2 == key.size(-1) + + _bias = None if in_proj_bias is None else in_proj_bias[_start:_end] + k = nn.functional.linear(key, k_proj_weight_non_opt, _bias) + if value_as_connected: + v = nn.functional.linear(v, k_proj_weight_non_opt, _bias) + _start += hopfield_dim + _end += num_heads * pattern_dim + + if value_as_static: + if not (value_as_connected or key_as_static): + v = v.repeat(1, num_heads, 1) + else: + v_proj_weight_non_opt = torch.jit._unwrap_optional(v_proj_weight) + len1, len2 = v_proj_weight_non_opt.size() + assert len1 == (num_heads * pattern_dim) and len2 == v.size(-1) + if in_proj_bias is not None: + v = nn.functional.linear(v, v_proj_weight_non_opt, in_proj_bias[_start:]) + else: + v = nn.functional.linear(v, v_proj_weight_non_opt, in_proj_bias) + + if attn_mask is not None: + assert attn_mask.dtype == torch.float32 or attn_mask.dtype == torch.float64 or \ + attn_mask.dtype == torch.float16 or attn_mask.dtype == torch.uint8 or \ + attn_mask.dtype == torch.bool, \ + 'Only float, byte, and bool types are supported for attn_mask, not {}'.format(attn_mask.dtype) + if attn_mask.dtype == torch.uint8: + warnings.warn( + "Byte tensor for attn_mask in nn.HopfieldCore is deprecated. Use bool tensor instead.") + attn_mask = attn_mask.to(torch.bool) + + if attn_mask.dim() == 2: + attn_mask = attn_mask.unsqueeze(0) + if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: + raise RuntimeError('The size of the 2D attn_mask is not correct.') + elif attn_mask.dim() == 3: + if list(attn_mask.size()) != [bsz * num_heads, query.size(0), key.size(0)]: + raise RuntimeError('The size of the 3D attn_mask is not correct.') + else: + raise RuntimeError("attn_mask's dimension {} is not supported".format(attn_mask.dim())) + # attn_mask's dim is 3 now. + + # Optionally normalize patterns. + if normalize_pattern: + q = torch.nn.functional.layer_norm( + input=q.reshape(shape=(-1, head_dim)), normalized_shape=(head_dim,), + weight=p_norm_weight, bias=p_norm_bias).reshape(shape=q.shape) + k = torch.nn.functional.layer_norm( + input=k.reshape(shape=(-1, head_dim)), normalized_shape=(head_dim,), + weight=p_norm_weight, bias=p_norm_bias).reshape(shape=k.shape) + + else: + active_xi = xi.masked_select(mask=update_active_heads).view(size=(-1, *xi.shape[1:])) + active_k = k.masked_select(mask=update_active_heads).view(size=(-1, *k.shape[1:])) + q = torch.masked_scatter(input=q, mask=update_active_heads, source=torch.bmm(active_xi, active_k)) + + # Optionally scale association heads (each head separately). + if type(scaling) == float: + q = q * scaling + elif type(scaling) == torch.Tensor: + q = q * scaling.view(1, 1, -1).repeat(repeats=(1, 1, q.shape[2] // scaling.shape[0])) + + if update_step == 0: + # convert ByteTensor key_padding_mask to bool + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + warnings.warn( + "Byte tensor for key_padding_mask in nn.HopfieldCore is deprecated. Use bool tensor instead.") + key_padding_mask = key_padding_mask.to(torch.bool) + + if bias_k is not None and bias_v is not None: + if static_k is None and static_v is None and key_as_static is None and value_as_static is None: + k = torch.cat([k, bias_k.repeat(1, bsz, 1)]) + v = torch.cat([v, bias_v.repeat(1, bsz, 1)]) + if attn_mask is not None: + attn_mask = nn.functional.pad(attn_mask, [0, 1]) + if key_padding_mask is not None: + key_padding_mask = nn.functional.pad(key_padding_mask, [0, 1]) + else: + assert static_k is None, "bias cannot be added to static key." + assert static_v is None, "bias cannot be added to static value." + assert not key_as_static, "bias cannot be added to static key." + assert not value_as_static, "bias cannot be added to static value." + else: + assert bias_k is None + assert bias_v is None + + q = q.contiguous().view(tgt_len, -1, head_dim).transpose(0, 1) + if k is not None: + k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) + if v is not None: + v = v.contiguous().view(v.shape[0], bsz * num_heads, -1).transpose(0, 1) + + if static_k is not None: + assert static_k.size(0) == bsz * num_heads + assert static_k.size(2) == head_dim + k = static_k + + if static_v is not None: + assert static_v.size(0) == bsz * num_heads + assert static_v.size(2) == pattern_dim + v = static_v + + src_len = k.size(1) + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz + assert key_padding_mask.size(1) == src_len + + if add_zero_attn: + src_len += 1 + k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1) + v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1) + if attn_mask is not None: + attn_mask = nn.functional.pad(attn_mask, [0, 1]) + if key_padding_mask is not None: + key_padding_mask = nn.functional.pad(key_padding_mask, [0, 1]) + + attn_output_weights = torch.bmm(q, k.transpose(1, 2)) + assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len] + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_output_weights.masked_fill_(attn_mask, float('-inf')) + else: + attn_output_weights += attn_mask + + if key_padding_mask is not None: + attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) + attn_output_weights = attn_output_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2), + float('-inf'), + ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len) + + # Compute new xi for Hopfield retrieve iterations. + if xi is None: + xi = nn.functional.softmax(attn_output_weights, dim=-1) + else: + xi = torch.masked_scatter(input=xi, mask=update_active_heads, source=nn.functional.softmax( + attn_output_weights.masked_select(mask=update_active_heads).view(size=(-1, *xi.shape[1:])), dim=-1)) + + # Compute threshold-based stopping criterion for Hopfield retrieve iterations. + with torch.no_grad(): + xi_active = xi.view(size=(bsz, num_heads, tgt_len, src_len)) + update_active_heads = (update_step < update_steps_max) | (update_steps_max < 0) + if xi_old is not None: + update_active_heads &= ((xi_old - xi_active).norm(p=2, dim=(2, 3)).max(axis=0)[0]) > update_steps_eps + update_active_heads = update_active_heads.unsqueeze(dim=1).unsqueeze(dim=2).repeat(repeats=(bsz, 1, 1)) + xi_old = xi_active + update_step += 1 + + #################################################################################################################### + # END HOPFIELD UPDATE ITERATION # + #################################################################################################################### + + attn_output_weights = nn.functional.dropout(xi, p=dropout_p, training=training) + attn_output = torch.bmm(attn_output_weights, v) + assert list(attn_output.shape[:2]) == [bsz * num_heads, tgt_len] + attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, -1) + if out_proj_weight is not None: + assert attn_output.shape[2] == num_heads * pattern_dim + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) + + xi = xi.view(bsz, num_heads, tgt_len, src_len) if return_raw_associations else None + v = v.view(bsz, num_heads, src_len, -1) if return_projected_patterns else None + if need_weights: + # average attention weights over heads + attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) + return attn_output, attn_output_weights.sum(dim=1) / num_heads, xi, v + else: + return attn_output, None, xi, v diff --git a/src/mhnfs/hopfield/modules/transformer.py b/src/mhnfs/hopfield/modules/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..884e0cb4b57610cf1daf8147f2c3d59f17824750 --- /dev/null +++ b/src/mhnfs/hopfield/modules/transformer.py @@ -0,0 +1,209 @@ +import torch +import torch.nn as nn + +from copy import deepcopy +from torch import Tensor +from torch.nn.modules import Module +from typing import Optional, Tuple, Union + +from . import Hopfield + + +class HopfieldEncoderLayer(Module): + """ + Module with underlying Hopfield association to be used as an encoder in transformer-like architectures. + """ + + def __init__(self, + hopfield_association: Hopfield, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation: str = r'relu' + ): + """ + Initialise a new instance of a Hopfield association-based encoder module. + + :param hopfield_association: instance of Hopfield association module + :param dim_feedforward: depth of the linear projections applied internally + :param activation: activation to be applied on the result of the internal linear projections + :param dropout: dropout probability to be applied internally + """ + super(HopfieldEncoderLayer, self).__init__() + self.hopfield_association = deepcopy(hopfield_association) + + self.linear_residual = nn.Linear(self.hopfield_association.state_pattern_dim, dim_feedforward) + self.dropout_residual = nn.Dropout(dropout) + self.linear_output = nn.Linear(dim_feedforward, self.hopfield_association.state_pattern_dim) + + self.norm_residual = nn.LayerNorm(self.hopfield_association.state_pattern_dim) + self.norm_output = nn.LayerNorm(self.hopfield_association.state_pattern_dim) + self.dropout_hopfield_association = nn.Dropout(dropout) + self.dropout_output = nn.Dropout(dropout) + + self.activation_residual = getattr(torch, activation, None) + assert self.activation_residual is not None, r'invalid activation function supplied.' + self.reset_parameters() + + def reset_parameters(self) -> None: + """ + Reset parameters, including Hopfield association. + + :return: None + """ + for module in (self.hopfield_association, self.linear_residual, + self.linear_output, self.norm_residual, self.norm_output): + if hasattr(module, r'reset_parameters'): + module.reset_parameters() + + def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None) -> Tensor: + """ + Apply Hopfield encoding on specified data. + + :param src: data to be processed by Hopfield encoder module + :param src_mask: mask to be applied on association matrix + :param src_key_padding_mask: mask to be applied on stored patterns + :return: Hopfield-encoded input data + """ + data_associated = self.hopfield_association( + input=src, stored_pattern_padding_mask=src_key_padding_mask, association_mask=src_mask) + src = src + self.dropout_hopfield_association(input=data_associated) + src = self.norm_residual(input=src) + + result_residual_inner = self.activation_residual(input=self.linear_residual(input=src)) + data_associated = self.linear_output(input=self.dropout_residual(input=result_residual_inner)) + src = src + self.dropout_output(input=data_associated) + + return self.norm_output(input=src) + + def get_association_matrix(self, input: Union[Tensor, Tuple[Tensor, Tensor, Tensor]]) -> Tensor: + """ + Fetch Hopfield association matrix gathered by passing through the specified data. + + :param input: data to be passed through the Hopfield association + :return: association matrix as computed by the Hopfield core module + """ + return self.hopfield_association.get_association_matrix(input=input) + + @property + def batch_first(self) -> int: + return self.hopfield_association.batch_first + + @property + def input_size(self) -> int: + return self.hopfield_association.input_size + + @property + def output_size(self) -> int: + return self.linear_output.out_features + + +class HopfieldDecoderLayer(Module): + + def __init__(self, + hopfield_association_self: Hopfield, + hopfield_association_cross: Hopfield, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation: str = r'relu' + ): + """ + Initialise a new instance of a Hopfield association-based encoder module. + + :param hopfield_association_self: instance of Hopfield self-association module + :param hopfield_association_cross: instance of Hopfield cross-association module + :param dim_feedforward: depth of the linear projections applied internally + :param dropout: dropout probability to be applied internally + :param activation: activation to be applied on the result of the internal linear projections + """ + super(HopfieldDecoderLayer, self).__init__() + self.hopfield_association_self = deepcopy(hopfield_association_self) + self.hopfield_association_cross = deepcopy(hopfield_association_cross) + + self.linear_residual = nn.Linear(self.hopfield_association_self.state_pattern_dim, dim_feedforward) + self.dropout_residual = nn.Dropout(dropout) + self.linear_output = nn.Linear(dim_feedforward, self.hopfield_association_self.state_pattern_dim) + + self.norm_residual_self = nn.LayerNorm(self.hopfield_association_self.state_pattern_dim) + self.norm_residual_cross = nn.LayerNorm(self.hopfield_association_self.state_pattern_dim) + self.norm_output = nn.LayerNorm(self.hopfield_association_self.state_pattern_dim) + self.dropout_hopfield_association_self = nn.Dropout(dropout) + self.dropout_hopfield_association_cross = nn.Dropout(dropout) + self.dropout_output = nn.Dropout(dropout) + + self.activation_residual = getattr(torch, activation, None) + assert self.activation_residual is not None, r'invalid activation function supplied.' + self.reset_parameters() + + def reset_parameters(self) -> None: + """ + Reset parameters, including Hopfield association. + + :return: None + """ + for module in (self.hopfield_association_self, self.hopfield_association_cross, + self.linear_residual, self.linear_output, self.norm_residual_self, + self.norm_residual_cross, self.norm_output): + if hasattr(module, r'reset_parameters'): + module.reset_parameters() + + def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None) -> Tensor: + """ + Apply Hopfield decoding on specified data. + + :param tgt: data to be processed by Hopfield decoder module (self-association) + :param memory: data to be processed by Hopfield encoder module (cross-association) + :param tgt_mask: mask to be applied on self-association matrix + :param memory_mask: mask to be applied on cross-association matrix + :param tgt_key_padding_mask: mask to be applied on stored patterns + :param memory_key_padding_mask: mask to be applied on state patterns as well as pattern projection + :return: Hopfield-decoded input + """ + data_associated = self.hopfield_association_self( + input=tgt, stored_pattern_padding_mask=tgt_key_padding_mask, + association_mask=tgt_mask) + tgt = tgt + self.dropout_hopfield_association_self(input=data_associated) + tgt = self.norm_residual_self(input=tgt) + + data_associated = self.hopfield_association_cross( + input=(memory, tgt, memory), stored_pattern_padding_mask=memory_key_padding_mask, + association_mask=memory_mask) + tgt = tgt + self.dropout_hopfield_association_cross(input=data_associated) + tgt = self.norm_residual_cross(input=tgt) + + result_residual_inner = self.activation_residual(input=self.linear_residual(input=tgt)) + data_associated = self.linear_output(input=self.dropout_residual(input=result_residual_inner)) + tgt = tgt + self.dropout_output(input=data_associated) + return self.norm_output(input=tgt) + + def get_association_matrix_self(self, input: Union[Tensor, Tuple[Tensor, Tensor, Tensor]]) -> Tensor: + """ + Fetch Hopfield self-association matrix gathered by passing through the specified data. + + :param input: data to be passed through the Hopfield association + :return: association matrix as computed by the Hopfield core module + """ + return self.hopfield_association_self.get_association_matrix(input=input) + + def get_association_matrix_cross(self, input: Union[Tensor, Tuple[Tensor, Tensor, Tensor]]) -> Tensor: + """ + Fetch Hopfield cross-association matrix gathered by passing through the specified data. + + :param input: data to be passed through the Hopfield association + :return: association matrix as computed by the Hopfield core module + """ + return self.hopfield_association_cross.get_association_matrix(input=input) + + @property + def batch_first(self) -> int: + return self.hopfield_association_self.batch_first + + @property + def input_size(self) -> int: + return self.hopfield_association_self.input_size + + @property + def output_size(self) -> int: + return self.linear_output_self.out_features diff --git a/src/mhnfs/initialization.py b/src/mhnfs/initialization.py new file mode 100644 index 0000000000000000000000000000000000000000..103f29553a76a21b97bb1f4824c62dbfaad0221a --- /dev/null +++ b/src/mhnfs/initialization.py @@ -0,0 +1,29 @@ +from torch import nn +import torch + + +def init_lecun(m): + nn.init.normal_( + m.weight, + mean=0.0, + std=torch.sqrt(torch.tensor([1.0]) / m.in_features).numpy()[0], + ) + nn.init.zeros_(m.bias) + + +def init_kaiming(m, nonlinearity): + nn.init.kaiming_normal_(m.weight, mode="fan_in", nonlinearity=nonlinearity) + nn.init.zeros_(m.bias) + + +@torch.no_grad() +def init_weights(m, activation_function="linear"): + if activation_function == "relu": + if type(m) == nn.Linear: + init_kaiming(m, nonlinearity="relu") + elif activation_function == "selu": + if type(m) == nn.Linear: + init_lecun(m) + elif activation_function == "linear": + if type(m) == nn.Linear: + init_lecun(m) diff --git a/src/mhnfs/model.py b/src/mhnfs/model.py new file mode 100644 index 0000000000000000000000000000000000000000..0fd46def671434367cb939bf69494a46177fc070 --- /dev/null +++ b/src/mhnfs/model.py @@ -0,0 +1,199 @@ +import pytorch_lightning as pl +import numpy as np +import torch + +from src.mhnfs.modules import ( + EncoderBlock, + ContextModule, + LayerNormalizingBlock, + CrossAttentionModule, + SimilarityModule, +) + +class MHNfs(pl.LightningModule): + """ + The MHNfs is a few-shot drug-discovery model for activity prediction. + + For a requested query molecule, MHNfs predicts activity, while known knowledge from + the support set is used. + + MHNfs can be seen as an embedding-based few-shot method since the prediction is + based on similarities of molecule representations in a learned "representation + space". Being able to build rich, expressive molecule representations is the key for + a predictive model. + + MHNfs consists of + three consecutive modules: + - the context module, + - the cross attention module, and + - the similarity module. + + The context module associates the query and support set molecules with context - + i.e., a large set of training molecules. + + The cross-attention module allows for information sharing between query and support + set molecules. + + The similirity modules computes pair-wise similarity values between query and sup- + port set molecules and uses these similarity values to build a prediction from a + weighted sum over the support set labels. + """ + + def __init__(self, cfg): + super(MHNfs, self).__init__() + + # Config + self.cfg = cfg + + # Load context set + current_loc = __file__.rsplit("/",3)[0] + self.context = ( + torch.unsqueeze( + torch.from_numpy( + np.load(current_loc + "/assets/mhnfs_data/full_context_set.npy") + ), + 0, + ) + .float() + ) + + self.context_embedding = torch.ones(1, 512, 1024) + + self.layerNorm_context = torch.nn.LayerNorm( + cfg.model.associationSpace_dim, + elementwise_affine=cfg.model.layerNormBlock.affine, + ) + + # Encoder + self.encoder = EncoderBlock(cfg) + + # Context module + self.contextModule = ContextModule(self.cfg) + + # Layernormalizing-block + self.layerNormBlock = LayerNormalizingBlock(cfg) + + # Cross-attention module + self.crossAttentionModule = CrossAttentionModule(self.cfg) + + # Similarity module + self.similarity_function = SimilarityModule + + # Output function + self.sigmoid = torch.nn.Sigmoid() + self.prediction_scaling = cfg.model.prediction_scaling + + def forward( + self, + query_molecules: torch.Tensor, + support_molecules_active: torch.Tensor, + support_molecules_inactive: torch.Tensor, + support_set_actives_size: torch.Tensor, + support_set_inactives_size: torch.Tensor, + ) -> torch.Tensor: + # Get embeddings from molecule encoder + query_embedding = self.encoder(query_molecules) + support_actives_embedding = self.encoder(support_molecules_active) + support_inactives_embedding = self.encoder(support_molecules_inactive) + + # Retrieve updated representations from the context module + # - Layernorm + ( + query_embedding, + support_actives_embedding, + support_inactives_embedding, + ) = self.layerNormBlock( + query_embedding, support_actives_embedding, support_inactives_embedding + ) + + # - Expand support set related tensors + support_actives_embedding = support_actives_embedding.expand( + query_embedding.shape[0], -1, -1) + support_inactives_embedding = support_inactives_embedding.expand( + query_embedding.shape[0], -1, -1) + support_set_actives_size = support_set_actives_size.expand( + query_embedding.shape[0]) + support_set_inactives_size = support_set_inactives_size.expand( + query_embedding.shape[0]) + + # - Context module + ( + query_embedding, + support_actives_embedding, + support_inactives_embedding, + ) = self.contextModule( + query_embedding, + support_actives_embedding, + support_inactives_embedding, + self.context_embedding, + ) + + # Allow for information sharing between query and support set + # - Layernorm + ( + query_embedding, + support_actives_embedding, + support_inactives_embedding, + ) = self.layerNormBlock( + query_embedding, support_actives_embedding, support_inactives_embedding + ) + + # - Cross-attention module + ( + query_embedding, + support_actives_embedding, + support_inactives_embedding, + ) = self.crossAttentionModule( + query_embedding, + support_actives_embedding, + support_inactives_embedding, + ) + + # Build predictions from a weighted sum over support set labels + # - Layernorm: + if self.cfg.model.layerNormBlock.usage: + ( + query_embedding, + support_actives_embedding, + support_inactives_embedding, + ) = self.layerNormBlock( + query_embedding, support_actives_embedding, support_inactives_embedding + ) + + # - Similarity module: + predictions_support_actives = self.similarity_function( + query_embedding, + support_actives_embedding, + support_set_actives_size, + self.cfg, + ) + + predictions_support_inactives = self.similarity_function( + query_embedding, + support_inactives_embedding, + support_set_inactives_size, + self.cfg, + ) + + predictions = predictions_support_actives - predictions_support_inactives + + predictions = self.sigmoid(self.prediction_scaling * predictions) + + return predictions + + @torch.no_grad() + def _update_context_set_embedding(self): + """ + This function randomly samples the context set as a subset of all available + training molecules + """ + max_rows = self.context.shape[1] + number_requested_rows = int( + np.rint(self.cfg.model.context.ratio_training_molecules * max_rows) + ) + + sampled_rows = torch.randperm(max_rows)[:number_requested_rows] + + self.context_embedding = self.layerNorm_context( + self.encoder(self.context[:, sampled_rows, :]) + ) diff --git a/src/mhnfs/modules.py b/src/mhnfs/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..6767205f780c9d6d9f4b471e4171cfd063a1f812 --- /dev/null +++ b/src/mhnfs/modules.py @@ -0,0 +1,574 @@ +import torch +import torch.nn as nn +#import hydra +from omegaconf import OmegaConf +from functools import partial +import math +from torch.nn import functional as F + +from src.mhnfs.hopfield.modules import Hopfield +from src.mhnfs.initialization import init_weights + + +# Mappings +activation_function_mapping = { + "relu": nn.ReLU(), + "selu": nn.SELU(), + "sigmoid": nn.Sigmoid(), +} + +dropout_mapping = {"relu": nn.Dropout, "selu": nn.AlphaDropout} + + +# Modules +class EncoderBlock(nn.Module): + """ + Fully connected molecule encoder block. + - Takes molecular descriptors, e.g., ECFPs and RDKit fps as inputs + - returns a molecular representation + """ + + def __init__(self, cfg: OmegaConf): + super(EncoderBlock, self).__init__() + + # Input layer + self.dropout = dropout_mapping[cfg.model.encoder.activation]( + cfg.model.encoder.regularization.input_dropout + ) + self.fc = nn.Linear( + cfg.model.encoder.input_dim, cfg.model.encoder.number_hidden_neurons + ) + self.act = activation_function_mapping[cfg.model.encoder.activation] + + # Hidden layer + self.hidden_linear_layers = nn.ModuleList([]) + self.hidden_dropout_layers = nn.ModuleList([]) + self.hidden_activations = nn.ModuleList([]) + + for _ in range(cfg.model.encoder.number_hidden_layers): + self.hidden_dropout_layers.append( + dropout_mapping[cfg.model.encoder.activation]( + cfg.model.encoder.regularization.dropout + ) + ) + self.hidden_linear_layers.append( + nn.Linear( + cfg.model.encoder.number_hidden_neurons, + cfg.model.encoder.number_hidden_neurons, + ) + ) + self.hidden_activations.append( + activation_function_mapping[cfg.model.encoder.activation] + ) + + # Output layer + self.dropout_o = dropout_mapping[cfg.model.encoder.activation]( + cfg.model.encoder.regularization.dropout + ) + self.fc_o = nn.Linear( + cfg.model.encoder.number_hidden_neurons, + cfg.model.associationSpace_dim, + ) + self.act_o = activation_function_mapping[cfg.model.encoder.activation] + + # Initialization + encoder_initialization = partial(init_weights, cfg.model.encoder.activation) + self.apply(encoder_initialization) + + def forward(self, molecule_representation: torch.Tensor) -> torch.Tensor: + # Input layer + x = self.dropout(molecule_representation) + x = self.fc(x) + x = self.act(x) + + # Hidden layer + for hidden_dropout, hidden_layer, hidden_activation_function in zip( + self.hidden_dropout_layers, + self.hidden_linear_layers, + self.hidden_activations, + ): + x = hidden_dropout(x) + x = hidden_layer(x) + x = hidden_activation_function(x) + + # Output layer + x = self.dropout_o(x) + x = self.fc_o(x) + x = self.act_o(x) + + return x + + +class ContextModule(nn.Module): + """ + Allows for mutual information sharing. + Enriches the query and support set embeddings with context by associating a query or + support set molecule with the context set, i.e., large set of training molecules: + - The context set can be seen as an external memory + - For a given molecule embedding, a Modern Hopfield Network retrieves a representa- + tion from the external memory + + Since we have to retrieve representations for all query and support set molecules we + stack all embeddings together and perform a "batch-retrieval". + """ + + def __init__(self, cfg: OmegaConf): + super(ContextModule, self).__init__() + + self.cfg = cfg + + self.hopfield = Hopfield( + input_size=self.cfg.model.associationSpace_dim, + hidden_size=cfg.model.hopfield.dim_QK, + stored_pattern_size=self.cfg.model.associationSpace_dim, + pattern_projection_size=self.cfg.model.associationSpace_dim, + output_size=self.cfg.model.associationSpace_dim, + num_heads=self.cfg.model.hopfield.heads, + scaling=self.cfg.model.hopfield.beta, + dropout=self.cfg.model.hopfield.dropout, + ) + + # Initialization + hopfield_initialization = partial(init_weights, "linear") + self.hopfield.apply(hopfield_initialization) + + def forward( + self, + query_embedding: torch.Tensor, + support_actives_embedding: torch.Tensor, + support_inactives_embedding: torch.Tensor, + context_set_embedding: torch.Tensor, + ) -> tuple: + """ + inputs: + - query; torch.tensor; + dim: [batch-size, 1, initial-embedding-dimension] + * e.g.: [512, 1, 1024] + * initial-embedding-dimension: defined by molecule encoder block + - active support set molecules; torch.tensor; + dim: [batch-size, active-padding-dim, initial-embedding-dimension] + * e.g.: [512, 9, 1024] + - inactive support set molecules; torch.tensor; + dim: [batch-size, inactive-padding-dim, initial-embedding-dimension] + * e.g.: [512, 11, 1024] + - context set molecules; torch.tensor; + dim: [1, number-of-context-molecules, initial-embedding-dimension] + * e.g.: [1, 512, 1024] + + return: + tuple which includes the updated representations for query, active, and inactive + support set molecules: + (query, active support set molecules, inactive support set molecules) + """ + # Stack embeddings together to perform a "batch-retrieval" + s = torch.cat( + (query_embedding, support_actives_embedding, support_inactives_embedding), + dim=1, + ) + s_flattend = s.reshape(1, s.shape[0] * s.shape[1], s.shape[2]) + + # Retrieval + s_h = self.hopfield((context_set_embedding, s_flattend, context_set_embedding)) + + # Combine retrieval with skip connection + s_updated = s_flattend + s_h + s_updated_inputShape = s_updated.reshape( + s.shape[0], s.shape[1], s.shape[2] + ) # reshape tensor back to input shape + + query_embedding = s_updated_inputShape[:, 0, :] + query_embedding = torch.unsqueeze(query_embedding, 1) + + # Split query, active and inactive support set embeddings + padding_size_actives = support_actives_embedding.shape[1] + + support_actives_embedding = s_updated_inputShape[ + :, 1 : (padding_size_actives + 1), : + ] + support_inactives_embedding = s_updated_inputShape[ + :, (padding_size_actives + 1) :, : + ] + + return query_embedding, support_actives_embedding, support_inactives_embedding + + +class LayerNormalizingBlock(nn.Module): + """ + Layernorm-block which scales/transforms the representations for query, ac- + tive, and inactive support set molecules. + """ + + def __init__(self, cfg: OmegaConf): + super(LayerNormalizingBlock, self).__init__() + + self.cfg = cfg + + if cfg.model.layerNormBlock.usage: + self.layernorm_query = nn.LayerNorm( + cfg.model.associationSpace_dim, + elementwise_affine=cfg.model.layerNormBlock.affine, + ) + self.layernorm_support_actives = nn.LayerNorm( + cfg.model.associationSpace_dim, + elementwise_affine=cfg.model.layerNormBlock.affine, + ) + self.layernorm_support_inactives = nn.LayerNorm( + cfg.model.associationSpace_dim, + elementwise_affine=cfg.model.layerNormBlock.affine, + ) + + def forward( + self, + query_embedding: torch.Tensor, + support_actives_embedding: torch.Tensor, + support_inactives_embedding: torch.Tensor, + ) -> tuple: + """ + inputs: + - query; torch.tensor; + dim: [batch-size, 1, embedding-dim] + * e.g.: [512, 1, 1024] + - active support set molecules; torch.tensor; + dim: [batch-size, active-padding-dim, embedding-dim] + * e.g.: [512, 9, 1024] + - inactive support set molecules; torch.tensor; + dim: [batch-size, inactive-padding-dim, initial-embedding-dim] + * e.g.: [512, 11, 1024] + + return: + tuple which includes the updated representations for query, active, and inactive + support set molecules: + (query, active support set molecules, inactive support set molecules) + """ + + # Layer normalization + # Since the layernorm operations are optional the module just updates represen- + # tations if the the referring option is set in the config. + if self.cfg.model.layerNormBlock.usage: + query_embedding = self.layernorm_query(query_embedding) + support_actives_embedding = self.layernorm_support_actives( + support_actives_embedding + ) + if support_inactives_embedding is not None: + support_inactives_embedding = self.layernorm_support_inactives( + support_inactives_embedding + ) + return query_embedding, support_actives_embedding, support_inactives_embedding + + +class CrossAttentionModule(nn.Module): + """ + The cross-attention module allows for information sharing between query and support + set molecules. + + Altae-Tran et al. [1] showed that representations can be enriched by making the + query molecule aware of the support set molecules and making the support set mole- + cules aware of each other and of the query molecule. We enable information sharing + with a transformer. + + Overview of the cross-attention module: + 1) The query and support set molecules are concatenated such that one joint matrix + emerges which includes both query and support set molecules. + 2) The joint matrix is fed into a transformer + - Self-attention enables information sharing between query and support set mole- + cules + + [1] Altae-Tran, H., Ramsundar, B., Pappu, A. S., & Pande, V. (2017). Low data drug + discovery with one-shot learning. ACS central science, 3(4), 283-293. + """ + + def __init__(self, cfg: OmegaConf): + + super(CrossAttentionModule, self).__init__() + + self.cfg = cfg + + cfg_gpt = self.GPTConfig() + self.transformer_block = self.TranformerBlock(cfg_gpt) + + # Initialization + encoder_initialization = partial(init_weights, 'relu') + self.apply(encoder_initialization) + + def forward( + self, + query_embedding: torch.Tensor, + support_actives_embedding: torch.Tensor, + support_inactives_embedding: torch.Tensor, + ) -> tuple: + """ + inputs: + - query; torch.tensor; + dim: [batch-size, 1, embedding-dim] + * e.g.: [512, 1, 1024] + - active support set molecules; torch.tensor; + dim: [batch-size, active-padding-dim, embedding-dim] + * e.g.: [512, 9, 1024] + - inactive support set molecules; torch.tensor; + dim: [batch-size, inactive-padding-dim, initial-embedding-dim] + * e.g.: [512, 11, 1024] + - number of active molecules in support set; torch.tensor; + dim: [batch-size] + - number of inactive molecules in support set; torch.tensor; + dim: [batch-size] + + return: + tuple which includes the updated representations for query, active, and inactive + support set molecules: + (query, active support set molecules, inactive support set molecules) + query_embedding, support_actives_embedding, support_inactives_embedding + """ + + # Embedding dim of query and support set molecules + embedding_dim = support_actives_embedding.shape[2] + + # Add activity encoding to representations + # Activity encoding: + # - active: 1 + # - inactive: -1 + # - unknown (query): 0 + query_embedding = torch.cat( + [ + query_embedding, + torch.zeros_like( + query_embedding[ + :, :, : self.cfg.model.transformer.activity_embedding_dim + ] + ), + ], + dim=2, + ) + + support_actives_embedding = torch.cat( + [ + support_actives_embedding, + torch.ones_like( + support_actives_embedding[ + :, :, : self.cfg.model.transformer.activity_embedding_dim + ] + ), + ], + dim=2, + ) + + support_inactives_embedding = torch.cat( + [ + support_inactives_embedding, + (-1.0) + * torch.ones_like( + support_inactives_embedding[ + :, :, : self.cfg.model.transformer.activity_embedding_dim + ] + ), + ], + dim=2, + ) + + # Concatenate query and support set molecules + s = torch.cat( + [query_embedding, support_actives_embedding, support_inactives_embedding], + dim=1, + ) + + # Run transformer and update representations + s_h = self.transformer_block(s) + s_updated = s + s_h + + # Split representations into query, active, and inactive support set molecules + query_embedding = s_updated[:, 0, :embedding_dim] + query_embedding = torch.unsqueeze(query_embedding, 1) + support_actives_embedding = s_updated[ + :, 1 : (support_actives_embedding.shape[1] + 1), :embedding_dim + ] + support_inactives_embedding = s_updated[ + :, (support_actives_embedding.shape[1] + 1) :, :embedding_dim + ] + + return query_embedding, support_actives_embedding, support_inactives_embedding + + #----------------------------------------------------------------------------------- + # Sub-modules + class TranformerBlock(nn.Module): + + def __init__(self, config): + super().__init__() + self.ln_1 = self.LayerNorm(config.n_embd, bias=config.bias) + self.attn = self.SelfAttention(config) + self.ln_2 = self.LayerNorm(config.n_embd, bias=config.bias) + self.mlp = self.MLP(config) + + def forward(self, x): + x = x + self.attn(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + class LayerNorm(nn.Module): + """ + LayerNorm but with an optional bias. PyTorch doesn't support simply + bias=False + """ + + def __init__(self, ndim, bias): + super().__init__() + self.weight = nn.Parameter(torch.ones(ndim)) + self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None + + def forward(self, input): + return F.layer_norm(input, self.weight.shape, self.weight, self.bias, + 1e-5) + + class SelfAttention(nn.Module): + """ + Self Attention Block + """ + + def __init__(self, config): + super().__init__() + + self.cfg = config + + # query, key, value projections + self.q_proj = nn.Linear(config.n_embd, config.n_qk_proj*config.n_head, + bias=config.bias) + self.k_proj = nn.Linear(config.n_embd, config.n_qk_proj*config.n_head, + bias=config.bias) + self.v_proj = nn.Linear(config.n_embd, config.n_v_proj*config.n_head, + bias=config.bias) + + # output projection + self.c_proj = nn.Linear(config.n_v_proj*config.n_head, config.n_embd, + bias=config.bias) + + # regularization + self.attn_dropout = nn.Dropout(config.dropout) + self.resid_dropout = nn.Dropout(config.dropout) + self.n_head = config.n_head + self.n_embd = config.n_embd + self.dropout = config.dropout + + def forward(self, x): + B, T, C = x.size() # batch size, sequence length, embedding dim (n_embd) + + # Calculate queries, keys, and values + q = self.q_proj(x).view(B, T, self.n_head, self.cfg.n_qk_proj + ).transpose(1, 2) # (B, nh, T, hs) + k = self.k_proj(x).view(B, T, self.n_head, self.cfg.n_qk_proj + ).transpose(1, 2) # (B, nh, T, hs) + v = self.v_proj(x).view(B, T, self.n_head, self.cfg.n_v_proj + ).transpose(1, 2) # (B, nh, T, hs) + + # Calculate self-attentions + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + + # Activations + att = F.softmax(att, dim=-1) + att = self.attn_dropout(att) + y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) + + # Re-assemble all head outputs side by side + y = y.transpose(1, 2).contiguous().view(y.shape[0], + y.shape[2], + -1) + + # output projection + y = self.resid_dropout(self.c_proj(y)) + return y + + class MLP(nn.Module): + + def __init__(self, config): + super().__init__() + self.c_fc = nn.Linear(config.n_embd, 567, bias=config.bias) + self.relu = nn.ReLU() + self.c_proj = nn.Linear(567, config.n_embd, bias=config.bias) + self.dropout = nn.Dropout(config.dropout) + + def forward(self, x): + x = self.c_fc(x) + x = self.relu(x) + x = self.c_proj(x) + x = self.dropout(x) + return x + + class GPTConfig: + n_head: int = 8 + n_embd: int = 1088 + dropout: float = 0. + bias: bool = True + n_qk_proj: int = 136 + n_v_proj: int = 136 + + +def SimilarityModule( + query_embedding: torch.Tensor, + support_set_embeddings: torch.Tensor, + support_set_size: torch.Tensor, + cfg: OmegaConf, +) -> torch.Tensor: + """ + The similarity module builds the activity prediction for the query molecule from a + weighted sum over the support set labels. Pair-wise similarity values between query + and support set molecules are used as weights for the weighted sum. + + Since the similarity module is applied twice within the MHNfs model - once for the + active and once for the inactive support set molecules, the support_set_embeddings + here mean ether active or inactive support set molecule embeddings. + + inputs: + - query; torch.tensor; + dim: [batch-size, 1, embedding-dimension] + * e.g.: [512, 1, 1024] + - support set molecules; torch.tensor; + dim: [batch-size, padding-dim, embedding-dimension] + * e.g.: [512, 9, 1024] + - padding mask; torch.tensor; boolean + dim: [batch-size, padding-dim] + * e.g.: [512, 9] + - support set size; torch.tensor; + dim: [batch-size] + * e.g.: [512] + """ + + # Optional L2-norm + if cfg.model.similarityModule.l2Norm: + query_embedding_div = torch.unsqueeze( + query_embedding.pow(2).sum(dim=2).sqrt(), 2 + ) + query_embedding_div[query_embedding_div == 0] = 1 + support_set_embeddings_div = torch.unsqueeze( + support_set_embeddings.pow(2).sum(dim=2).sqrt(), 2 + ) + support_set_embeddings_div[support_set_embeddings_div == 0] = 1 + + query_embedding = query_embedding / query_embedding_div + support_set_embeddings = support_set_embeddings / support_set_embeddings_div + + # Compute similarity values + similarities = query_embedding @ torch.transpose(support_set_embeddings, 1, 2) + # dim: + # [batch-size, 1, padding-dim] = + # [batch-size, 1, emb-dim] x [batch-size, emb-dim, padding-dim] + + # Compute similarity values + similarities[torch.isnan(similarities)] = 0.0 + similarity_sums = similarities.sum( + dim=2 + ) # For every query molecule: Sum over support set molecules + + # Scaling + if cfg.model.similarityModule.scaling == "1/N": + stabilizer = torch.tensor(1e-8).float() + similarity_sums = ( + 1 / (2.0 * support_set_size.reshape(-1, 1) + stabilizer) * similarity_sums + ) + if cfg.model.similarityModule.scaling == "1/sqrt(N)": + stabilizer = torch.tensor(1e-8).float() + similarity_sums = ( + 1 + / (2.0 * torch.sqrt(support_set_size.reshape(-1, 1).float()) + stabilizer) + * similarity_sums + ) + + return similarity_sums + +# -------------------------------------------------------------------------------------- diff --git a/src/prediction_pipeline.py b/src/prediction_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..30b650213a13e562891f442e3b99d619be0cc7c6 --- /dev/null +++ b/src/prediction_pipeline.py @@ -0,0 +1,104 @@ +""" +This module provides a simple predict function for the MHNfs model. +It loads the model from the provided checkpoint, creates necessary helper inputs +and makes predictions for a list of molecules +""" + +#--------------------------------------------------------------------------------------- +# Dependencies +import pandas as pd +import pytorch_lightning as pl +import streamlit as st + +from src.data_preprocessing.create_model_inputs import (create_query_input, + create_support_set_input) +from src.mhnfs.model import MHNfs + +#--------------------------------------------------------------------------------------- +# Define predictor class + +class ActivityPredictor: + + def __init__(self): + + @st.cache_resource # Caching for streamlit + def load_model(): + pl.seed_everything(1234) + current_loc = __file__.rsplit("/",2)[0] + model = MHNfs.load_from_checkpoint(current_loc + + "/assets/mhnfs_data/" + "mhnfs_checkpoint.ckpt") + model._update_context_set_embedding() + model.eval() + + return model + + # Load model + self.model = load_model() + + # Initiate query mol storage + self.query_molecules = None + + def predict(self, query_smiles, support_activces_smiles, support_inactives_smiles): + + # Create model inputs + # Query input + self.query_molecules = query_smiles + query_input = create_query_input(query_smiles) + + # Active support set input + support_actives_input, support_actives_size = create_support_set_input( + support_activces_smiles + ) + + # Inactive support set input + support_inactives_input, support_inactives_size = create_support_set_input( + support_inactives_smiles + ) + + # Make predictions + predictions = self.model( + query_input, + support_actives_input, + support_inactives_input, + support_actives_size, + support_inactives_size, + ) + + preds_numpy = predictions.detach().numpy().flatten() + + + return preds_numpy + + def _return_query_mols_as_list(self): + if isinstance(self.query_molecules, list): + return self.query_molecules + elif isinstance(self.query_molecules, str): + smiles_list = self.query_molecules.split(",") + smiles_list_cleaned = [smiles.strip() for smiles in smiles_list] + return smiles_list_cleaned + elif isinstance(self.query_molecules, pd.DataFrame): + return self.query_molecules.smiles.tolist() + elif isinstance(self.query_molecules, type(None)): + raise ValueError("No query molecules have been stored yet." + "Run predict-function first.") + else: + raise TypeError("Type of query molecules not recognized." + "Please check input type.") + +#--------------------------------------------------------------------------------------- +if __name__ == "__main__": + # Create predictor + predictor = ActivityPredictor() + + # Create example inputs + query_smiles = ["C1CCCCC1", "C1CCCCC1", "C1CCCCC1", "C1CCCCC1"] + support_actives_smiles = ["C1CCCCC1", "C1CCCCC1"] + support_inactives_smiles = ["C1CCCCC1", "C1CCCCC1"] + + # Make predictions + predictions = predictor.predict(query_smiles, + support_actives_smiles, + support_inactives_smiles) + + print(predictions) \ No newline at end of file diff --git a/src/tests/__init__.py b/src/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/tests/__pycache__/__init__.cpython-36.pyc b/src/tests/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..51b83385d06b9f52ca156bc8bc90a2c89a063e22 Binary files /dev/null and b/src/tests/__pycache__/__init__.cpython-36.pyc differ diff --git a/src/tests/__pycache__/__init__.cpython-37.pyc b/src/tests/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a365866a2700e17b4118ea7cea62d82526a19960 Binary files /dev/null and b/src/tests/__pycache__/__init__.cpython-37.pyc differ diff --git a/src/tests/__pycache__/conftest.cpython-36-pytest-7.0.1.pyc b/src/tests/__pycache__/conftest.cpython-36-pytest-7.0.1.pyc new file mode 100644 index 0000000000000000000000000000000000000000..92991581878b602e1b74053ba7d44f3ef1eb1de6 Binary files /dev/null and b/src/tests/__pycache__/conftest.cpython-36-pytest-7.0.1.pyc differ diff --git a/src/tests/__pycache__/conftest.cpython-37-pytest-7.1.2.pyc b/src/tests/__pycache__/conftest.cpython-37-pytest-7.1.2.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f6e2a0fa9c1682595dc7bc00d5f0aecbc34ccc7d Binary files /dev/null and b/src/tests/__pycache__/conftest.cpython-37-pytest-7.1.2.pyc differ diff --git a/src/tests/__pycache__/test_data_preprocessing.cpython-36-pytest-7.0.1.pyc b/src/tests/__pycache__/test_data_preprocessing.cpython-36-pytest-7.0.1.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c23c0b8683506876a4c833bbc626b35bce3d644b Binary files /dev/null and b/src/tests/__pycache__/test_data_preprocessing.cpython-36-pytest-7.0.1.pyc differ diff --git a/src/tests/__pycache__/test_data_preprocessing.cpython-37-pytest-7.1.2.pyc b/src/tests/__pycache__/test_data_preprocessing.cpython-37-pytest-7.1.2.pyc new file mode 100644 index 0000000000000000000000000000000000000000..68e709e1abd86012ba3af7696cd2abc6b68ab942 Binary files /dev/null and b/src/tests/__pycache__/test_data_preprocessing.cpython-37-pytest-7.1.2.pyc differ diff --git a/src/tests/__pycache__/test_prediction_pipeline_model_preds.cpython-37-pytest-7.1.2.pyc b/src/tests/__pycache__/test_prediction_pipeline_model_preds.cpython-37-pytest-7.1.2.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5ad965de65f1e148f09275382c2d88fe6958c154 Binary files /dev/null and b/src/tests/__pycache__/test_prediction_pipeline_model_preds.cpython-37-pytest-7.1.2.pyc differ diff --git a/src/tests/conftest.py b/src/tests/conftest.py new file mode 100644 index 0000000000000000000000000000000000000000..856af48f074ef6ee8d18c9cd7a92c82445680f8d --- /dev/null +++ b/src/tests/conftest.py @@ -0,0 +1,100 @@ +""" +Needed objects for tests +""" + +#--------------------------------------------------------------------------------------- +# Dependencies +import pytest +import pandas as pd +import pickle +import numpy as np +import torch + +from data_preprocessing.create_descriptors import create_cleaned_mol_objects + +#--------------------------------------------------------------------------------------- +# Define fixtures + +#--------------------------------------------------------------------------------------- +# Data preprocessing +@pytest.fixture(scope="session") +def input_molecule_formats(): + class Formats: + smiles = "CCO" + smiles_coma = "CCO, CCN" + smiles_list = ["CCO", "CCN"] + smiles_df = pd.DataFrame({"smiLES": ["CCO", "CCN"]}) + smiles_df_wrong_key = pd.DataFrame({"notSMILES": ["CCO", "CCN"]}) + return Formats() + +@pytest.fixture(scope="session") +def input_smiles(): + current_loc = __file__.rsplit("/",3)[0] + with open(current_loc + "/assets/test_reference_data/smiles.pkl", "rb") as fl: + input_smiles = pickle.load(fl) + return input_smiles + +@pytest.fixture(scope="session") +def input_mols_from_smiles(): + current_loc = __file__.rsplit("/",3)[0] + with open(current_loc + "/assets/test_reference_data/smiles.pkl", "rb") as fl: + input_smiles = pickle.load(fl) + + input_molecules = create_cleaned_mol_objects(input_smiles) + return input_molecules + +@pytest.fixture(scope="session") +def ecfps_from_smiles(): + current_loc = __file__.rsplit("/",3)[0] + ecfps = np.load(current_loc + "/assets/test_reference_data/ecfps.npy") + return ecfps + +@pytest.fixture(scope="session") +def rdkit_descrs_from_smiles(): + current_loc = __file__.rsplit("/",3)[0] + rdkit_descrs = np.load(current_loc + "/assets/test_reference_data/rdkit_descrs.npy") + return rdkit_descrs + +@pytest.fixture(scope="session") +def rdkit_descr_quantils(): + current_loc = __file__.rsplit("/",3)[0] + rdkit_descr_quantils = np.load( + current_loc + "/assets/test_reference_data/rdkit_descr_quantils.npy") + return rdkit_descr_quantils + +@pytest.fixture(scope="session") +def preprocessed_features(): + current_loc = __file__.rsplit("/",3)[0] + preprocessed_features = np.load( + current_loc + "/assets/test_reference_data/preprocessed_features.npy") + return preprocessed_features + +#--------------------------------------------------------------------------------------- +# Model +@pytest.fixture(scope="session") +def model_input_query(): + current_loc = __file__.rsplit("/",3)[0] + model_input_query = torch.load( + current_loc + "/assets/test_reference_data/model_input_query.pt") + return model_input_query + +@pytest.fixture(scope="session") +def model_input_support_actives(): + current_loc = __file__.rsplit("/",3)[0] + model_input_support_actives = torch.load( + current_loc + "/assets/test_reference_data/model_input_support_actives.pt") + return model_input_support_actives + +@pytest.fixture(scope="session") +def model_input_support_inactives(): + current_loc = __file__.rsplit("/",3)[0] + model_input_support_inactives = torch.load( + current_loc + "/assets/test_reference_data/model_input_support_inactives.pt") + return model_input_support_inactives + +@pytest.fixture(scope="session") +def model_predictions(): + current_loc = __file__.rsplit("/",3)[0] + model_predictions = torch.load( + current_loc + "/assets/test_reference_data/model_predictions.pt") + return model_predictions \ No newline at end of file diff --git a/src/tests/test_data_preprocessing.py b/src/tests/test_data_preprocessing.py new file mode 100644 index 0000000000000000000000000000000000000000..eaf747764cfb9c7aa4ec4a51f79d5b6d8f000b34 --- /dev/null +++ b/src/tests/test_data_preprocessing.py @@ -0,0 +1,113 @@ +""" +This file includes all tests for the data_preprocessing module. +""" + +import pytest +import numpy as np +import pickle +from data_preprocessing.create_descriptors import (handle_inputs, + create_ecfp_fps, + create_rdkit_descriptors, + create_quantils, + preprocess_molecules) + +class TestPreprocessMolecules: + + def test_handle_inputs(self, input_molecule_formats): + """ + This functions check whether all 3 possible input formats are correctly + transformed into list. + """ + + # Check 1: Smiles + output_smiles = handle_inputs(input_molecule_formats.smiles) + assert isinstance(output_smiles, list) + + # Check 2: Smiles coma + output_smiles_coma = handle_inputs(input_molecule_formats.smiles_coma) + assert isinstance(output_smiles_coma, list) + assert output_smiles_coma == input_molecule_formats.smiles_list + + # Check 3: Smiles list + output_smiles_list = handle_inputs(input_molecule_formats.smiles_list) + assert isinstance(output_smiles_list, list) + + # Check 4.1: Correct DataFrame + output_smiles_df = handle_inputs(input_molecule_formats.smiles_df) + assert isinstance(output_smiles_df, list) + + # Check 4.2: Wrong DataFrame + with pytest.raises(ValueError): + handle_inputs(input_molecule_formats.smiles_df_wrong_key) + + def test_create_ecfps_fps(self, input_mols_from_smiles, ecfps_from_smiles): + """ + This function tests whether the ECFP fingerprints are correctly created. + """ + + # Check 1: Correct output type + output_ecfps = create_ecfp_fps(input_mols_from_smiles) + assert isinstance(output_ecfps, np.ndarray) + + # Check 2: Correct output shape + assert output_ecfps.shape == ecfps_from_smiles.shape + + # Check 3: Correct output values + assert np.allclose(output_ecfps, ecfps_from_smiles, 0, 0) + + def test_create_rdkit_descriptors(self, input_mols_from_smiles, + rdkit_descrs_from_smiles): + """ + This function tests whether the RDKit descriptors are correctly created. + """ + + # Check 1: Correct output type + output_rdkit_descrs = create_rdkit_descriptors(input_mols_from_smiles) + assert isinstance(output_rdkit_descrs, np.ndarray) + + # Check 2: Correct output shape + assert output_rdkit_descrs.shape == rdkit_descrs_from_smiles.shape + + # Check 3: Correct output values + assert np.allclose(output_rdkit_descrs, rdkit_descrs_from_smiles) + + def test_create_quantils(self, input_mols_from_smiles, rdkit_descr_quantils): + """ + This function tests whether the quantils are correctly created. + """ + current_loc = __file__.rsplit("/",3)[0] + with open(current_loc + "/assets/data_preprocessing_objects/ecdfs.pkl", + "rb") as fl: + ecdfs = pickle.load(fl) + + rdkit_descrs = create_rdkit_descriptors(input_mols_from_smiles) + output_quantils = create_quantils(rdkit_descrs, ecdfs) + + # Check 1: Correct output type + assert isinstance(output_quantils, np.ndarray) + + # Check 2: Correct output shape + assert output_quantils.shape == rdkit_descr_quantils.shape + + # Check 3: Correct output values + assert np.allclose(output_quantils, rdkit_descr_quantils) + + def test_preprocess_molecules(self, input_smiles, + preprocessed_features): + """ + This function tests whether the preprocessing of molecules is correctly + done. + """ + + # Check 1: Correct output type + output_preprocessed_features = preprocess_molecules(input_smiles) + assert isinstance(output_preprocessed_features, np.ndarray) + + # Check 2: Correct output shape + assert output_preprocessed_features.shape == preprocessed_features.shape + + # Check 3: Correct output values + assert np.allclose(output_preprocessed_features, preprocessed_features) + + + \ No newline at end of file diff --git a/src/tests/test_prediction_pipeline_model_preds.py b/src/tests/test_prediction_pipeline_model_preds.py new file mode 100644 index 0000000000000000000000000000000000000000..36bb640f6cc55565ae7f75a717869ea15800138c --- /dev/null +++ b/src/tests/test_prediction_pipeline_model_preds.py @@ -0,0 +1,88 @@ +""" +This file tests whether the model predictions for MHNfs match the predictions made on +the JKU development server (varified model, server conda env with spec. packages ...) +""" + +#--------------------------------------------------------------------------------------- +# Dependencies +import pytest +import torch +import pandas as pd +from prediction_pipeline import ActivityPredictor + +#--------------------------------------------------------------------------------------- +# Define tests + +class TestActivityPredictor: + + def test_mhnfs_prediction(self, model_input_query, model_input_support_actives, + model_input_support_inactives, model_predictions): + + # Load model + predictor = ActivityPredictor() + + # Define additional inputs to model - i.e. support set sizes + support_actives_size = torch.tensor(model_input_support_actives.shape[1]) + support_inactives_size = torch.tensor(model_input_support_inactives.shape[1]) + + # Make predictions + predictions = predictor.model( + model_input_query, + model_input_support_actives, + model_input_support_inactives, + support_actives_size, + support_inactives_size + ).detach() + + # Compare predictions + assert torch.allclose(predictions, model_predictions, atol=0.01, rtol=0.) + + def test_query_mol_return(self): + + # Support set + support_actives_smiles = ["CCCCCCCCNC(C)C(O)c1ccc(SC(C)C)cc1"] + support_inactives_smiles = ["CCN(CC)C(=S)SSC(=S)N(CC)CCCCC"] + + # Load activity predictor + predictor = ActivityPredictor() + + # Check 1: Query mols given as a list + query_smiles = ["CCCCCCCCNC(C)C(O)c1ccc(SC(C)C)cc1", + "CCN(CC)C(=S)SSC(=S)N(CC)CC"] + + _ = predictor.predict(query_smiles, support_actives_smiles, + support_inactives_smiles) + query_output = predictor._return_query_mols_as_list() + assert query_output == query_smiles + + # Check 2: Query mols given as a string + query_smiles_str = ("CCCCCCCCNC(C)C(O)c1ccc(SC(C)C)cc1," + "CCN(CC)C(=S)SSC(=S)N(CC)CC") + _ = predictor.predict(query_smiles_str, support_actives_smiles, + support_inactives_smiles) + query_output = predictor._return_query_mols_as_list() + assert query_output == query_smiles + + # Check 3: Query mols given as a pd.Series + query_smiles_series = pd.DataFrame({"smiles": + ["CCCCCCCCNC(C)C(O)c1ccc(SC(C)C)cc1", "CCN(CC)C(=S)SSC(=S)N(CC)CC"]}) + _ = predictor.predict(query_smiles_series, support_actives_smiles, + support_inactives_smiles) + query_output = predictor._return_query_mols_as_list() + assert query_output == query_smiles + + # Check 4: Query molecules storage is None + predictor.query_molecules = None + with pytest.raises(ValueError): + predictor._return_query_mols_as_list() + + # Check 5: Other data types + predictor.query_molecules = 123 # any other data type + with pytest.raises(TypeError): + predictor._return_query_mols_as_list() + + + + + +