diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..5407cc64f2b48d0da5303963952efa61c11d22da 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +figs/mirror-frontpage.png filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..5a95fee2cbe69e506ba780dda73574822b7042ec --- /dev/null +++ b/.gitignore @@ -0,0 +1,187 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +.DS_Store +._.DS_Store +debug.py +outputs/ +resources/NER/msra/cache/ +resources/NER/msra/mrc/ +resources/NER/msra/formatted/ +resources/MRC/cmrc2018/cache/ +resources/MRC/cmrc2018/formatted/ +cache/*.cache +resources/MRC/DuReader-*/ +resources/**/*.json +resources/**/*.jsonl +resources/**/*.zip +resources/**/*.tsv +resources/**/*.xml +resources/**/raw/ +resources.tar.gz +debug/ +debug.json +mirror_outputs/ +sampled_stats.xlsx +mirror_fewshot_outputs/ +conll03-100.jsonl +tmp*/ +resources/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1da8d0f898fec24d4e2a98814a9e8d3398ea7a4f --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,19 @@ +repos: +- repo: https://github.com/pycqa/isort + rev: 5.12.0 + hooks: + - id: isort + name: isort (python) + args: ["--profile", "black", "--filter-files"] +- repo: https://github.com/psf/black + rev: 22.12.0 + hooks: + - id: black +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.4.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-yaml + - id: check-added-large-files + args: [--maxkb=900] diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..7b8a5b210ed30226cc558b91f53250629576e154 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2023 + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/Makefile b/Makefile new file mode 100644 index 0000000000000000000000000000000000000000..a153a772c60e50e5b644a0aec1110dbb28d81349 --- /dev/null +++ b/Makefile @@ -0,0 +1,26 @@ +all: format clean test pre + echo 'finished' + +.PHONY: format +format: + isort --profile black --filter-files . + black . + +.PHONY: test +test: + coverage run --source src -m pytest -vv . + coverage report -m + flake8 + +.PHONY: pre +pre: + pre-commit run --all-files + +.PHONY: clean +clean: + rm -rf build/ + rm -rf dist/ + rm -rf *.egg-info/ + rm -f .coverage + rm -f coverage.xml + find . | grep -E '(__pycache__|\.pyc|\.pyo$$)' | xargs rm -rf diff --git a/README.md b/README.md index 6567813f2d2afe8111afe15240ab8292b4eba473..65f75b91bdff8fa2059a3c1d1e4e1229140d297e 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,143 @@ --- title: Mirror -emoji: ๐Ÿ‘€ -colorFrom: green -colorTo: red +emoji: ๐Ÿชž +colorFrom: blue +colorTo: yellow sdk: gradio sdk_version: 4.1.2 -app_file: app.py -pinned: false +app_file: src/app/gradio_app.py +pinned: true license: apache-2.0 --- -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +
+

๐Ÿชž Mirror: A Universal Framework for Various Information Extraction Tasks

+ Magic mirror
+ Image generated by DALLE 3
+ + [Paper] | [Demo]
+ ๐Ÿ“ƒ Our paper has been accepted to EMNLP23 main conference, check it out!
+
+ +
+ +๐Ÿ˜Ž: This is the official implementation of [๐ŸชžMirror](https://arxiv.org/abs/2311.05419) which supports *almost* all the Information Extraction tasks. + +The name, Mirror, comes from the classical story *Snow White and the Seven Dwarfs*, where a magic mirror knows everything in the world. +We aim to build such a powerful tool for the IE community. + +## ๐Ÿ”ฅ Supported Tasks + +1. Named Entity Recognition +2. Entity Relationship Extraction (Triplet Extraction) +3. Event Extraction +4. Aspect-based Sentiment Analysis +5. Multi-span Extraction (e.g. Discontinuous NER) +6. N-ary Extraction (e.g. Hyper Relation Extraction) +7. Extractive Machine Reading Comprehension (MRC) and Question Answering +8. Classification & Multi-choice MRC + +![System Comparison](figs/sys-comparison.png) + +## ๐ŸŒด Dependencies + +Python>=3.10 + +```bash +pip install -r requirements.txt +``` + +## ๐Ÿš€ QuickStart + +### Pretrained Model Weights & Datasets + +Download the pretrained model weights & datasets from [[OSF]](https://osf.io/kwsm4/?view_only=5b66734d88cf456b93f17b6bac8a44fb) . + +No worries, it's an anonymous link just for double blind peer reviewing. + +### Pretraining + +1. Download and unzip the pretraining corpus into `resources/Mirror/v1.4_sampled_v3/merged/all_excluded` +2. Start to run + +```bash +CUDA_VISIBLE_DEVICES=0 rex train -m src.task -dc conf/Pretrain_excluded.yaml +``` + +### Fine-tuning + +โš ๏ธ Due to data license constraints, some datasets are unavailable to provide directly (e.g. ACE04, ACE05). + +1. Download and unzip the pretraining corpus into `resources/Mirror/v1.4_sampled_v3/merged/all_excluded` +2. Download and unzip the fine-tuning datasets into `resources/Mirror/uie/` +3. Start to fine-tuning + +```bash +# UIE tasks +CUDA_VISIBLE_DEVICES=0 bash scripts/single_task_wPTAllExcluded_wInstruction/run1.sh +CUDA_VISIBLE_DEVICES=1 bash scripts/single_task_wPTAllExcluded_wInstruction/run2.sh +CUDA_VISIBLE_DEVICES=2 bash scripts/single_task_wPTAllExcluded_wInstruction/run3.sh +CUDA_VISIBLE_DEVICES=3 bash scripts/single_task_wPTAllExcluded_wInstruction/run4.sh +# Multi-span and N-ary extraction +CUDA_VISIBLE_DEVICES=4 bash scripts/single_task_wPTAllExcluded_wInstruction/run_new_tasks.sh +# GLUE datasets +CUDA_VISIBLE_DEVICES=5 bash scripts/single_task_wPTAllExcluded_wInstruction/glue.sh +``` + +### Analysis Experiments + +- Few-shot experiments : `scripts/run_fewshot.sh`. Collecting results: `python mirror_fewshot_outputs/get_avg_results.py` +- Mirror w/ PT w/o Inst. : `scripts/single_task_wPTAllExcluded_woInstruction` +- Mirror w/o PT w/ Inst. : `scripts/single_task_wo_pretrain` +- Mirror w/o PT w/o Inst. : `scripts/single_task_wo_pretrain_wo_instruction` + +### Evaluation + +1. Change `task_dir` and `data_pairs` you want to evaluate. The default setting is to get results of Mirrordirect on all downstream tasks. +2. `CUDA_VISIBLE_DEVICES=0 python -m src.eval` + +### Demo + +1. Download and unzip the pretrained task dump into `mirror_outputs/Mirror_Pretrain_AllExcluded_2` +2. Try our demo: + +```bash +CUDA_VISIBLE_DEVICES=0 python -m src.app.api_backend +``` + +![Demo](figs/mirror-demo.gif) + +## ๐Ÿ“‹ Citation + +```bibtex +@misc{zhu_mirror_2023, + shorttitle = {Mirror}, + title = {Mirror: A Universal Framework for Various Information Extraction Tasks}, + author = {Zhu, Tong and Ren, Junfei and Yu, Zijian and Wu, Mengsong and Zhang, Guoliang and Qu, Xiaoye and Chen, Wenliang and Wang, Zhefeng and Huai, Baoxing and Zhang, Min}, + url = {http://arxiv.org/abs/2311.05419}, + doi = {10.48550/arXiv.2311.05419}, + urldate = {2023-11-10}, + publisher = {arXiv}, + month = nov, + year = {2023}, + note = {arXiv:2311.05419 [cs]}, + keywords = {Computer Science - Artificial Intelligence, Computer Science - Computation and Language}, +} +``` + +## ๐Ÿ›ฃ๏ธ Roadmap + +- [ ] Convert current model into Huggingface version, supporting loading from `transformers` like other newly released LLMs. +- [ ] Remove `Background` area, merge `TL`, `TP` into a single `T` token +- [ ] Add more task data: keyword extraction, coreference resolution, FrameNet, WikiNER, T-Rex relation extraction dataset, etc. +- [ ] Pre-train on all the data (including benchmarks) to build a nice out-of-the-box toolkit for universal IE. + +## ๐Ÿ’Œ Yours sincerely + +This project is licensed under Apache-2.0. +We hope you enjoy it ~ + +
+
+

Mirror Team w/ ๐Ÿ’–

+
diff --git a/conf/Pretrain_excluded.yaml b/conf/Pretrain_excluded.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1c8d23a61cdecb8189933ec77c442448e84cf404 --- /dev/null +++ b/conf/Pretrain_excluded.yaml @@ -0,0 +1,51 @@ +# task +task_type: SchemaGuidedInstructBertTask +task_name: Mirror_Pretrain_AllExcluded_2 +comment: '~~content as label, (start, end + 1) span' + +# data preprocessing +max_seq_len: 512 +debug_mode: false +label_span: tag # tag `[LM]` or content `person` +mode: span # w2 (1,2,3) or span (1,3) +stream_mode: false + +# filepaths +plm_dir: microsoft/deberta-v3-large +data_dir: resources/Mirror/v1.4_sampled_v3/merged/all_excluded +output_dir: mirror_outputs +task_dir: ${output_dir}/${task_name} +train_filepath: ${data_dir}/train.jsonl +dev_filepath: ${data_dir}/dev.jsonl +test_filepath: ${data_dir}/test.jsonl +dump_cache_dir: ${task_dir}/cache +regenerate_cache: false + +# training +random_seed: 1227 +base_model_path: null +eval_on_data: [train] +select_best_on_data: train +select_best_by_key: loss +final_eval_on_test: false +save_every_ckpt: true +save_best_ckpt: true + +warmup_proportion: 0.1 +num_epochs: 3 +epoch_patience: -1 +num_steps: -1 +step_patience: -1 +step_eval_interval: 10000 +train_batch_size: 8 +eval_batch_size: 8 +grad_accum_steps: 1 +learning_rate: !!float 2e-5 +other_learning_rate: !!float 1e-4 +max_grad_norm: 1.0 +weight_decay: 0.1 + +# model +dropout: 0.3 +use_rope: true +biaffine_size: 512 diff --git a/conf/Pretrain_v1.5.yaml b/conf/Pretrain_v1.5.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8091120acbe10f010adb314d95370f7e33c37e35 --- /dev/null +++ b/conf/Pretrain_v1.5.yaml @@ -0,0 +1,51 @@ +# task +task_type: SchemaGuidedInstructBertTask +task_name: Mirror_Pretrain_DataV1.5_2 +comment: '~~content as label, (start, end + 1) span' + +# data preprocessing +max_seq_len: 512 +debug_mode: false +label_span: tag # tag `[LM]` or content `person` +mode: span # w2 (1,2,3) or span (1,3) +stream_mode: false + +# filepaths +plm_dir: microsoft/deberta-v3-large +data_dir: resources/Mirror/v1.5/merged/t-rex-200k +output_dir: mirror_outputs +task_dir: ${output_dir}/${task_name} +train_filepath: ${data_dir}/train.jsonl +dev_filepath: ${data_dir}/dev.jsonl +test_filepath: ${data_dir}/test.jsonl +dump_cache_dir: ${task_dir}/cache +regenerate_cache: false + +# training +random_seed: 1227 +base_model_path: null +eval_on_data: [train] +select_best_on_data: train +select_best_by_key: loss +final_eval_on_test: false +save_every_ckpt: true +save_best_ckpt: true + +warmup_proportion: 0.1 +num_epochs: 3 +epoch_patience: -1 +num_steps: -1 +step_patience: -1 +step_eval_interval: 10000 +train_batch_size: 8 +eval_batch_size: 8 +grad_accum_steps: 1 +learning_rate: !!float 2e-5 +other_learning_rate: !!float 1e-4 +max_grad_norm: 1.0 +weight_decay: 0.1 + +# model +dropout: 0.3 +use_rope: true +biaffine_size: 512 diff --git a/conf/Pretrain_v1.5_woInstruction.yaml b/conf/Pretrain_v1.5_woInstruction.yaml new file mode 100644 index 0000000000000000000000000000000000000000..91cab75be973e1f90c9b1d543a8542e2f121d104 --- /dev/null +++ b/conf/Pretrain_v1.5_woInstruction.yaml @@ -0,0 +1,51 @@ +# task +task_type: SchemaGuidedInstructBertTask +task_name: Mirror_Pretrain_DataV1.5_woInstruction +comment: '~~content as label, (start, end + 1) span' + +# data preprocessing +max_seq_len: 512 +debug_mode: false +label_span: tag # tag `[LM]` or content `person` +mode: span # w2 (1,2,3) or span (1,3) +stream_mode: false + +# filepaths +plm_dir: microsoft/deberta-v3-large +data_dir: resources/Mirror/v1.5/merged/t-rex-200k-woInstruction/remove_instruction +output_dir: mirror_outputs +task_dir: ${output_dir}/${task_name} +train_filepath: ${data_dir}/train.jsonl +dev_filepath: ${data_dir}/dev.jsonl +test_filepath: ${data_dir}/test.jsonl +dump_cache_dir: ${task_dir}/cache +regenerate_cache: false + +# training +random_seed: 1227 +base_model_path: null +eval_on_data: [train] +select_best_on_data: train +select_best_by_key: loss +final_eval_on_test: false +save_every_ckpt: true +save_best_ckpt: true + +warmup_proportion: 0.1 +num_epochs: 3 +epoch_patience: -1 +num_steps: -1 +step_patience: -1 +step_eval_interval: 10000 +train_batch_size: 8 +eval_batch_size: 8 +grad_accum_steps: 1 +learning_rate: !!float 2e-5 +other_learning_rate: !!float 1e-4 +max_grad_norm: 1.0 +weight_decay: 0.1 + +# model +dropout: 0.3 +use_rope: true +biaffine_size: 512 diff --git a/conf/Pretrain_woOverlapV2.yaml b/conf/Pretrain_woOverlapV2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8e68ef555106e91486618cba971e5b8f3e5f5fd3 --- /dev/null +++ b/conf/Pretrain_woOverlapV2.yaml @@ -0,0 +1,51 @@ +# task +task_type: SchemaGuidedInstructBertTask +task_name: Mirror_Pretrain_woOverlapV2 +comment: '~~content as label, (start, end + 1) span' + +# data preprocessing +max_seq_len: 512 +debug_mode: false +label_span: tag # tag `[LM]` or content `person` +mode: span # w2 (1,2,3) or span (1,3) +stream_mode: false + +# filepaths +plm_dir: microsoft/deberta-v3-large +data_dir: resources/Mirror/v1.4_sampled_v3/merged/all +output_dir: mirror_outputs +task_dir: ${output_dir}/${task_name} +train_filepath: ${data_dir}/train_wo_overlap_v2.jsonl +dev_filepath: ${data_dir}/dev.jsonl +test_filepath: ${data_dir}/test.jsonl +dump_cache_dir: ${task_dir}/cache +regenerate_cache: false + +# training +random_seed: 1227 +base_model_path: null +eval_on_data: [train] +select_best_on_data: train +select_best_by_key: loss +final_eval_on_test: false +save_every_ckpt: true +save_best_ckpt: true + +warmup_proportion: 0.1 +num_epochs: 3 +epoch_patience: -1 +num_steps: -1 +step_patience: -1 +step_eval_interval: 10000 +train_batch_size: 8 +eval_batch_size: 8 +grad_accum_steps: 1 +learning_rate: !!float 2e-5 +other_learning_rate: !!float 1e-4 +max_grad_norm: 1.0 +weight_decay: 0.1 + +# model +dropout: 0.3 +use_rope: true +biaffine_size: 512 diff --git a/conf/ac/g1_dpspd.yaml b/conf/ac/g1_dpspd.yaml new file mode 100644 index 0000000000000000000000000000000000000000..af21ebdeb172b58c62fef4f4a3b95bc535d6e385 --- /dev/null +++ b/conf/ac/g1_dpspd.yaml @@ -0,0 +1,18 @@ +compute_environment: LOCAL_MACHINE +deepspeed_config: + gradient_accumulation_steps: 1 + zero3_init_flag: false + zero_stage: 1 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +machine_rank: 0 +main_training_function: main +mixed_precision: 'no' +num_machines: 1 +num_processes: 1 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/conf/ac/g1_dpspd_fp16.yaml b/conf/ac/g1_dpspd_fp16.yaml new file mode 100644 index 0000000000000000000000000000000000000000..381b21d64e7eeb23734f041ca1ed2c7a9c65dd52 --- /dev/null +++ b/conf/ac/g1_dpspd_fp16.yaml @@ -0,0 +1,18 @@ +compute_environment: LOCAL_MACHINE +deepspeed_config: + gradient_accumulation_steps: 4 + zero3_init_flag: false + zero_stage: 1 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +machine_rank: 0 +main_training_function: main +mixed_precision: fp16 +num_machines: 1 +num_processes: 1 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/conf/cadec.yaml b/conf/cadec.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4186d74890b5c8869c414ea1ebb61a986e40dc30 --- /dev/null +++ b/conf/cadec.yaml @@ -0,0 +1,3 @@ +task_name: Mirror_SingleTask_DiscontinuousNER_CADEC +data_dir: resources/Mirror/new_abilities_v2/cadec/new +best_metric_field: discontinuous_ent.micro.f1 diff --git a/conf/hyperred.yaml b/conf/hyperred.yaml new file mode 100644 index 0000000000000000000000000000000000000000..78e4089063ff593d5ce331d56cea2e99798e781d --- /dev/null +++ b/conf/hyperred.yaml @@ -0,0 +1,3 @@ +task_name: Mirror_SingleTask_HyperRel_HyperRED +data_dir: resources/Mirror/new_abilities_v2/HyperRED/new +best_metric_field: hyper_rel.micro.f1 diff --git a/conf/merge_all_data.yaml b/conf/merge_all_data.yaml new file mode 100644 index 0000000000000000000000000000000000000000..45ec8711892db5a9b966d99f39cdafefbb59f4c7 --- /dev/null +++ b/conf/merge_all_data.yaml @@ -0,0 +1,6 @@ +task_name: InstructBert_MergedAllData +data_dir: resources/Mirror/v1.3/merged_pretrained_data +train_filepath: ${data_dir}/train.jsonl +dev_filepath: resources/Mirror/v1.3/uie_data/dev.jsonl +test_filepath: resources/Mirror/v1.3/uie_data/test.jsonl +num_epochs: 1 diff --git a/conf/merge_analysis_data.yaml b/conf/merge_analysis_data.yaml new file mode 100644 index 0000000000000000000000000000000000000000..eff69048bafd4a7039c7f53995f404ac12741396 --- /dev/null +++ b/conf/merge_analysis_data.yaml @@ -0,0 +1,18 @@ +task_name: Mirror_MultiTask_Analysis +plm_dir: microsoft/deberta-v3-large + +data_dir: resources/Mirror/uie/merged_analysis +train_filepath: ${data_dir}/train.jsonl +dev_filepath: ${data_dir}/dev.jsonl +test_filepath: ${data_dir}/test.jsonl +num_epochs: 20 +epoch_patience: 3 +regenerate_cache: true + +eval_on_data: [dev] +select_best_on_data: dev +select_best_by_key: metric +best_metric_field: general_spans.micro.f1 +final_eval_on_test: true + +base_model_path: null diff --git a/conf/merge_analysis_data_woInstruction.yaml b/conf/merge_analysis_data_woInstruction.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e5d3d7dad622641d49ff9ddd53b452324ac280f2 --- /dev/null +++ b/conf/merge_analysis_data_woInstruction.yaml @@ -0,0 +1,18 @@ +task_name: Mirror_MultiTask_Analysis_woInstruction +plm_dir: microsoft/deberta-v3-large + +data_dir: resources/Mirror/uie/merged_analysis/remove_instruction +train_filepath: ${data_dir}/train.jsonl +dev_filepath: ${data_dir}/dev.jsonl +test_filepath: ${data_dir}/test.jsonl +num_epochs: 20 +epoch_patience: 3 +regenerate_cache: true + +eval_on_data: [dev] +select_best_on_data: dev +select_best_by_key: metric +best_metric_field: general_spans.micro.f1 +final_eval_on_test: true + +base_model_path: null diff --git a/conf/merge_uie_data.yaml b/conf/merge_uie_data.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b2450e266315b54a07a5a772252acdf27fcb3953 --- /dev/null +++ b/conf/merge_uie_data.yaml @@ -0,0 +1,18 @@ +task_name: Mirror_woPT_NewMergedUIEData_woOverlap +plm_dir: microsoft/deberta-v3-large + +data_dir: resources/Mirror/uie/merged +train_filepath: ${data_dir}/train_wo_overlap.jsonl +dev_filepath: ${data_dir}/dev.jsonl +test_filepath: ${data_dir}/test.jsonl +num_epochs: 20 +epoch_patience: 3 +regenerate_cache: true + +eval_on_data: [dev] +select_best_on_data: dev +select_best_by_key: metric +best_metric_field: general_spans.micro.f1 +final_eval_on_test: true + +base_model_path: null diff --git a/conf/mirror-ace05en.yaml b/conf/mirror-ace05en.yaml new file mode 100644 index 0000000000000000000000000000000000000000..106103f6b9aef1747af2d843b98e500738bd8899 --- /dev/null +++ b/conf/mirror-ace05en.yaml @@ -0,0 +1,70 @@ +# task +task_type: SchemaGuidedInstructBertTask +task_name: InstructBert_TagSpan_DebertaV3Base_ACE05ENPlus +comment: '~~content as label, (start, end + 1) span' + +# data preprocessing +max_seq_len: 512 +debug_mode: false +label_span: tag # tag `[LM]` or content `person` +mode: span # w2 (1,2,3) or span (1,3) + +# filepaths +plm_dir: microsoft/deberta-v3-base +# plm_dir: bert-base-cased +# data_dir: resources/Mirror/Tasks/EE/ACE05-EN +# data_dir: resources/Mirror/Tasks/RE/merged-20230502-2340-v1 +# data_dir: resources/Mirror/Tasks/RE/merged-20230502-2358-v2-woADE +# data_dir: resources/Mirror/Tasks/EE/ACE05-EN-labelmap +data_dir: resources/Mirror/v1.3/event/en/ACE05-EN-plus/fixed_instructed +output_dir: outputs +task_dir: ${output_dir}/${task_name} +# train_filepath: ${data_dir}/ACE2005_plus_train.jsonl +# dev_filepath: ${data_dir}/ACE2005_plus_dev.jsonl +# test_filepath: ${data_dir}/ACE2005_plus_test.jsonl +# train_filepath: ${data_dir}/ACE2005_oneie_NER_train.jsonl +# dev_filepath: ${data_dir}/ACE2005_oneie_NER_dev.jsonl +# test_filepath: ${data_dir}/ACE2005_oneie_NER_test.jsonl +# train_filepath: ${data_dir}/ACE2005_oneie_RE_train.jsonl +# dev_filepath: ${data_dir}/ACE2005_oneie_RE_dev.jsonl +# test_filepath: ${data_dir}/ACE2005_oneie_RE_test.jsonl +# train_filepath: ${data_dir}/ACE2005_oneie_EE_train.jsonl +# dev_filepath: ${data_dir}/ACE2005_oneie_EE_dev.jsonl +# test_filepath: ${data_dir}/ACE2005_oneie_EE_test.jsonl +# train_filepath: ${data_dir}/ACE2005_oneie_train.jsonl +# dev_filepath: ${data_dir}/ACE2005_oneie_dev.jsonl +# test_filepath: ${data_dir}/ACE2005_oneie_test.jsonl +# train_filepath: ${data_dir}/train.jsonl +# dev_filepath: ${data_dir}/dev.jsonl +# test_filepath: ${data_dir}/test.jsonl +train_filepath: ${data_dir}/train.jsonl +dev_filepath: ${data_dir}/dev.jsonl +test_filepath: ${data_dir}/test.jsonl + +dump_cache_dir: ${task_dir}/cache +regenerate_cache: false + +# training +random_seed: 1227 +eval_on_data: [dev, test] +select_best_on_data: dev +select_best_by_key: metric +best_metric_field: general_spans.micro.f1 +final_eval_on_test: true +save_every_ckpt: false +save_best_ckpt: true + +warmup_proportion: 0.1 +num_epochs: 50 +epoch_patience: 5 +train_batch_size: 32 +eval_batch_size: 32 +learning_rate: !!float 3e-5 +other_learning_rate: !!float 3e-5 +max_grad_norm: 1.0 +weight_decay: 0.1 + +# model +dropout: 0.3 +use_rope: true +biaffine_size: 512 diff --git a/conf/mirror-multi-task-pretrain.yaml b/conf/mirror-multi-task-pretrain.yaml new file mode 100644 index 0000000000000000000000000000000000000000..22b0d039890e82c093114876038e1664ca798b5c --- /dev/null +++ b/conf/mirror-multi-task-pretrain.yaml @@ -0,0 +1,51 @@ +# task +task_type: SchemaGuidedInstructBertTask +task_name: MirrorLarge_SamplingPretrain_woLowResource_woOverlap +comment: '~~content as label, (start, end + 1) span' + +# data preprocessing +max_seq_len: 512 +debug_mode: false +label_span: tag # tag `[LM]` or content `person` +mode: span # w2 (1,2,3) or span (1,3) +stream_mode: false + +# filepaths +plm_dir: microsoft/deberta-v3-large +data_dir: resources/Mirror/v1.4_sampled_v3/merged/woLowResource +output_dir: mirror_outputs +task_dir: ${output_dir}/${task_name} +train_filepath: ${data_dir}/train_wo_overlap.jsonl +dev_filepath: ${data_dir}/dev.jsonl +test_filepath: ${data_dir}/test.jsonl +dump_cache_dir: ${task_dir}/cache +regenerate_cache: false + +# training +random_seed: 1227 +base_model_path: null +eval_on_data: [train] +select_best_on_data: train +select_best_by_key: loss +final_eval_on_test: false +save_every_ckpt: true +save_best_ckpt: true + +warmup_proportion: 0.1 +num_epochs: 1 +epoch_patience: -1 +num_steps: -1 +step_patience: -1 +step_eval_interval: 3000 +train_batch_size: 8 +eval_batch_size: 8 +grad_accum_steps: 1 +learning_rate: !!float 2e-5 +other_learning_rate: !!float 1e-4 +max_grad_norm: 1.0 +weight_decay: 0.1 + +# model +dropout: 0.3 +use_rope: true +biaffine_size: 512 diff --git a/conf/mrc.yaml b/conf/mrc.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9a967f09a984c9b51f07fb41873d7d586bee337d --- /dev/null +++ b/conf/mrc.yaml @@ -0,0 +1,43 @@ +# task +task_type: MrcQaTask +task_name: Mirror_RobertaBaseWwm_Cons_MsraMrc +comment: 'GlobalPointer with RoPE' + +# data preprocessing +max_seq_len: 512 +debug_mode: false +mode: cons + +# filepaths +plm_dir: hfl/chinese-roberta-wwm-ext +data_dir: resources/NER/msra/mrc +output_dir: outputs +task_dir: ${output_dir}/${task_name} +train_filepath: ${data_dir}/train.jsonl +dev_filepath: ${data_dir}/test.jsonl +test_filepath: ${data_dir}/test.jsonl +dump_cache_dir: ${task_dir}/cache +regenerate_cache: true + +# training +random_seed: 1227 +eval_on_data: [dev] +select_best_on_data: dev +select_best_by_key: metric +best_metric_field: micro.f1 +final_eval_on_test: true + +warmup_proportion: 0.1 +step_eval_interval: 20000 +step_patience: -1 +num_epochs: 5 +epoch_patience: 5 +train_batch_size: 32 +eval_batch_size: 64 +learning_rate: !!float 5e-5 +other_learning_rate: !!float 1e-4 +max_grad_norm: 1.0 + +# model +dropout: 0.3 +biaffine_size: 512 diff --git a/conf/ner.yaml b/conf/ner.yaml new file mode 100644 index 0000000000000000000000000000000000000000..349156646000fe19e98aea7fb005fd40a48514a3 --- /dev/null +++ b/conf/ner.yaml @@ -0,0 +1,45 @@ +# task +task_type: MrcTaggingTask +task_name: debug-Mirror_W2_MSRAv2_NER_FreezeBertEmbAnd0-3_bs64 +comment: 'bert mrc w/ w2ner for NER' + +# data preprocessing +max_seq_len: 300 +negative_sample_prob: 1.0 +debug_mode: false +mode: w2 + +# filepaths +base_model_path: outputs/RobertaBase_data20230314v2/ckpt/MrcGlobalPointerModel.best.pth +plm_dir: hfl/chinese-roberta-wwm-ext +data_dir: resources/NER/MSRA_v2/formatted +output_dir: outputs +task_dir: ${output_dir}/${task_name} +train_filepath: ${data_dir}/train.char.bmes.jsonl +dev_filepath: ${data_dir}/dev.char.bmes.jsonl +test_filepath: ${data_dir}/test.char.bmes.jsonl +ent_type2query_filepath: ${data_dir}/query.json +dump_cache_dir: ${task_dir}/cache +regenerate_cache: true + +# training +random_seed: 1227 +eval_on_data: [dev, test] +select_best_on_data: dev +select_best_by_key: metric +best_metric_field: micro.f1 +final_eval_on_test: true + +warmup_proportion: 0.1 +num_epochs: 5 +epoch_patience: 5 +train_batch_size: 64 +eval_batch_size: 128 +learning_rate: !!float 5e-5 +other_learning_rate: !!float 1e-4 +max_grad_norm: 1.0 +weight_decay: 0.1 + +# model +dropout: 0.3 +biaffine_size: 512 diff --git a/conf/nlu/cola.yaml b/conf/nlu/cola.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d7dfcdc810cdf630cabfddcc2e50a0095adf4882 --- /dev/null +++ b/conf/nlu/cola.yaml @@ -0,0 +1,6 @@ +task_name: Mirror_SingleTask_Cls_CoLA +data_dir: resources/Mirror/v1.3/cls/en/CoLA/formated +train_filepath: ${data_dir}/train.jsonl +dev_filepath: ${data_dir}/dev.jsonl +test_filepath: ${data_dir}/dev.jsonl +best_metric_field: cls.mcc diff --git a/conf/nlu/mnli.yaml b/conf/nlu/mnli.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d7ddb6a927dcd8510fceb376d9cb178a72b2838a --- /dev/null +++ b/conf/nlu/mnli.yaml @@ -0,0 +1,6 @@ +task_name: Mirror_SingleTask_Cls_MNLI +data_dir: resources/Mirror/v1.3/cls/en/MNLI/formated +train_filepath: ${data_dir}/MNLI_train.jsonl +dev_filepath: ${data_dir}/MNLI_dev.jsonl +test_filepath: ${data_dir}/MNLI_dev.jsonl +best_metric_field: cls.acc diff --git a/conf/nlu/mrpc.yaml b/conf/nlu/mrpc.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fc9841b3c6fbd3b36ddcd50dcb81a8958660439f --- /dev/null +++ b/conf/nlu/mrpc.yaml @@ -0,0 +1,3 @@ +task_name: Mirror_SingleTask_Cls_MRPC +data_dir: resources/Mirror/v1.3/cls/en/MRPC/formated +best_metric_field: cls.acc diff --git a/conf/nlu/plm.yaml b/conf/nlu/plm.yaml new file mode 100644 index 0000000000000000000000000000000000000000..770e53393bc86b8eb4a002424e7d91f7881bef51 --- /dev/null +++ b/conf/nlu/plm.yaml @@ -0,0 +1,19 @@ +plm_dir: microsoft/deberta-v3-large +base_model_path: mirror_outputs/Mirror_Pretrain_AllExcluded_2/ckpt/SchemaGuidedInstructBertModel.best.pth + +stream_mode: false +train_filepath: ${data_dir}/train.jsonl +dev_filepath: ${data_dir}/dev.jsonl +test_filepath: ${data_dir}/test.jsonl + +num_epochs: 5 +epoch_patience: -1 +num_steps: -1 +step_patience: -1 +step_eval_interval: -1 + +eval_on_data: [dev] +select_best_on_data: dev +select_best_by_key: metric +best_metric_field: general_spans.micro.f1 +final_eval_on_test: true diff --git a/conf/nlu/qnli.yaml b/conf/nlu/qnli.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b95f03aba58455d2a21585cc7f63751e7152694d --- /dev/null +++ b/conf/nlu/qnli.yaml @@ -0,0 +1,6 @@ +task_name: Mirror_SingleTask_Cls_QNLI +data_dir: resources/Mirror/v1.3/cls/en/QNLI/processed +train_filepath: ${data_dir}/QNLI_train.jsonl +dev_filepath: ${data_dir}/QNLI_dev.jsonl +test_filepath: ${data_dir}/QNLI_dev.jsonl +best_metric_field: cls.acc diff --git a/conf/nlu/qqp.yaml b/conf/nlu/qqp.yaml new file mode 100644 index 0000000000000000000000000000000000000000..62e64176c9e75536beee7adc1f912e8694e5f17a --- /dev/null +++ b/conf/nlu/qqp.yaml @@ -0,0 +1,6 @@ +task_name: Mirror_SingleTask_Cls_QQP +data_dir: resources/Mirror/v1.3/cls/en/QQP/new +train_filepath: ${data_dir}/train.jsonl +dev_filepath: ${data_dir}/dev.jsonl +test_filepath: ${data_dir}/dev.jsonl +best_metric_field: cls.acc diff --git a/conf/nlu/rte.yaml b/conf/nlu/rte.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c281cab7a5f8c5f0f816cd809066ab68f883c126 --- /dev/null +++ b/conf/nlu/rte.yaml @@ -0,0 +1,6 @@ +task_name: Mirror_SingleTask_Cls_RTE +data_dir: resources/Mirror/v1.3/cls/en/RTE/formated +train_filepath: ${data_dir}/RTE_train.jsonl +dev_filepath: ${data_dir}/RTE_dev.jsonl +test_filepath: ${data_dir}/RTE_dev.jsonl +best_metric_field: cls.acc diff --git a/conf/nlu/squad_v2.yaml b/conf/nlu/squad_v2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8fada03a197b25ee0976cfb29b041cabd8917513 --- /dev/null +++ b/conf/nlu/squad_v2.yaml @@ -0,0 +1,4 @@ +task_name: Mirror_SingleTask_MRC_SQuADv2 +data_dir: resources/Mirror/v1.3/span/en/squad_v2 +test_filepath: ${data_dir}/dev.jsonl +best_metric_field: span.f1.f1 diff --git a/conf/nlu/sst-2.yaml b/conf/nlu/sst-2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..873b510f4b2c90935643eda318d3a1bc12fba37e --- /dev/null +++ b/conf/nlu/sst-2.yaml @@ -0,0 +1,6 @@ +task_name: Mirror_SingleTask_Cls_SST2 +data_dir: resources/Mirror/v1.3/cls/en/SST-2/instructed +train_filepath: ${data_dir}/SST-2_train.jsonl +dev_filepath: ${data_dir}/SST-2_dev.jsonl +test_filepath: ${data_dir}/SST-2_dev.jsonl +best_metric_field: cls.acc diff --git a/conf/t-rex_pretrain.yaml b/conf/t-rex_pretrain.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9ba2f299aea8c43218609a8463afe14d7241c8e9 --- /dev/null +++ b/conf/t-rex_pretrain.yaml @@ -0,0 +1,9 @@ +task_name: InstructBert_TagSpan_DebertaV3Base_TRExPretrain +data_dir: resources/Mirror/v1.3/rel/en/T-REx/instructed +train_filepath: ${data_dir}/t-rex.udi.fix.jsonl + +num_epochs: 3 +eval_on_data: [train] +select_best_on_data: train +select_best_by_key: loss +final_eval_on_test: false diff --git a/conf/uie_data/absa_14lap.yaml b/conf/uie_data/absa_14lap.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ee860dcb9141cd9819cca6d26a60f845fa2db2f6 --- /dev/null +++ b/conf/uie_data/absa_14lap.yaml @@ -0,0 +1,3 @@ +task_name: Mirror_SingleTask_ABSA_14lap +data_dir: resources/Mirror/uie/absa/14lap +best_metric_field: rel.rel.micro.f1 diff --git a/conf/uie_data/absa_14res.yaml b/conf/uie_data/absa_14res.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e3447a0bc10d81d54389b4478499afe480b68273 --- /dev/null +++ b/conf/uie_data/absa_14res.yaml @@ -0,0 +1,3 @@ +task_name: Mirror_SingleTask_ABSA_14res +data_dir: resources/Mirror/uie/absa/14res +best_metric_field: rel.rel.micro.f1 diff --git a/conf/uie_data/absa_15res.yaml b/conf/uie_data/absa_15res.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d70f989e346d14ded78e5aa99f4fd17411236bea --- /dev/null +++ b/conf/uie_data/absa_15res.yaml @@ -0,0 +1,3 @@ +task_name: Mirror_SingleTask_ABSA_15res +data_dir: resources/Mirror/uie/absa/15res +best_metric_field: rel.rel.micro.f1 diff --git a/conf/uie_data/absa_16res.yaml b/conf/uie_data/absa_16res.yaml new file mode 100644 index 0000000000000000000000000000000000000000..10c1a8a073e2e26a02dfb626823bf5920867e670 --- /dev/null +++ b/conf/uie_data/absa_16res.yaml @@ -0,0 +1,3 @@ +task_name: Mirror_SingleTask_ABSA_16res +data_dir: resources/Mirror/uie/absa/16res +best_metric_field: rel.rel.micro.f1 diff --git a/conf/uie_data/ent_ace04.yaml b/conf/uie_data/ent_ace04.yaml new file mode 100644 index 0000000000000000000000000000000000000000..49969a4a5d5c69150453f7e9ee9cc2e9004f9cdb --- /dev/null +++ b/conf/uie_data/ent_ace04.yaml @@ -0,0 +1,3 @@ +task_name: Mirror_SingleTask_Ent_ACE04 +data_dir: resources/Mirror/uie/ent/ace04 +best_metric_field: ent.micro.f1 diff --git a/conf/uie_data/ent_ace05.yaml b/conf/uie_data/ent_ace05.yaml new file mode 100644 index 0000000000000000000000000000000000000000..413e9082313f503b991e51ce9dbe6c022f4a4a83 --- /dev/null +++ b/conf/uie_data/ent_ace05.yaml @@ -0,0 +1,3 @@ +task_name: Mirror_SingleTask_Ent_ACE05 +data_dir: resources/Mirror/uie/ent/ace05 +best_metric_field: ent.micro.f1 diff --git a/conf/uie_data/ent_conll03.yaml b/conf/uie_data/ent_conll03.yaml new file mode 100644 index 0000000000000000000000000000000000000000..377ceb61de89ed29e9841f94f9a1bfea536d40a4 --- /dev/null +++ b/conf/uie_data/ent_conll03.yaml @@ -0,0 +1,3 @@ +task_name: Mirror_SingleTask_Ent_CoNLL03 +data_dir: resources/Mirror/uie/ent/conll03 +best_metric_field: ent.micro.f1 diff --git a/conf/uie_data/event_ace05.yaml b/conf/uie_data/event_ace05.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f31ef7454b35217b4142ebc2de7dfed0565eb69e --- /dev/null +++ b/conf/uie_data/event_ace05.yaml @@ -0,0 +1,3 @@ +task_name: Mirror_SingleTask_Event_ACE05 +data_dir: resources/Mirror/uie/event/ace05-evt +best_metric_field: event.arg_cls.f1 diff --git a/conf/uie_data/event_casie.yaml b/conf/uie_data/event_casie.yaml new file mode 100644 index 0000000000000000000000000000000000000000..79de43b538f11fbe6459723abbd9333912c85e39 --- /dev/null +++ b/conf/uie_data/event_casie.yaml @@ -0,0 +1,3 @@ +task_name: Mirror_SingleTask_Event_CASIE +data_dir: resources/Mirror/uie/event/casie +best_metric_field: event.arg_cls.f1 diff --git a/conf/uie_data/fewshot.yaml b/conf/uie_data/fewshot.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e20665083de1cd3a17b7fa83ff665e58aeea0eb9 --- /dev/null +++ b/conf/uie_data/fewshot.yaml @@ -0,0 +1,5 @@ +num_epochs: 200 +epoch_patience: 10 +output_dir: mirror_fewshot_outputs +base_model_path: mirror_outputs/Mirror_Pretrain_AllExcluded_2/ckpt/SchemaGuidedInstructBertModel.best.pth +save_every_ckpt: false diff --git a/conf/uie_data/merged.yaml b/conf/uie_data/merged.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2c88730b7f7ff6384b06a1e871df7ba490dccc93 --- /dev/null +++ b/conf/uie_data/merged.yaml @@ -0,0 +1,3 @@ +task_name: Mirror_MultiTask_UIE +data_dir: resources/Mirror/uie/merged +best_metric_field: general_spans.micro.f1 diff --git a/conf/uie_data/rel_ace05.yaml b/conf/uie_data/rel_ace05.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4c8951eb48c2e0e6476a711ad4d5445a6c68597f --- /dev/null +++ b/conf/uie_data/rel_ace05.yaml @@ -0,0 +1,3 @@ +task_name: Mirror_SingleTask_Rel_ACE05 +data_dir: resources/Mirror/uie/rel/ace05-rel +best_metric_field: rel.rel.micro.f1 diff --git a/conf/uie_data/rel_conll04.yaml b/conf/uie_data/rel_conll04.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e3bde2d363451745a3fff17bb0779a69b888e5ab --- /dev/null +++ b/conf/uie_data/rel_conll04.yaml @@ -0,0 +1,3 @@ +task_name: Mirror_SingleTask_Rel_CoNLL04 +data_dir: resources/Mirror/uie/rel/conll04 +best_metric_field: rel.rel.micro.f1 diff --git a/conf/uie_data/rel_nyt.yaml b/conf/uie_data/rel_nyt.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1131946b1bce7022603c684084531e9e1c80eb9d --- /dev/null +++ b/conf/uie_data/rel_nyt.yaml @@ -0,0 +1,3 @@ +task_name: Mirror_SingleTask_Rel_NYT +data_dir: resources/Mirror/uie/rel/nyt +best_metric_field: rel.rel.micro.f1 diff --git a/conf/uie_data/rel_scierc.yaml b/conf/uie_data/rel_scierc.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7655405e61cae5f07ce9261c74839f371ccb6428 --- /dev/null +++ b/conf/uie_data/rel_scierc.yaml @@ -0,0 +1,3 @@ +task_name: Mirror_SingleTask_Rel_SciERC +data_dir: resources/Mirror/uie/rel/scierc +best_metric_field: rel.rel.micro.f1 diff --git a/conf/uie_data/wPretrain.yaml b/conf/uie_data/wPretrain.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a54be67d2b04181502a91dd25dca1a1bfa8daedc --- /dev/null +++ b/conf/uie_data/wPretrain.yaml @@ -0,0 +1,19 @@ +plm_dir: microsoft/deberta-v3-large +base_model_path: mirror_outputs/Mirror_Pretrain_AllExcluded_2/ckpt/SchemaGuidedInstructBertModel.best.pth + +stream_mode: false +train_filepath: ${data_dir}/train.jsonl +dev_filepath: ${data_dir}/dev.jsonl +test_filepath: ${data_dir}/test.jsonl + +num_epochs: 20 +epoch_patience: 3 +num_steps: -1 +step_patience: -1 +step_eval_interval: -1 + +eval_on_data: [dev] +select_best_on_data: dev +select_best_by_key: metric +best_metric_field: general_spans.micro.f1 +final_eval_on_test: true diff --git a/eval.py b/eval.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/index.html b/index.html new file mode 100644 index 0000000000000000000000000000000000000000..1af05bcbaf3e3599411534ddf8f35633b50a306e --- /dev/null +++ b/index.html @@ -0,0 +1,288 @@ + + + + + + + + ๐ŸชžMirror + + + + + + +
+

๐ŸชžMirror

+

+ ๐ŸชžMirror can help you deal with a wide range of Natural Language Understanding and Information Extraction tasks. +

+
+ +
+
+
+ + +
+
+ +

Split with # for multiple inputs

+

For entities, relations or classification, input {"ent|rel|cls": ["cls1", "type2"]} .

+

For events and hyper relations, input {"type": ["role1", "role2"]} .

+ + +
+
+ + +
+ + +
+ + + +
+ +
+

โฑ๏ธ {{ searchSecondsString }}

+
+ +
+
+ + + + + + + + + + + +
ItemPredicted
+
+
+ +
+
+ + + + + + + diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..8089003957db8735e2913a88283c46afc480d77f --- /dev/null +++ b/requirements.txt @@ -0,0 +1,10 @@ +pandas +rich +numpy +omegaconf +gpu-watchmen +tqdm +datasets +transformers +gradio +git+https://github.com/Spico197/REx diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/analyze.py b/src/analyze.py new file mode 100644 index 0000000000000000000000000000000000000000..d2f25c0c9a52e34d5c331ccfa98e11abbc13e55e --- /dev/null +++ b/src/analyze.py @@ -0,0 +1,135 @@ +from collections import defaultdict + +from rex.metrics.tagging import tagging_prf1 +from rex.utils.io import load_jsonlines +from rex.utils.position import find_all_positions + + +def main(): + middle_filepath = "outputs/InstructBert_TagSpan_DebertaV3Base_ACE05EN_labelmap_Rel_updateTag_bs32/middle/test.final.jsonl" + data = load_jsonlines(middle_filepath) + for ins in data: + gold = ins["gold"] + pred = ins["pred"] + if gold["spans"] != pred["spans"]: + breakpoint() + + +def check_ent_string_matching_upper_bound(filepath: str, strategy: str = "first"): + def _check_overlap(x, y): + if x[0] > y[1] or y[0] > x[1]: + return False + else: + return True + + data = load_jsonlines(filepath) + golds = [] + preds = [] + for ins in data: + text = ins["text"] + gold_ents = ins["ans"]["ent"] + gold_ents = list( + set([(ent["text"], ent["type"], tuple(ent["span"])) for ent in gold_ents]) + ) + gold_ents.sort(key=lambda x: len(x[0]), reverse=True) + pred_ents = [] + matched = set() + for gold_ent in gold_ents: + ent_string = gold_ent[0] + ent_type = gold_ent[1] + positions = find_all_positions(text, ent_string) + if strategy == "first": + for position in positions: + if (ent_type, position) not in matched: + matched.add((ent_type, position)) + pred_ents.append((ent_string, ent_type, tuple(position))) + else: + flag = False + for position in positions: + for _, g in matched: + if _check_overlap(g, position): + flag = True + if flag: + continue + + if (ent_type, position) not in matched: + matched.add((ent_type, position)) + pred_ents.append((ent_string, ent_type, tuple(position))) + break + + golds.append(gold_ents) + preds.append(pred_ents) + + results = tagging_prf1(golds, preds) + + print(f"filepath: {filepath}, Strategy: {strategy}") + print(f"Results: {results['micro']}") + + +def check_rel_tanl_upper_bound(filepath): + data = load_jsonlines(filepath) + golds = [] + preds = [] + for ins in data: + text = ins["text"] + gold_rels = ins["ans"]["rel"] + ent_text_to_spans = defaultdict(set) + for ent in ins["ans"]["ent"]: + ent_text_to_spans[ent["text"]].add(tuple(ent["span"])) + gold_rels = list( + set( + [ + ( + tuple(rel["head"]["span"]), + rel["relation"], + tuple(rel["tail"]["span"]), + ) + for rel in gold_rels + ] + ) + ) + pred_rels = [] + for pred_rel in ins["ans"]["rel"]: + # pred_triple = () + tail_text = pred_rel["tail"]["text"] + if ( + tail_text in ent_text_to_spans + and len(ent_text_to_spans[tail_text]) == 1 + ): + tail_span = list(ent_text_to_spans[tail_text])[0] + pred_rels.append( + (tuple(pred_rel["head"]["span"]), pred_rel["relation"], tail_span) + ) + # if tail_text in ent_text_to_spans: + # tail_span = list(ent_text_to_spans[tail_text])[0] + # else: + # tail_span = find_all_positions(text, tail_text)[0] + # pred_rels.append((tuple(pred_rel["head"]["span"]), pred_rel["relation"], tail_span)) + + golds.append(gold_rels) + preds.append(pred_rels) + + results = tagging_prf1(golds, preds) + + print(f"filepath: {filepath}") + print(f"Results: {results['micro']}") + + +if __name__ == "__main__": + # main() + + # for filepath in [ + # "/data/tzhu/Mirror/resources/Mirror/uie/ent/ace04/test.jsonl", + # "/data/tzhu/Mirror/resources/Mirror/uie/ent/ace05/test.jsonl", + # "/data/tzhu/Mirror/resources/Mirror/uie/ent/conll03/test.jsonl", + # ]: + # for strategy in ["first", "longer_first"]: + # check_ent_string_matching_upper_bound(filepath, strategy) + + for filepath in [ + "/data/tzhu/Mirror/resources/Mirror/uie/rel/ace05-rel/test.jsonl", + "/data/tzhu/Mirror/resources/Mirror/uie/rel/conll04/test.jsonl", + "/data/tzhu/Mirror/resources/Mirror/uie/rel/nyt/test.jsonl", + "/data/tzhu/Mirror/resources/Mirror/uie/rel/scierc/test.jsonl", + ]: + check_rel_tanl_upper_bound(filepath) diff --git a/src/app/__init__.py b/src/app/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/app/api_backend.py b/src/app/api_backend.py new file mode 100644 index 0000000000000000000000000000000000000000..220f6484880bfc78b2052a5c1d72a11e5ecfe38f --- /dev/null +++ b/src/app/api_backend.py @@ -0,0 +1,80 @@ +import traceback +from typing import Any, Dict, List + +import uvicorn +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import FileResponse +from pydantic import BaseModel +from rex.utils.initialization import set_seed_and_log_path + +from src.task import SchemaGuidedInstructBertTask + +set_seed_and_log_path(log_path="debug.log") + +app = FastAPI() + +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_methods=["*"], + allow_headers=["*"], +) + + +class RequestData(BaseModel): + data: List[Dict[str, Any]] + + +task = SchemaGuidedInstructBertTask.from_taskdir( + "mirror_outputs/Mirror_Pretrain_AllExcluded_2", + load_best_model=True, + initialize=False, + dump_configfile=False, + update_config={ + "regenerate_cache": False, + }, +) + + +@app.post("/process") +def process_data(data: RequestData): + input_data = data.data + + ok = True + msg = "" + results = {} + try: + results = task.predict(input_data) + msg = "success" + except KeyboardInterrupt: + raise KeyboardInterrupt + except Exception: + ok = False + msg = traceback.format_exc() + + # Return the processed data + return {"ok": ok, "msg": msg, "results": results} + + +@app.get("/") +async def api(): + return FileResponse("./index.html", media_type="text/html") + + +if __name__ == "__main__": + log_config = uvicorn.config.LOGGING_CONFIG + log_config["formatters"]["access"]["fmt"] = ( + "%(asctime)s | " + log_config["formatters"]["access"]["fmt"] + ) + log_config["formatters"]["default"]["fmt"] = ( + "%(asctime)s | " + log_config["formatters"]["default"]["fmt"] + ) + uvicorn.run( + "src.app.api_backend:app", + host="0.0.0.0", + port=7860, + log_level="debug", + log_config=log_config, + reload=True, + ) diff --git a/src/app/demo1_deprecated.py b/src/app/demo1_deprecated.py new file mode 100644 index 0000000000000000000000000000000000000000..e41018bfdb17651c97e2a03660044e67d629ebfc --- /dev/null +++ b/src/app/demo1_deprecated.py @@ -0,0 +1,97 @@ +import gradio as gr +from rex.utils.initialization import set_seed_and_log_path +from rex.utils.logging import logger + +from src.task import MrcQaTask, SchemaGuidedInstructBertTask + +set_seed_and_log_path(log_path="app.log") + + +class MrcQaPipeline: + def __init__(self, task_dir: str, load_path: str = None) -> None: + self.task = MrcQaTask.from_taskdir( + task_dir, load_best_model=load_path is None, initialize=False + ) + if load_path: + self.task.load(load_path, load_history=False) + + def predict(self, query, context, background=None): + data = [ + { + "query": query, + "context": context, + "background": background, + } + ] + results = self.task.predict(data) + ret = results[0] + + data[0]["pred"] = ret + logger.opt(colors=False).debug(data[0]) + + return ret + + +class InstructBertPipeline: + def __init__(self, task_dir: str, load_path: str = None) -> None: + self.task = SchemaGuidedInstructBertTask.from_taskdir( + task_dir, load_best_model=load_path is None, initialize=False + ) + if load_path: + self.task.load(load_path, load_history=False) + + def predict(self, instruction, schema, text, background): + data = [ + { + "query": query, + "context": context, + "background": background, + } + ] + results = self.task.predict(data) + ret = results[0] + + data[0]["pred"] = ret + logger.opt(colors=False).debug(data[0]) + + return ret + + +def mrc_qa(): + pipe = Pipeline("outputs/RobertaBase_data20230314v2") + + with gr.Blocks() as demo: + gr.Markdown("# ๐Ÿชž Mirror Mirror") + + with gr.Row(): + with gr.Column(): + with gr.Row(): + query = gr.Textbox( + label="Query", placeholder="Mirror Mirror, tell me ..." + ) + with gr.Row(): + context = gr.TextArea( + label="Candidates", + placeholder="Separated by comma (,) without spaces.", + ) + with gr.Row(): + background = gr.TextArea( + label="Background", + placeholder="Background explanation, could be empty", + ) + + with gr.Column(): + with gr.Row(): + trigger_button = gr.Button("Tell me the truth", variant="primary") + with gr.Row(): + output = gr.TextArea(label="Output") + + trigger_button.click( + pipe.predict, inputs=[query, context, background], outputs=output + ) + + demo.launch(show_error=True, share=False) + + +def instruct_bert_pipeline(): + task = SchemaGuidedInstructBertTask.from_taskdir() diff --git a/src/app/gradio_app.py b/src/app/gradio_app.py new file mode 100644 index 0000000000000000000000000000000000000000..1aa58aa0b48245f4ef5bedaf45e89c99023e3f15 --- /dev/null +++ b/src/app/gradio_app.py @@ -0,0 +1,58 @@ +import json + +import gradio as gr +from rex.utils.initialization import set_seed_and_log_path + +from src.task import SchemaGuidedInstructBertTask + +set_seed_and_log_path(log_path="debug.log") + + +task = SchemaGuidedInstructBertTask.from_taskdir( + "mirror_outputs/Mirror_Pretrain_AllExcluded_2", + load_best_model=True, + initialize=False, + dump_configfile=False, + update_config={ + "regenerate_cache": False, + }, +) + + +def ask_mirror(instruction, schema, text): + input_data = { + "id": "app", + "instruction": instruction, + "schema": json.loads(schema), + "text": text, + "ans": {}, + } + results = task.predict(input_data) + return results + + +with gr.Blocks() as demo: + gr.Markdown("# ๐ŸชžMirror") + gr.Markdown( + "๐ŸชžMirror can help you deal with a wide range of Natural Language Understanding and Information Extraction tasks." + ) + gr.Markdown( + "[[paper]](https://arxiv.org/abs/2311.05419) | [[code]](https://github.com/Spico197/Mirror)" + ) + + instruction = gr.Textbox(label="Instruction") + schema = gr.Textbox( + label="schema", + placeholder='{"cls": ["class1", "class2"], "ent": ["type1", "type2"], "rel": ["relation1", "relation2"]} leave it as {} to support span extraction.', + ) + text = gr.TextArea(label="Text") + output = gr.Textbox(label="Output") + + submit_btn = gr.Button("Ask Mirror") + submit_btn.click(ask_mirror, inputs=[instruction, schema, text], outputs=output) + + gr.Markdown("Made by Mirror Team w/ ๐Ÿ’–") + + +if __name__ == "__main__": + demo.launch() diff --git a/src/eval.py b/src/eval.py new file mode 100644 index 0000000000000000000000000000000000000000..f5158f3d892c73e0a2d15042b12e11904d7f7ac9 --- /dev/null +++ b/src/eval.py @@ -0,0 +1,142 @@ +from pathlib import Path + +import pandas as pd +from rex.utils.initialization import set_seed_and_log_path +from rex.utils.io import load_json +from rich.console import Console +from rich.table import Table + +from src.task import SchemaGuidedInstructBertTask + +set_seed_and_log_path(log_path="tmp_eval.log") + + +if __name__ == "__main__": + task_dir = "mirror_outputs/Mirror_Pretrain_AllExcluded_2" + # task_dir = "mirror_outputs/Mirror_SingleTask_wPTAllExcluded_Event_ACE05" + task: SchemaGuidedInstructBertTask = SchemaGuidedInstructBertTask.from_taskdir( + task_dir, + load_best_model=True, + initialize=False, + dump_configfile=False, + update_config={ + "regenerate_cache": True, + "eval_on_data": ["dev"], + "select_best_on_data": "dev", + "select_best_by_key": "metric", + "best_metric_field": "general_spans.micro.f1", + "eval_batch_size": 32, + }, + ) + table = Table(title=task_dir) + + data_pairs = [ + # fmt: off + + # UIE eval data + # ["ent_ace04_test", "resources/Mirror/uie/ent/ace04/test.jsonl"], + # ["ent_ace05_test", "resources/Mirror/uie/ent/ace05/test.jsonl"], + ["ent_conll03_test", "resources/Mirror/uie/ent/conll03/test.jsonl"], + # ["rel_ace05_test", "resources/Mirror/uie/rel/ace05-rel/test.jsonl"], + ["rel_conll04_test", "resources/Mirror/uie/rel/conll04/test.jsonl"], + # ["rel_nyt_test", "resources/Mirror/uie/rel/nyt/test.jsonl"], + # ["rel_scierc_test", "resources/Mirror/uie/rel/scierc/test.jsonl"], + ["event_ace05_test", "resources/Mirror/uie/event/ace05-evt/test.jsonl"], + # ["event_casie_test", "resources/Mirror/uie/event/casie/test.jsonl"], + # ["absa_14res_test", "resources/Mirror/uie/absa/14res/test.jsonl"], + # ["absa_14lap_test", "resources/Mirror/uie/absa/14lap/test.jsonl"], + # ["absa_15res_test", "resources/Mirror/uie/absa/15res/test.jsonl"], + # ["absa_16res_test", "resources/Mirror/uie/absa/16res/test.jsonl"], + # # discontinuous NER + # ["discontinuous_ent", "resources/Mirror/new_abilities_v2/cadec/new/test.jsonl"], + # # hyper-RE + # ["hyper_rel", "resources/Mirror/new_abilities_v2/HyperRED/new/test.jsonl"], + # # zero-shot NER + # ["ent_movie", "resources/Mirror/v1.3/ent/en/MIT_MOVIE_Review/instructed/test.jsonl"], + # ["ent_restaurant", "resources/Mirror/v1.3/ent/en/MIT_Restaurant_Review/instructed/test.jsonl"], + # ["ent_ai", "resources/Mirror/v1.3/ent/en/CrossNER_AI/instructed/test.jsonl"], + # ["ent_literature", "resources/Mirror/v1.3/ent/en/CrossNER_literature/instructed/test.jsonl"], + # ["ent_music", "resources/Mirror/v1.3/ent/en/CrossNER_music/instructed/test.jsonl"], + # ["ent_politics", "resources/Mirror/v1.3/ent/en/CrossNER_politics/instructed/test.jsonl"], + # ["ent_science", "resources/Mirror/v1.3/ent/en/CrossNER_science/instructed/test.jsonl"], + # # mrc + # ["span_squad2", "resources/Mirror/v1.3/span/en/squad_v2/dev.jsonl"], + # # glue + # ["cls_glue_cola", "resources/Mirror/v1.3/cls/en/CoLA/formated/dev.jsonl"], + # ["cls_glue_qqp", "resources/Mirror/v1.3/cls/en/QQP/new/dev.jsonl"], + # ["cls_glue_mnli", "resources/Mirror/v1.3/cls/en/MNLI/formated/MNLI_dev.jsonl"], + # ["cls_glue_sst2", "resources/Mirror/v1.3/cls/en/SST-2/instructed/SST-2_dev.jsonl"], + # ["cls_glue_qnli", "resources/Mirror/v1.3/cls/en/QNLI/processed/QNLI_dev.jsonl"], + # ["cls_glue_rte", "resources/Mirror/v1.3/cls/en/RTE/formated/RTE_dev.jsonl"], + # ["cls_glue_mrpc", "resources/Mirror/v1.3/cls/en/MRPC/formated/dev.jsonl"], + # fmt: on + ] + + eval_res = {"task": [], "dataset": [], "metric_val": []} + table.add_column("Task", justify="left", style="cyan") + table.add_column("Dataset", justify="left", style="magenta") + table.add_column("Metric (%)", justify="right", style="green") + for dname, fpath in data_pairs: + dname = dname.lower() + task.data_manager.update_datapath(dname, fpath) + _, res = task.eval(dname, verbose=True, dump=True, dump_middle=True) + # res = load_json(Path(task_dir) / "measures" / f"{dname}.json")["metrics"] + if dname.startswith("ent_"): + eval_res["task"].append("ent") + eval_res["dataset"].append(dname) + eval_res["metric_val"].append(res["ent"]["micro"]["f1"]) + elif dname.startswith("rel_"): + eval_res["task"].append("rel") + eval_res["dataset"].append(dname) + eval_res["metric_val"].append(res["rel"]["rel"]["micro"]["f1"]) + elif dname.startswith("event_"): + eval_res["task"].append("event") + eval_res["dataset"].append(dname + "_tgg") + eval_res["metric_val"].append(res["event"]["trigger_cls"]["f1"]) + eval_res["task"].append("event") + eval_res["dataset"].append(dname + "_arg") + eval_res["metric_val"].append(res["event"]["arg_cls"]["f1"]) + elif dname.startswith("absa_"): + eval_res["task"].append("absa") + eval_res["dataset"].append(dname) + eval_res["metric_val"].append(res["rel"]["rel"]["micro"]["f1"]) + elif dname.startswith("cls_"): + eval_res["task"].append("cls") + eval_res["dataset"].append(dname) + if "_glue_" in dname: + if "_cola" in dname: + eval_res["metric_val"].append(res["cls"]["mcc"]) + else: + eval_res["metric_val"].append(res["cls"]["acc"]) + else: + eval_res["metric_val"].append(res["cls"]["mf1"]["micro"]["f1"]) + elif dname.startswith("span"): + eval_res["task"].append("span_em") + eval_res["dataset"].append(dname) + eval_res["metric_val"].append(res["span"]["em"]) + eval_res["task"].append("span_f1") + eval_res["dataset"].append(dname) + eval_res["metric_val"].append(res["span"]["f1"]["f1"]) + elif dname.startswith("discontinuous_ent"): + eval_res["task"].append("discontinuous_ent") + eval_res["dataset"].append(dname) + eval_res["metric_val"].append(res["discontinuous_ent"]["micro"]["f1"]) + elif dname.startswith("hyper_rel"): + eval_res["task"].append("hyper_rel") + eval_res["dataset"].append(dname) + eval_res["metric_val"].append(res["hyper_rel"]["micro"]["f1"]) + else: + raise ValueError + + for i in range(len(eval_res["task"])): + table.add_row( + eval_res["task"][i], + eval_res["dataset"][i], + f"{100*eval_res['metric_val'][i]:.3f}", + ) + + console = Console() + console.print(table) + + df = pd.DataFrame(eval_res) + df.to_excel(task.measures_path.joinpath("data_eval_res.xlsx")) diff --git a/src/get_avg_results.py b/src/get_avg_results.py new file mode 100644 index 0000000000000000000000000000000000000000..f57d9705424451452bb2b17e3f17cc3b23e1897e --- /dev/null +++ b/src/get_avg_results.py @@ -0,0 +1,89 @@ +import os +import re +import statistics as sts +from collections import defaultdict +from pathlib import Path + +from rex.utils.dict import get_dict_content +from rex.utils.io import load_json +from rich.console import Console +from rich.table import Table + +inputs_dir = Path("mirror_fewshot_outputs") +# regex = re.compile(r"Mirror_SingleTask_(.*?)_seed(\d+)_(\d+)shot") +regex = re.compile(r"Mirror_wPT_woInst_(.*?)_seed(\d+)_(\d+)shot") + +# task -> shot -> seeds +results = defaultdict(lambda: defaultdict(list)) + +for dirname in os.listdir(inputs_dir): + dpath = inputs_dir / dirname + re_matched = regex.match(dirname) + if dpath.is_dir() and re_matched: + task, seed, shot = re_matched.groups() + results_json_p = dpath / "measures" / "test.final.json" + metrics = load_json(results_json_p) + if "Ent_" in task: + results[task][shot].append( + get_dict_content(metrics, "metrics.ent.micro.f1") + ) + elif "Rel_" in task or "ABSA_" in task: + results[task][shot].append( + get_dict_content(metrics, "metrics.rel.rel.micro.f1") + ) + elif "Event_" in task: + results[task + "_Trigger"][shot].append( + get_dict_content(metrics, "metrics.event.trigger_cls.f1") + ) + results[task + "_Arg"][shot].append( + get_dict_content(metrics, "metrics.event.arg_cls.f1") + ) + else: + raise RuntimeError + +table = Table(title="Few-shot results") +table.add_column("Task", justify="center") +table.add_column("1-shot", justify="right") +table.add_column("5-shot", justify="right") +table.add_column("10-shot", justify="right") +table.add_column("Avg.", justify="right") +for task in results: + shots = sorted(results[task].keys(), key=lambda x: int(x)) + all_seeds = [] + shot_results = [] + for shot in shots: + seeds = results[task][shot] + all_seeds.extend(seeds) + avg = sum(seeds) / len(seeds) + sts.stdev(seeds) + shot_results.append(f"{100*avg:.2f}ยฑ{100*sts.stdev(seeds):.2f}") + shot_results.append(f"{100*sts.mean(all_seeds):.2f}") + table.add_row(task, *shot_results) + +console = Console() +console.print(table) + +""" + Few-shot results wPT wInst +โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”ณโ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”ณโ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”ณโ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”ณโ”โ”โ”โ”โ”โ”โ”โ”“ +โ”ƒ Task โ”ƒ 1-shot โ”ƒ 5-shot โ”ƒ 10-shot โ”ƒ Avg. โ”ƒ +โ”กโ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ•‡โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ•‡โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ•‡โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ•‡โ”โ”โ”โ”โ”โ”โ”โ”ฉ +โ”‚ Ent_CoNLL03 โ”‚ 77.50ยฑ1.64 โ”‚ 82.73ยฑ2.29 โ”‚ 84.48ยฑ1.62 โ”‚ 81.57 โ”‚ +โ”‚ Rel_CoNLL04 โ”‚ 34.66ยฑ10.52 โ”‚ 52.23ยฑ3.16 โ”‚ 58.68ยฑ1.77 โ”‚ 48.52 โ”‚ +โ”‚ Event_ACE05_Trigger โ”‚ 49.50ยฑ3.59 โ”‚ 65.61ยฑ19.29 โ”‚ 60.68ยฑ2.45 โ”‚ 58.60 โ”‚ +โ”‚ Event_ACE05_Arg โ”‚ 23.46ยฑ1.66 โ”‚ 48.32ยฑ28.91 โ”‚ 41.90ยฑ1.95 โ”‚ 37.89 โ”‚ +โ”‚ ABSA_16res โ”‚ 67.06ยฑ0.56 โ”‚ 73.51ยฑ14.75 โ”‚ 68.70ยฑ1.46 โ”‚ 69.76 โ”‚ +โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ + + Few-shot results wPT woInst +โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”ณโ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”ณโ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”ณโ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”ณโ”โ”โ”โ”โ”โ”โ”โ”“ +โ”ƒ Task โ”ƒ 1-shot โ”ƒ 5-shot โ”ƒ 10-shot โ”ƒ Avg. โ”ƒ +โ”กโ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ•‡โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ•‡โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ•‡โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ•‡โ”โ”โ”โ”โ”โ”โ”โ”ฉ +โ”‚ Ent_CoNLL03 โ”‚ 76.33ยฑ1.74 โ”‚ 82.50ยฑ1.87 โ”‚ 84.47ยฑ1.18 โ”‚ 81.10 โ”‚ +โ”‚ woInst_Rel_CoNLL04 โ”‚ 34.86ยฑ6.20 โ”‚ 48.00ยฑ4.44 โ”‚ 55.65ยฑ2.53 โ”‚ 46.17 โ”‚ +โ”‚ Rel_CoNLL04 โ”‚ 26.83ยฑ15.22 โ”‚ 47.39ยฑ3.60 โ”‚ 55.38ยฑ2.41 โ”‚ 43.20 โ”‚ +โ”‚ Event_ACE05_Trigger โ”‚ 46.60ยฑ1.09 โ”‚ 57.21ยฑ3.51 โ”‚ 59.67ยฑ3.20 โ”‚ 54.49 โ”‚ +โ”‚ Event_ACE05_Arg โ”‚ 21.60ยฑ3.61 โ”‚ 34.43ยฑ3.63 โ”‚ 39.62ยฑ2.60 โ”‚ 31.88 โ”‚ +โ”‚ ABSA_16res โ”‚ 8.10ยฑ18.11 โ”‚ 52.73ยฑ5.52 โ”‚ 57.32ยฑ1.73 โ”‚ 39.38 โ”‚ +โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ +""" diff --git a/src/inference.py b/src/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..c501bf8de812c3e1e097a2e87a3a2898813568c8 --- /dev/null +++ b/src/inference.py @@ -0,0 +1,23 @@ +import os + +from rex.utils.logging import logger + +from src.task import MrcTaggingTask + +if __name__ == "__main__": + os.environ["CUDA_VISIBLE_DEVICES"] = "" + + task = MrcTaggingTask.from_taskdir( + "outputs/bert_mrc_ner", + load_best_model=True, + update_config={ + "skip_train": True, + "debug_mode": False, + }, + ) + + cases = ["123123", "123123"] + logger.info(f"Cases: {cases}") + + ents = task.predict(cases) + logger.info(f"Results: {ents}") diff --git a/src/metric.py b/src/metric.py new file mode 100644 index 0000000000000000000000000000000000000000..9f4a18e0fac965b239c2e8fc2c490746e06929c1 --- /dev/null +++ b/src/metric.py @@ -0,0 +1,555 @@ +from collections import defaultdict +from typing import Tuple + +from rex.metrics import calc_p_r_f1_from_tp_fp_fn, safe_division +from rex.metrics.base import MetricBase +from rex.metrics.tagging import tagging_prf1 +from rex.utils.batch import decompose_batch_into_instances +from rex.utils.iteration import windowed_queue_iter +from rex.utils.random import generate_random_string_with_datetime +from sklearn.metrics import accuracy_score, matthews_corrcoef + + +class MrcNERMetric(MetricBase): + def get_instances_from_batch(self, raw_batch: dict, out_batch: dict) -> Tuple: + gold_instances = [] + pred_instances = [] + + batch_gold = decompose_batch_into_instances(raw_batch) + assert len(batch_gold) == len(out_batch["pred"]) + + for i, gold in enumerate(batch_gold): + gold_instances.append( + { + "id": gold["id"], + "ents": {(gold["ent_type"], gent) for gent in gold["gold_ents"]}, + } + ) + pred_instances.append( + { + "id": gold["id"], + "ents": {(gold["ent_type"], pent) for pent in out_batch["pred"][i]}, + } + ) + + return gold_instances, pred_instances + + def calculate_scores(self, golds: list, preds: list) -> dict: + id2gold = defaultdict(set) + id2pred = defaultdict(set) + # aggregate all ents with diff queries before evaluating + for gold in golds: + id2gold[gold["id"]].update(gold["ents"]) + for pred in preds: + id2pred[pred["id"]].update(pred["ents"]) + assert len(id2gold) == len(id2pred) + + gold_ents = [] + pred_ents = [] + for _id in id2gold: + gold_ents.append(id2gold[_id]) + pred_ents.append(id2pred[_id]) + + return tagging_prf1(gold_ents, pred_ents, type_idx=0) + + +class MrcSpanMetric(MetricBase): + def get_instances_from_batch(self, raw_batch: dict, out_batch: dict) -> Tuple: + gold_instances = [] + pred_instances = [] + + batch_gold = decompose_batch_into_instances(raw_batch) + assert len(batch_gold) == len(out_batch["pred"]) + + for i, gold in enumerate(batch_gold): + gold_instances.append( + { + "id": gold["id"], + "spans": set(tuple(span) for span in gold["gold_spans"]), + } + ) + pred_instances.append( + { + "id": gold["id"], + "spans": set(out_batch["pred"][i]), + } + ) + + return gold_instances, pred_instances + + def calculate_scores(self, golds: list, preds: list) -> dict: + id2gold = defaultdict(set) + id2pred = defaultdict(set) + # aggregate all ents with diff queries before evaluating + for gold in golds: + id2gold[gold["id"]].update(gold["spans"]) + for pred in preds: + id2pred[pred["id"]].update(pred["spans"]) + assert len(id2gold) == len(id2pred) + + gold_spans = [] + pred_spans = [] + for _id in id2gold: + gold_spans.append(id2gold[_id]) + pred_spans.append(id2pred[_id]) + + return tagging_prf1(gold_spans, pred_spans, type_idx=None) + + +def calc_char_event(golds, preds): + """ + Calculate char-level event argument scores + + References: + - https://aistudio.baidu.com/aistudio/competition/detail/46/0/submit-result + + Args: + golds: a list of gold answers (a list of `event_list`), len=#data, + format is a list of `event_list` + preds: a list of pred answers, len=#data + """ + + def _match_arg_char_f1(gold_arg, pred_args): + gtype, grole, gstring = gold_arg + gchars = set(gstring) + garg_len = len(gchars) + cands = [] + for parg in pred_args: + if parg[0] == gtype and parg[1] == grole: + pchars = set(str(parg[-1])) + parg_len = len(pchars) + pmatch = len(pchars & gchars) + p = safe_division(pmatch, parg_len) + r = safe_division(pmatch, garg_len) + f1 = safe_division(2 * p * r, p + r) + cands.append(f1) + if len(cands) > 0: + f1 = sorted(cands)[-1] + return f1 + else: + return 0.0 + + pscore = num_gargs = num_pargs = 0 + for _golds, _preds in zip(golds, preds): + # _golds and _preds pair in one data instance + gold_args = [] + pred_args = [] + for gold in _golds: + for arg in gold.get("arguments", []): + gold_args.append( + (gold.get("event_type"), arg.get("role"), arg.get("argument")) + ) + for pred in _preds: + for arg in pred.get("arguments", []): + pred_args.append( + (pred.get("event_type"), arg.get("role"), arg.get("argument")) + ) + + num_gargs += len(gold_args) + num_pargs += len(pred_args) + for gold_arg in gold_args: + pscore += _match_arg_char_f1(gold_arg, pred_args) + + p = safe_division(pscore, num_pargs) + r = safe_division(pscore, num_gargs) + f1 = safe_division(2 * p * r, p + r) + return { + "p": p, + "r": r, + "f1": f1, + "pscore": pscore, + "num_pargs": num_pargs, + "num_gargs": num_gargs, + } + + +def calc_trigger_identification_metrics(golds, preds): + tp = fp = fn = 0 + for _golds, _preds in zip(golds, preds): + gold_triggers = {gold["trigger"] for gold in _golds} + pred_triggers = {pred["trigger"] for pred in _preds} + tp += len(gold_triggers & pred_triggers) + fp += len(pred_triggers - gold_triggers) + fn += len(gold_triggers - pred_triggers) + metrics = calc_p_r_f1_from_tp_fp_fn(tp, fp, fn) + return metrics + + +def calc_trigger_classification_metrics(golds, preds): + tp = fp = fn = 0 + for _golds, _preds in zip(golds, preds): + gold_tgg_cls = {(gold["trigger"], gold["event_type"]) for gold in _golds} + pred_tgg_cls = {(pred["trigger"], pred["event_type"]) for pred in _preds} + tp += len(gold_tgg_cls & pred_tgg_cls) + fp += len(pred_tgg_cls - gold_tgg_cls) + fn += len(gold_tgg_cls - pred_tgg_cls) + metrics = calc_p_r_f1_from_tp_fp_fn(tp, fp, fn) + return metrics + + +def calc_arg_identification_metrics(golds, preds): + """Calculate argument identification metrics + + Notice: + An entity could take different roles in an event, + so the base number must be calculated by + (arg, event type, pos, role) + """ + tp = fp = fn = 0 + for _golds, _preds in zip(golds, preds): + gold_args = set() + pred_args = set() + for gold in _golds: + _args = { + (arg["role"], arg["argument"], gold["event_type"]) + for arg in gold["arguments"] + } + gold_args.update(_args) + for pred in _preds: + _args = { + (arg["role"], arg["argument"], pred["event_type"]) + for arg in pred["arguments"] + } + pred_args.update(_args) + # logic derived from OneIE + _tp = 0 + _tp_fp = len(pred_args) + _tp_fn = len(gold_args) + _gold_args_wo_role = {_ga[1:] for _ga in gold_args} + for pred_arg in pred_args: + if pred_arg[1:] in _gold_args_wo_role: + _tp += 1 + tp += _tp + fp += _tp_fp - _tp + fn += _tp_fn - _tp + metrics = calc_p_r_f1_from_tp_fp_fn(tp, fp, fn) + return metrics + + +def calc_arg_classification_metrics(golds, preds): + tp = fp = fn = 0 + for _golds, _preds in zip(golds, preds): + gold_arg_cls = set() + pred_arg_cls = set() + for gold in _golds: + _args = { + (arg["argument"], arg["role"], gold["event_type"]) + for arg in gold["arguments"] + } + gold_arg_cls.update(_args) + for pred in _preds: + _args = { + (arg["argument"], arg["role"], pred["event_type"]) + for arg in pred["arguments"] + } + pred_arg_cls.update(_args) + tp += len(gold_arg_cls & pred_arg_cls) + fp += len(pred_arg_cls - gold_arg_cls) + fn += len(gold_arg_cls - pred_arg_cls) + metrics = calc_p_r_f1_from_tp_fp_fn(tp, fp, fn) + return metrics + + +def calc_ent(golds, preds): + """ + Args: + golds, preds: [(type, index list), ...] + """ + res = tagging_prf1(golds, preds, type_idx=0) + return res + + +def calc_rel(golds, preds): + gold_ents = [] + pred_ents = [] + for gold, pred in zip(golds, preds): + gold_ins_ents = [] + for t in gold: + gold_ins_ents.extend(t[1:]) + gold_ents.append(gold_ins_ents) + pred_ins_ents = [] + for t in pred: + pred_ins_ents.extend(t[1:]) + pred_ents.append(pred_ins_ents) + + metrics = { + "ent": tagging_prf1(gold_ents, pred_ents, type_idx=None), + "rel": tagging_prf1(golds, preds, type_idx=None), + } + return metrics + + +def calc_cls(golds, preds): + metrics = { + "mcc": -1, + "acc": -1, + "mf1": tagging_prf1(golds, preds, type_idx=None), + } + y_true = [] + y_pred = [] + for gold, pred in zip(golds, preds): + y_true.append(" ".join(sorted(gold))) + y_pred.append(" ".join(sorted(pred))) + if y_true and y_pred: + metrics["acc"] = accuracy_score(y_true, y_pred) + else: + metrics["acc"] = 0.0 + metrics["mcc"] = matthews_corrcoef(y_true, y_pred) + return metrics + + +def calc_span(golds, preds, mode="span"): + def _get_tokens(spans: list[tuple[tuple[int]]]) -> list[int]: + tokens = [] + for span in spans: + for part in span: + _toks = [] + if len(part) == 1: + _toks = [part[0]] + elif len(part) > 1: + if mode == "w2": + _toks = [*part] + elif mode == "span": + _toks = [*range(part[0], part[1] + 1)] + else: + raise ValueError + tokens.extend(_toks) + return tokens + + metrics = { + "em": -1, + "f1": None, + } + acc_num = 0 + tp = fp = fn = 0 + for gold, pred in zip(golds, preds): + if gold == pred: + acc_num += 1 + gold_tokens = _get_tokens(gold) + pred_tokens = _get_tokens(pred) + tp += len(set(gold_tokens) & set(pred_tokens)) + fp += len(set(pred_tokens) - set(gold_tokens)) + fn += len(set(gold_tokens) - set(pred_tokens)) + if len(golds) > 0: + metrics["em"] = acc_num / len(golds) + else: + metrics["em"] = 0.0 + metrics["f1"] = calc_p_r_f1_from_tp_fp_fn(tp, fp, fn) + return metrics + + +class MultiPartSpanMetric(MetricBase): + def _encode_span_to_label_dict(self, span_to_label: dict) -> list: + span_to_label_list = [] + for key, val in span_to_label.items(): + span_to_label_list.append({"key": key, "val": val}) + return span_to_label_list + + def _decode_span_to_label(self, span_to_label_list: list) -> dict: + span_to_label = {} + for content in span_to_label_list: + span_to_label[tuple(content["key"])] = content["val"] + return span_to_label + + def get_instances_from_batch(self, raw_batch: dict, out_batch: dict) -> Tuple: + gold_instances = [] + pred_instances = [] + + batch_gold = decompose_batch_into_instances(raw_batch) + assert len(batch_gold) == len(out_batch["pred"]) + + for i, gold in enumerate(batch_gold): + ins_id = gold["raw"].get("id", generate_random_string_with_datetime()) + # encode to list to make the span_to_label dict json-serializable + # where the original dict key is a tuple + span_to_label_list = self._encode_span_to_label_dict(gold["span_to_label"]) + gold["span_to_label"] = span_to_label_list + gold_instances.append( + { + "id": ins_id, + "span_to_label_list": span_to_label_list, + "raw_gold_content": gold, + "spans": set( + tuple(multi_part_span) for multi_part_span in gold["spans"] + ), + } + ) + pred_instances.append( + { + "id": ins_id, + "spans": set( + tuple(multi_part_span) + for multi_part_span in out_batch["pred"][i] + ), + } + ) + + return gold_instances, pred_instances + + def calculate_scores(self, golds: list, preds: list) -> dict: + # for general purpose evaluation + general_gold_spans, general_pred_spans = [], [] + # cls task + gold_cls_list, pred_cls_list = [], [] + # ent task + gold_ent_list, pred_ent_list = [], [] + # rel task + gold_rel_list, pred_rel_list = [], [] + # event task + gold_event_list, pred_event_list = [], [] + # span task + gold_span_list, pred_span_list = [], [] + # discon ent task + gold_discon_ent_list, pred_discon_ent_list = [], [] + # hyper rel task + gold_hyper_rel_list, pred_hyper_rel_list = [], [] + + for gold, pred in zip(golds, preds): + general_gold_spans.append(gold["spans"]) + general_pred_spans.append(pred["spans"]) + span_to_label = self._decode_span_to_label(gold["span_to_label_list"]) + gold_clses, pred_clses = [], [] + gold_ents, pred_ents = [], [] + gold_rels, pred_rels = [], [] + gold_trigger_to_event = defaultdict( + lambda: {"event_type": "", "arguments": []} + ) + pred_trigger_to_event = defaultdict( + lambda: {"event_type": "", "arguments": []} + ) + gold_events, pred_events = [], [] + gold_spans, pred_spans = [], [] + gold_discon_ents, pred_discon_ents = [], [] + gold_hyper_rels, pred_hyper_rels = [], [] + + raw_schema = gold["raw_gold_content"]["raw"]["schema"] + for span in gold["spans"]: + if span[0] in span_to_label: + label = span_to_label[span[0]] + if label["task"] == "cls" and len(span) == 1: + gold_clses.append(label["string"]) + elif label["task"] == "ent" and len(span) == 2: + gold_ents.append((label["string"], *span[1:])) + elif label["task"] == "rel" and len(span) == 3: + gold_rels.append((label["string"], *span[1:])) + elif label["task"] == "event": + if label["type"] == "lm" and len(span) == 2: + gold_trigger_to_event[span[1]]["event_type"] = label["string"] # fmt: skip + elif label["type"] == "lr" and len(span) == 3: + gold_trigger_to_event[span[1]]["arguments"].append( + {"argument": span[2], "role": label["string"]} + ) + elif label["task"] == "discontinuous_ent" and len(span) > 1: + gold_discon_ents.append((label["string"], *span[1:])) + elif label["task"] == "hyper_rel" and len(span) == 5 and span[3] in span_to_label: # fmt: skip + q_label = span_to_label[span[3]] + gold_hyper_rels.append((label["string"], span[1], span[2], q_label["string"], span[4])) # fmt: skip + else: + # span task has no labels + gold_spans.append(tuple(span)) + for trigger, item in gold_trigger_to_event.items(): + legal_roles = raw_schema["event"][item["event_type"]] + gold_events.append( + { + "trigger": trigger, + "event_type": item["event_type"], + "arguments": [ + arg + for arg in filter( + lambda arg: arg["role"] in legal_roles, + item["arguments"], + ) + ], + } + ) + + for span in pred["spans"]: + if span[0] in span_to_label: + label = span_to_label[span[0]] + if label["task"] == "cls" and len(span) == 1: + pred_clses.append(label["string"]) + elif label["task"] == "ent" and len(span) == 2: + pred_ents.append((label["string"], *span[1:])) + elif label["task"] == "rel" and len(span) == 3: + pred_rels.append((label["string"], *span[1:])) + elif label["task"] == "event": + if label["type"] == "lm" and len(span) == 2: + pred_trigger_to_event[span[1]]["event_type"] = label["string"] # fmt: skip + elif label["type"] == "lr" and len(span) == 3: + pred_trigger_to_event[span[1]]["arguments"].append( + {"argument": span[2], "role": label["string"]} + ) + elif label["task"] == "discontinuous_ent" and len(span) > 1: + pred_discon_ents.append((label["string"], *span[1:])) + elif label["task"] == "hyper_rel" and len(span) == 5 and span[3] in span_to_label: # fmt: skip + q_label = span_to_label[span[3]] + pred_hyper_rels.append((label["string"], span[1], span[2], q_label["string"], span[4])) # fmt: skip + else: + # span task has no labels + pred_spans.append(tuple(span)) + for trigger, item in pred_trigger_to_event.items(): + if item["event_type"] not in raw_schema["event"]: + continue + legal_roles = raw_schema["event"][item["event_type"]] + pred_events.append( + { + "trigger": trigger, + "event_type": item["event_type"], + "arguments": [ + arg + for arg in filter( + lambda arg: arg["role"] in legal_roles, + item["arguments"], + ) + ], + } + ) + + gold_cls_list.append(gold_clses) + pred_cls_list.append(pred_clses) + gold_ent_list.append(gold_ents) + pred_ent_list.append(pred_ents) + gold_rel_list.append(gold_rels) + pred_rel_list.append(pred_rels) + gold_event_list.append(gold_events) + pred_event_list.append(pred_events) + gold_span_list.append(gold_spans) + pred_span_list.append(pred_spans) + gold_discon_ent_list.append(gold_discon_ents) + pred_discon_ent_list.append(pred_discon_ents) + gold_hyper_rel_list.append(gold_hyper_rels) + pred_hyper_rel_list.append(pred_hyper_rels) + + metrics = { + "general_spans": tagging_prf1( + general_gold_spans, general_pred_spans, type_idx=None + ), + "cls": calc_cls(gold_cls_list, pred_cls_list), + "ent": calc_ent(gold_ent_list, pred_ent_list), + "rel": calc_rel(gold_rel_list, pred_rel_list), + "event": { + "trigger_id": calc_trigger_identification_metrics( + gold_event_list, pred_event_list + ), + "trigger_cls": calc_trigger_classification_metrics( + gold_event_list, pred_event_list + ), + "arg_id": calc_arg_identification_metrics( + gold_event_list, pred_event_list + ), + "arg_cls": calc_arg_classification_metrics( + gold_event_list, pred_event_list + ), + "char_event": calc_char_event(gold_event_list, pred_event_list), + }, + "discontinuous_ent": tagging_prf1( + gold_discon_ent_list, pred_discon_ent_list, type_idx=None + ), + "hyper_rel": tagging_prf1( + gold_hyper_rel_list, pred_hyper_rel_list, type_idx=None + ), + # "span": tagging_prf1(gold_span_list, pred_span_list, type_idx=None), + "span": calc_span(gold_span_list, pred_span_list), + } + + return metrics diff --git a/src/model.py b/src/model.py new file mode 100644 index 0000000000000000000000000000000000000000..57ba605e6752cab782a437f7ddfbd8054fac36d7 --- /dev/null +++ b/src/model.py @@ -0,0 +1,533 @@ +import torch +import torch.nn as nn +from rex.utils.iteration import windowed_queue_iter +from transformers import AutoModel, BertModel + +from src.utils import decode_nnw_nsw_thw_mat, decode_nnw_thw_mat, decode_pointer_mat + + +class Biaffine(nn.Module): + """Biaffine transformation + + References: + - https://github.com/yzhangcs/parser/blob/main/supar/modules/affine.py + - https://github.com/ljynlp/W2NER + """ + + def __init__(self, n_in, n_out=2, bias_x=True, bias_y=True): + super().__init__() + + self.n_in = n_in + self.n_out = n_out + self.bias_x = bias_x + self.bias_y = bias_y + weight = torch.zeros(n_out, n_in + int(bias_x), n_in + int(bias_y)) + nn.init.xavier_normal_(weight) + self.weight = nn.Parameter(weight, requires_grad=True) + + def extra_repr(self): + s = f"n_in={self.n_in}, n_out={self.n_out}" + if self.bias_x: + s += f", bias_x={self.bias_x}" + if self.bias_y: + s += f", bias_y={self.bias_y}" + + return s + + def forward(self, x, y): + if self.bias_x: + x = torch.cat((x, torch.ones_like(x[..., :1])), -1) + if self.bias_y: + y = torch.cat((y, torch.ones_like(y[..., :1])), -1) + # [batch_size, n_out, seq_len, seq_len] + s = torch.einsum("bxi,oij,byj->boxy", x, self.weight, y) + # s = s.permute(0, 2, 3, 1) + + return s + + +class LinearWithAct(nn.Module): + def __init__(self, n_in, n_out, dropout=0) -> None: + super().__init__() + + self.linear = nn.Linear(n_in, n_out) + self.act_fn = nn.GELU() + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + x = self.linear(x) + x = self.act_fn(x) + x = self.dropout(x) + return x + + +class PointerMatrix(nn.Module): + """Pointer Matrix Prediction + + References: + - https://github.com/ljynlp/W2NER + """ + + def __init__( + self, + hidden_size, + biaffine_size, + cls_num=2, + dropout=0, + biaffine_bias=False, + use_rope=False, + ): + super().__init__() + self.linear_h = LinearWithAct( + n_in=hidden_size, n_out=biaffine_size, dropout=dropout + ) + self.linear_t = LinearWithAct( + n_in=hidden_size, n_out=biaffine_size, dropout=dropout + ) + self.biaffine = Biaffine( + n_in=biaffine_size, + n_out=cls_num, + bias_x=biaffine_bias, + bias_y=biaffine_bias, + ) + self.use_rope = use_rope + + def sinusoidal_position_embedding(self, qw, kw): + batch_size, seq_len, output_dim = qw.shape + position_ids = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(-1) + + indices = torch.arange(0, output_dim // 2, dtype=torch.float) + indices = torch.pow(10000, -2 * indices / output_dim) + pos_emb = position_ids * indices + pos_emb = torch.stack([torch.sin(pos_emb), torch.cos(pos_emb)], dim=-1) + pos_emb = pos_emb.repeat((batch_size, *([1] * len(pos_emb.shape)))) + pos_emb = torch.reshape(pos_emb, (batch_size, seq_len, output_dim)) + pos_emb = pos_emb.to(qw) + + # (bs, seq_len, 1, hz) -> (bs, seq_len, hz) + cos_pos = pos_emb[..., 1::2].repeat_interleave(2, dim=-1) + # (bs, seq_len, 1, hz) -> (bs, seq_len, hz) + sin_pos = pos_emb[..., ::2].repeat_interleave(2, dim=-1) + qw2 = torch.cat([-qw[..., 1::2], qw[..., ::2]], -1) + qw = qw * cos_pos + qw2 * sin_pos + kw2 = torch.cat([-kw[..., 1::2], kw[..., ::2]], -1) + kw = kw * cos_pos + kw2 * sin_pos + return qw, kw + + def forward(self, x): + h = self.linear_h(x) + t = self.linear_t(x) + if self.use_rope: + h, t = self.sinusoidal_position_embedding(h, t) + o = self.biaffine(h, t) + return o + + +def multilabel_categorical_crossentropy(y_pred, y_true, bit_mask=None): + """ + https://kexue.fm/archives/7359 + https://github.com/gaohongkui/GlobalPointer_pytorch/blob/main/common/utils.py + """ + y_pred = (1 - 2 * y_true) * y_pred # -1 -> pos classes, 1 -> neg classes + y_pred_neg = y_pred - y_true * 1e12 # mask the pred outputs of pos classes + y_pred_pos = y_pred - (1 - y_true) * 1e12 # mask the pred outputs of neg classes + zeros = torch.zeros_like(y_pred[..., :1]) + y_pred_neg = torch.cat([y_pred_neg, zeros], dim=-1) + y_pred_pos = torch.cat([y_pred_pos, zeros], dim=-1) + neg_loss = torch.logsumexp(y_pred_neg, dim=-1) + pos_loss = torch.logsumexp(y_pred_pos, dim=-1) + + if bit_mask is None: + return neg_loss + pos_loss + else: + raise NotImplementedError + + +class MrcPointerMatrixModel(nn.Module): + def __init__( + self, + plm_dir: str, + cls_num: int = 2, + biaffine_size: int = 384, + none_type_id: int = 0, + text_mask_id: int = 4, + dropout: float = 0.3, + ): + super().__init__() + + # num of predicted classes, default is 3: None, NNW and THW + self.cls_num = cls_num + # None type id: 0, Next Neighboring Word (NNW): 1, Tail Head Word (THW): 2 + self.none_type_id = none_type_id + # input: cls instruction sep text sep pad + # mask: 1 2 3 4 5 0 + self.text_mask_id = text_mask_id + + self.plm = BertModel.from_pretrained(plm_dir) + hidden_size = self.plm.config.hidden_size + # self.biaffine_size = biaffine_size + self.nnw_mat = PointerMatrix( + hidden_size, biaffine_size, cls_num=2, dropout=dropout + ) + self.thw_mat = PointerMatrix( + hidden_size, biaffine_size, cls_num=2, dropout=dropout + ) + self.criterion = nn.CrossEntropyLoss() + + def input_encoding(self, input_ids, mask): + attention_mask = mask.gt(0).float() + plm_outputs = self.plm( + input_ids=input_ids, + attention_mask=attention_mask, + return_dict=True, + ) + return plm_outputs.last_hidden_state + + def build_bit_mask(self, mask: torch.Tensor) -> torch.Tensor: + # mask: (batch_size, seq_len) + bs, seq_len = mask.shape + mask_mat = ( + mask.eq(self.text_mask_id).unsqueeze(-1).expand((bs, seq_len, seq_len)) + ) + # bit_mask: (batch_size, seq_len, seq_len, 1) + bit_mask = ( + torch.logical_and(mask_mat, mask_mat.transpose(1, 2)).unsqueeze(1).long() + ) + return bit_mask + + def forward(self, input_ids, mask, labels=None, is_eval=False, **kwargs): + hidden = self.input_encoding(input_ids, mask) + nnw_hidden = self.nnw_mat(hidden) + thw_hidden = self.thw_mat(hidden) + # nnw_hidden = nnw_hidden / self.biaffine_size ** 0.5 + # thw_hidden = thw_hidden / self.biaffine_size ** 0.5 + # # (bs, 2, seq_len, seq_len) + bs, _, seq_len, seq_len = nnw_hidden.shape + + bit_mask = self.build_bit_mask(mask) + + results = {"logits": {"nnw": nnw_hidden, "thw": thw_hidden}} + if labels is not None: + # mean + nnw_loss = self.criterion( + nnw_hidden.permute(0, 2, 3, 1).reshape(-1, 2), + labels[:, 0, :, :].reshape(-1), + ) + thw_loss = self.criterion( + thw_hidden.permute(0, 2, 3, 1).reshape(-1, 2), + labels[:, 1, :, :].reshape(-1), + ) + loss = nnw_loss + thw_loss + results["loss"] = loss + + if is_eval: + batch_positions = self.decode(nnw_hidden, thw_hidden, bit_mask, **kwargs) + results["pred"] = batch_positions + return results + + def decode( + self, + nnw_hidden: torch.Tensor, + thw_hidden: torch.Tensor, + bit_mask: torch.Tensor, + **kwargs, + ): + # B x L x L + nnw_pred = nnw_hidden.argmax(1) + thw_pred = thw_hidden.argmax(1) + # B x 2 x L x L + pred = torch.stack([nnw_pred, thw_pred], dim=1) + pred = pred * bit_mask + + batch_preds = decode_nnw_thw_mat(pred, offsets=kwargs.get("offset")) + + return batch_preds + + +class MrcGlobalPointerModel(nn.Module): + def __init__( + self, + plm_dir: str, + use_rope: bool = True, + cls_num: int = 2, + biaffine_size: int = 384, + none_type_id: int = 0, + text_mask_id: int = 4, + dropout: float = 0.3, + mode: str = "w2", + ): + super().__init__() + + # num of predicted classes, default is 3: None, NNW and THW + self.cls_num = cls_num + # None type id: 0, Next Neighboring Word (NNW): 1, Tail Head Word (THW): 2 + self.none_type_id = none_type_id + # input: cls instruction sep text sep pad + # mask: 1 2 3 4 5 0 + self.text_mask_id = text_mask_id + self.use_rope = use_rope + + # mode: w2: w2ner, cons: consecutive spans + self.mode = mode + assert self.mode in ["w2", "cons"] + + self.plm = BertModel.from_pretrained(plm_dir) + self.hidden_size = self.plm.config.hidden_size + self.biaffine_size = biaffine_size + self.pointer = PointerMatrix( + self.hidden_size, + biaffine_size, + cls_num=2 if self.mode == "w2" else 1, + dropout=dropout, + biaffine_bias=True, + use_rope=use_rope, + ) + + def input_encoding(self, input_ids, mask): + attention_mask = mask.gt(0).float() + plm_outputs = self.plm( + input_ids=input_ids, + attention_mask=attention_mask, + return_dict=True, + ) + return plm_outputs.last_hidden_state + + def build_bit_mask(self, mask: torch.Tensor) -> torch.Tensor: + # mask: (batch_size, seq_len) + bs, seq_len = mask.shape + mask_mat = ( + mask.eq(self.text_mask_id).unsqueeze(-1).expand((bs, seq_len, seq_len)) + ) + # bit_mask: (batch_size, 1, seq_len, seq_len) + bit_mask = ( + torch.logical_and(mask_mat, mask_mat.transpose(1, 2)).unsqueeze(1).float() + ) + if self.mode == "cons": + bit_mask = bit_mask.triu() + + return bit_mask + + def forward( + self, input_ids, mask, labels=None, is_eval=False, top_p=0.5, top_k=-1, **kwargs + ): + bit_mask = self.build_bit_mask(mask) + hidden = self.input_encoding(input_ids, mask) + # (bs, 2, seq_len, seq_len) + logits = self.pointer(hidden) + logits = logits * bit_mask - (1.0 - bit_mask) * 1e12 + logits = logits / (self.biaffine_size**0.5) + # # (bs, 2, seq_len, seq_len) + bs, cls_num, seq_len, seq_len = logits.shape + assert labels.shape == (bs, cls_num, seq_len, seq_len) + + results = {"logits": logits} + if labels is not None: + loss = multilabel_categorical_crossentropy( + logits.reshape(bs * cls_num, -1), labels.reshape(bs * cls_num, -1) + ) + loss = loss.mean() + results["loss"] = loss + + if is_eval: + batch_positions = self.decode(logits, top_p=top_p, top_k=top_k, **kwargs) + results["pred"] = batch_positions + return results + + def calc_path_prob(self, probs, paths): + """ + Args: + probs: (2, seq_len, seq_len) | (1, seq_len, seq_len) + paths: a list of paths in tuple + + Returns: + [(path: tuple, prob: float), ...] + """ + assert self.mode in ["w2", "cons"] + paths_with_prob = [] + for path in paths: + path_prob = 1.0 + if self.mode == "w2": + for se in windowed_queue_iter(path, 2, 1, drop_last=True): + path_prob *= probs[0, se[0], se[-1]] + path_prob *= probs[1, path[-1], path[0]] + elif self.mode == "cons": + path_prob = probs[0, path[0], path[-1]] + paths_with_prob.append((path, path_prob)) + return paths_with_prob + + def decode( + self, + logits: torch.Tensor, + top_p: float = 0.5, + top_k: int = -1, + **kwargs, + ): + # mode: w2: w2ner with nnw and thw labels, cons: consecutive spans with one type of labels + assert self.mode in ["w2", "cons"] + # B x 2 x L x L + probs = logits.sigmoid() + pred = (probs > top_p).long() + if self.mode == "w2": + preds = decode_nnw_thw_mat(pred, offsets=kwargs.get("offset")) + elif self.mode == "cons": + pred = pred.triu() + preds = decode_pointer_mat(pred, offsets=kwargs.get("offset")) + + if top_k == -1: + batch_preds = preds + else: + batch_preds = [] + for i, paths in enumerate(preds): + paths_with_prob = self.calc_path_prob(probs[i], paths) + paths_with_prob.sort(key=lambda pp: pp[1], reverse=True) + batch_preds.append([pp[0] for pp in paths_with_prob[:top_k]]) + + return batch_preds + + +class SchemaGuidedInstructBertModel(nn.Module): + def __init__( + self, + plm_dir: str, + vocab_size: int = None, + use_rope: bool = True, + biaffine_size: int = 512, + label_mask_id: int = 4, + text_mask_id: int = 7, + dropout: float = 0.3, + ): + super().__init__() + + # input: [CLS] [I] Instruction [LM] PER [LM] LOC [LM] ORG [TL] Text [B] Background [SEP] [PAD] + # mask: 1 2 3 4 5 4 5 4 5 6 7 8 9 10 0 + self.label_mask_id = label_mask_id + self.text_mask_id = text_mask_id + self.use_rope = use_rope + + self.plm = AutoModel.from_pretrained(plm_dir) + if vocab_size: + self.plm.resize_token_embeddings(vocab_size) + self.hidden_size = self.plm.config.hidden_size + self.biaffine_size = biaffine_size + self.pointer = PointerMatrix( + self.hidden_size, + biaffine_size, + cls_num=3, + dropout=dropout, + biaffine_bias=True, + use_rope=use_rope, + ) + + def input_encoding(self, input_ids, mask): + attention_mask = mask.gt(0).float() + plm_outputs = self.plm( + input_ids=input_ids, + attention_mask=attention_mask, + return_dict=True, + ) + return plm_outputs.last_hidden_state + + def build_bit_mask(self, mask: torch.Tensor) -> torch.Tensor: + # mask: (batch_size, seq_len) + bs, seq_len = mask.shape + # _m = torch.logical_or(mask.eq(self.label_mask_id), mask.eq(self.text_mask_id)) + # mask_mat = _m.unsqueeze(-1).expand((bs, seq_len, seq_len)) + # # bit_mask: (batch_size, 1, seq_len, seq_len) + # bit_mask = ( + # torch.logical_and(mask_mat, mask_mat.transpose(1, 2)).unsqueeze(1).float() + # ) + bit_mask = ( + mask.gt(0).unsqueeze(1).unsqueeze(1).expand(bs, 1, seq_len, seq_len).float() + ) + + return bit_mask + + def forward( + self, input_ids, mask, labels=None, is_eval=False, top_p=0.5, top_k=-1, **kwargs + ): + bit_mask = self.build_bit_mask(mask) + hidden = self.input_encoding(input_ids, mask) + # (bs, 3, seq_len, seq_len) + logits = self.pointer(hidden) + logits = logits * bit_mask - (1.0 - bit_mask) * 1e12 + logits = logits / (self.biaffine_size**0.5) + # # (bs, 3, seq_len, seq_len) + bs, cls_num, seq_len, seq_len = logits.shape + assert labels.shape == (bs, cls_num, seq_len, seq_len) + + results = {"logits": logits} + if labels is not None: + loss = multilabel_categorical_crossentropy( + logits.reshape(bs * cls_num, -1), labels.reshape(bs * cls_num, -1) + ) + loss = loss.mean() + results["loss"] = loss + + if is_eval: + batch_positions = self.decode(logits, top_p=top_p, top_k=top_k, **kwargs) + results["pred"] = batch_positions + return results + + def calc_path_prob(self, probs, paths): + """ + Args: + probs: (2, seq_len, seq_len) | (1, seq_len, seq_len) + paths: a list of paths in tuple + + Returns: + [(path: tuple, prob: float), ...] + """ + paths_with_prob = [] + for path in paths: + path_prob = 1.0 + for se in windowed_queue_iter(path, 2, 1, drop_last=True): + path_prob *= probs[0, se[0], se[-1]] + path_prob *= probs[1, path[-1], path[0]] + paths_with_prob.append((path, path_prob)) + return paths_with_prob + + def decode( + self, + logits: torch.Tensor, + top_p: float = 0.5, + top_k: int = -1, + # legal_num_parts: tuple = (1, 2, 3), + legal_num_parts: tuple = None, + labels: torch.Tensor = None, + **kwargs, + ): + # B x 3 x L x L + if labels is None: + # `labels` is used for upper bound analysis + probs = logits.sigmoid() + pred = (probs > top_p).long() + else: + pred = labels + preds = decode_nnw_nsw_thw_mat(pred, offsets=kwargs.get("offset")) + # for pred, gold in zip(preds, kwargs.get("spans")): + # sorted_pred = sorted(set(tuple(x) for x in pred)) + # sorted_gold = sorted(set(tuple(x) for x in gold)) + # if sorted_pred != sorted_gold: + # breakpoint() + + if top_k == -1: + batch_preds = preds + else: + batch_preds = [] + for i, paths in enumerate(preds): + paths_with_prob = self.calc_path_prob(probs[i], paths) + paths_with_prob.sort(key=lambda pp: pp[1], reverse=True) + batch_preds.append([pp[0] for pp in paths_with_prob[:top_k]]) + + if legal_num_parts is not None: + legal_preds = [] + for ins_paths in batch_preds: + legal_paths = [] + for path in ins_paths: + if len(path) in legal_num_parts: + legal_paths.append(path) + legal_preds.append(legal_paths) + else: + legal_preds = batch_preds + + return legal_preds diff --git a/src/preprocess.py b/src/preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/task.py b/src/task.py new file mode 100644 index 0000000000000000000000000000000000000000..52285fdd5cacbbb2e8c3276343884c111647d89a --- /dev/null +++ b/src/task.py @@ -0,0 +1,590 @@ +import math +import re +from collections import defaultdict +from datetime import datetime +from typing import List + +import torch +import torch.optim as optim +from rex import accelerator +from rex.data.data_manager import DataManager +from rex.data.dataset import CachedDataset, StreamReadDataset +from rex.tasks.simple_metric_task import SimpleMetricTask +from rex.utils.batch import decompose_batch_into_instances +from rex.utils.config import ConfigParser +from rex.utils.dict import flatten_dict +from rex.utils.io import load_jsonlines +from rex.utils.registry import register +from torch.utils.tensorboard import SummaryWriter +from transformers.optimization import ( + get_cosine_schedule_with_warmup, + get_linear_schedule_with_warmup, +) + +from .metric import MrcNERMetric, MrcSpanMetric, MultiPartSpanMetric +from .model import ( + MrcGlobalPointerModel, + MrcPointerMatrixModel, + SchemaGuidedInstructBertModel, +) +from .transform import ( + CachedLabelPointerTransform, + CachedPointerMRCTransform, + CachedPointerTaggingTransform, +) + + +@register("task") +class MrcTaggingTask(SimpleMetricTask): + def __init__(self, config, **kwargs) -> None: + super().__init__(config, **kwargs) + + def after_initialization(self): + now_string = datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + self.tb_logger: SummaryWriter = SummaryWriter( + log_dir=self.task_path / "tb_summary" / now_string, + comment=self.config.comment, + ) + + def after_whole_train(self): + self.tb_logger.close() + + def get_grad_norm(self): + # for name, param in self.model.named_parameters(): + # if param.grad is not None: + # grads = param.grad.detach().data + # grad_norm = (grads.norm(p=2) / grads.numel()).item() + total_norm = 0.0 + for p in self.model.parameters(): + if p.grad is not None: + param_norm = p.grad.detach().data.norm(2) + total_norm += param_norm.item() ** 2 + total_norm = total_norm ** (1.0 / 2) + return total_norm + + def log_loss( + self, idx: int, loss_item: float, step_or_epoch: str, dataset_name: str + ): + self.tb_logger.add_scalar( + f"loss/{dataset_name}/{step_or_epoch}", loss_item, idx + ) + # self.tb_logger.add_scalars( + # "lr", + # { + # str(i): self.optimizer.param_groups[i]["lr"] + # for i in range(len(self.optimizer.param_groups)) + # }, + # idx, + # ) + self.tb_logger.add_scalar("lr", self.optimizer.param_groups[0]["lr"], idx) + self.tb_logger.add_scalar("grad_norm_total", self.get_grad_norm(), idx) + + def log_metrics( + self, idx: int, metrics: dict, step_or_epoch: str, dataset_name: str + ): + metrics = flatten_dict(metrics) + self.tb_logger.add_scalars(f"{dataset_name}/{step_or_epoch}", metrics, idx) + + def init_transform(self): + return CachedPointerTaggingTransform( + self.config.max_seq_len, + self.config.plm_dir, + self.config.ent_type2query_filepath, + mode=self.config.mode, + negative_sample_prob=self.config.negative_sample_prob, + ) + + def init_data_manager(self): + return DataManager( + self.config.train_filepath, + self.config.dev_filepath, + self.config.test_filepath, + CachedDataset, + self.transform, + load_jsonlines, + self.config.train_batch_size, + self.config.eval_batch_size, + self.transform.collate_fn, + use_stream_transform=False, + debug_mode=self.config.debug_mode, + dump_cache_dir=self.config.dump_cache_dir, + regenerate_cache=self.config.regenerate_cache, + ) + + def init_model(self): + # m = MrcPointerMatrixModel( + m = MrcGlobalPointerModel( + self.config.plm_dir, + biaffine_size=self.config.biaffine_size, + dropout=self.config.dropout, + mode=self.config.mode, + ) + return m + + def init_metric(self): + return MrcNERMetric() + + def init_optimizer(self): + no_decay = r"(embedding|LayerNorm|\.bias$)" + plm_lr = r"^plm\." + non_trainable = r"^plm\.(emb|encoder\.layer\.[0-3])" + + param_groups = [] + for name, param in self.model.named_parameters(): + lr = self.config.learning_rate + weight_decay = self.config.weight_decay + if re.search(non_trainable, name): + param.requires_grad = False + if not re.search(plm_lr, name): + lr = self.config.other_learning_rate + if re.search(no_decay, name): + weight_decay = 0.0 + param_groups.append( + {"params": param, "lr": lr, "weight_decay": weight_decay} + ) + return optim.AdamW( + param_groups, + lr=self.config.learning_rate, + betas=(0.9, 0.98), + eps=1e-6, + ) + + def init_lr_scheduler(self): + num_training_steps = int( + len(self.data_manager.train_loader) + * self.config.num_epochs + * accelerator.num_processes + ) + num_warmup_steps = math.floor( + num_training_steps * self.config.warmup_proportion + ) + # return get_linear_schedule_with_warmup( + return get_cosine_schedule_with_warmup( + self.optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + ) + + def predict_api(self, texts: List[str], **kwargs): + raw_dataset = self.transform.predict_transform(texts) + text_ids = sorted(list({ins["id"] for ins in raw_dataset})) + loader = self.data_manager.prepare_loader(raw_dataset) + # to prepare input device + loader = accelerator.prepare_data_loader(loader) + id2ents = defaultdict(set) + for batch in loader: + batch_out = self.model(**batch, is_eval=True) + for _id, _pred in zip(batch["id"], batch_out["pred"]): + id2ents[_id].update(_pred) + results = [id2ents[_id] for _id in text_ids] + + return results + + +@register("task") +class MrcQaTask(MrcTaggingTask): + def init_transform(self): + return CachedPointerMRCTransform( + self.config.max_seq_len, + self.config.plm_dir, + mode=self.config.mode, + ) + + def init_model(self): + # m = MrcPointerMatrixModel( + m = MrcGlobalPointerModel( + self.config.plm_dir, + biaffine_size=self.config.biaffine_size, + dropout=self.config.dropout, + mode=self.config.mode, + ) + return m + + def init_metric(self): + return MrcSpanMetric() + + def predict_api(self, data: list[dict], **kwargs): + """ + Args: + data: a list of dict with query, context, and background strings + """ + raw_dataset = self.transform.predict_transform(data) + loader = self.data_manager.prepare_loader(raw_dataset) + results = [] + for batch in loader: + batch_out = self.model(**batch, is_eval=True) + batch["pred"] = batch_out["pred"] + instances = decompose_batch_into_instances(batch) + for ins in instances: + preds = ins["pred"] + ins_results = [] + for index_list in preds: + ins_result = [] + for i in index_list: + ins_result.append(ins["raw_tokens"][i]) + ins_results.append(("".join(ins_result), tuple(index_list))) + results.append(ins_results) + + return results + + +class StreamReadDatasetWithLen(StreamReadDataset): + def __len__(self): + return 631346 + + +@register("task") +class SchemaGuidedInstructBertTask(MrcTaggingTask): + # def __init__(self, config, **kwargs) -> None: + # super().__init__(config, **kwargs) + + # from watchmen import ClientMode, WatchClient + + # client = WatchClient( + # id=config.task_name, + # gpus=[4], + # req_gpu_num=1, + # mode=ClientMode.SCHEDULE, + # server_host="127.0.0.1", + # server_port=62333, + # ) + # client.wait() + + # def init_lr_scheduler(self): + # num_training_steps = int( + # 631346 / self.config.train_batch_size + # * self.config.num_epochs + # * accelerator.num_processes + # ) + # num_warmup_steps = math.floor( + # num_training_steps * self.config.warmup_proportion + # ) + # # return get_linear_schedule_with_warmup( + # return get_cosine_schedule_with_warmup( + # self.optimizer, + # num_warmup_steps=num_warmup_steps, + # num_training_steps=num_training_steps, + # ) + + def init_transform(self): + self.transform: CachedLabelPointerTransform + return CachedLabelPointerTransform( + self.config.max_seq_len, + self.config.plm_dir, + mode=self.config.mode, + label_span=self.config.label_span, + include_instructions=self.config.get("include_instructions", True), + ) + + def init_data_manager(self): + if self.config.get("stream_mode", False): + DatasetClass = StreamReadDatasetWithLen + transform = self.transform.transform + else: + DatasetClass = CachedDataset + transform = self.transform + return DataManager( + self.config.train_filepath, + self.config.dev_filepath, + self.config.test_filepath, + DatasetClass, + transform, + load_jsonlines, + self.config.train_batch_size, + self.config.eval_batch_size, + self.transform.collate_fn, + use_stream_transform=self.config.get("stream_mode", False), + debug_mode=self.config.debug_mode, + dump_cache_dir=self.config.dump_cache_dir, + regenerate_cache=self.config.regenerate_cache, + ) + + def init_model(self): + self.model = SchemaGuidedInstructBertModel( + self.config.plm_dir, + vocab_size=len(self.transform.tokenizer), + use_rope=self.config.use_rope, + biaffine_size=self.config.biaffine_size, + dropout=self.config.dropout, + ) + + if self.config.get("base_model_path"): + self.load( + self.config.base_model_path, + load_config=False, + load_model=True, + load_optimizer=False, + load_history=False, + ) + return self.model + + def init_optimizer(self): + no_decay = r"(embedding|LayerNorm|\.bias$)" + plm_lr = r"^plm\." + # non_trainable = r"^plm\.(emb|encoder\.layer\.[0-3])" + non_trainable = "no_non_trainable" + + param_groups = [] + for name, param in self.model.named_parameters(): + lr = self.config.learning_rate + weight_decay = self.config.weight_decay + if re.search(non_trainable, name): + param.requires_grad = False + if not re.search(plm_lr, name): + lr = self.config.other_learning_rate + if re.search(no_decay, name): + weight_decay = 0.0 + param_groups.append( + {"params": param, "lr": lr, "weight_decay": weight_decay} + ) + return optim.AdamW( + param_groups, + lr=self.config.learning_rate, + betas=(0.9, 0.98), + eps=1e-6, + ) + + def init_metric(self): + return MultiPartSpanMetric() + + def _convert_span_to_string(self, span, token_ids, tokenizer): + string = "" + if len(span) == 0 or len(span) > 2: + pass + elif len(span) == 1: + string = tokenizer.decode(token_ids[span[0]]) + elif len(span) == 2: + string = tokenizer.decode(token_ids[span[0] : span[1] + 1]) + return (string, self.reset_position(token_ids, span)) + + def reset_position(self, token_ids: list[int], span: list[int]) -> list[int]: + if isinstance(token_ids, torch.Tensor): + input_ids = token_ids.cpu().tolist() + if len(span) < 1: + return span + + tp_token_id, tl_token_id = self.transform.tokenizer.convert_tokens_to_ids( + [self.transform.tp_token, self.transform.tl_token] + ) + offset = 0 + if tp_token_id in input_ids: + offset = input_ids.index(tp_token_id) + 1 + elif tl_token_id in input_ids: + offset = input_ids.index(tl_token_id) + 1 + return [i - offset for i in span] + + def predict_api(self, data: list[dict], **kwargs): + """ + Args: + data: a list of dict in UDI: + { + "id": str, + "instruction": str, + "schema": { + "ent": list, + "rel": list, + "event": dict, + "cls": list, + "discontinuous_ent": list, + "hyper_rel": dict + }, + "text": str, + "bg": str, + "ans": {}, # empty dict + } + """ + raw_dataset = [self.transform.transform(d) for d in data] + loader = self.data_manager.prepare_loader(raw_dataset) + results = [] + for batch in loader: + batch_out = self.model(**batch, is_eval=True) + batch["pred"] = batch_out["pred"] + instances = decompose_batch_into_instances(batch) + for ins in instances: + pred_clses = [] + pred_ents = [] + pred_rels = [] + pred_trigger_to_event = defaultdict( + lambda: {"event_type": "", "arguments": []} + ) + pred_events = [] + pred_spans = [] + pred_discon_ents = [] + pred_hyper_rels = [] + raw_schema = ins["raw"]["schema"] + for multi_part_span in ins["pred"]: + span = tuple(multi_part_span) + span_to_label = ins["span_to_label"] + if span[0] in span_to_label: + label = span_to_label[span[0]] + if label["task"] == "cls" and len(span) == 1: + pred_clses.append(label["string"]) + elif label["task"] == "ent" and len(span) == 2: + string = self._convert_span_to_string( + span[1], ins["input_ids"], self.transform.tokenizer + ) + pred_ents.append((label["string"], string)) + elif label["task"] == "rel" and len(span) == 3: + head = self._convert_span_to_string( + span[1], ins["input_ids"], self.transform.tokenizer + ) + tail = self._convert_span_to_string( + span[2], ins["input_ids"], self.transform.tokenizer + ) + pred_rels.append((label["string"], head, tail)) + elif label["task"] == "event": + if label["type"] == "lm" and len(span) == 2: + pred_trigger_to_event[span[1]]["event_type"] = label["string"] # fmt: skip + elif label["type"] == "lr" and len(span) == 3: + arg = self._convert_span_to_string( + span[2], ins["input_ids"], self.transform.tokenizer + ) + pred_trigger_to_event[span[1]]["arguments"].append( + {"argument": arg, "role": label["string"]} + ) + elif label["task"] == "discontinuous_ent" and len(span) > 1: + parts = [ + self._convert_span_to_string( + part, ins["input_ids"], self.transform.tokenizer + ) + for part in span[1:] + ] + string = " ".join([part[0] for part in parts]) + position = [] + for part in parts: + position.append(part[1]) + pred_discon_ents.append( + (label["string"], string, self.reset_position(position)) + ) + elif label["task"] == "hyper_rel" and len(span) == 5 and span[3] in span_to_label: # fmt: skip + q_label = span_to_label[span[3]] + span_1 = self._convert_span_to_string( + span[1], ins["input_ids"], self.transform.tokenizer + ) + span_2 = self._convert_span_to_string( + span[2], ins["input_ids"], self.transform.tokenizer + ) + span_4 = self._convert_span_to_string( + span[4], ins["input_ids"], self.transform.tokenizer + ) + pred_hyper_rels.append((label["string"], span_1, span_2, q_label["string"], span_4)) # fmt: skip + else: + # span task has no labels + pred_token_ids = [] + for part in span: + _pred_token_ids = [ins["input_ids"][i] for i in part] + pred_token_ids.extend(_pred_token_ids) + span_string = self.transform.tokenizer.decode(pred_token_ids) + pred_spans.append( + ( + span_string, + tuple( + [ + tuple( + self.reset_position( + ins["input_ids"].cpu().tolist(), part + ) + ) + for part in span + ] + ), + ) + ) + for trigger, item in pred_trigger_to_event.items(): + trigger = self._convert_span_to_string( + trigger, ins["input_ids"], self.transform.tokenizer + ) + if item["event_type"] not in raw_schema["event"]: + continue + legal_roles = raw_schema["event"][item["event_type"]] + pred_events.append( + { + "trigger": trigger, + "event_type": item["event_type"], + "arguments": [ + arg + for arg in filter( + lambda arg: arg["role"] in legal_roles, + item["arguments"], + ) + ], + } + ) + results.append( + { + "id": ins["raw"]["id"], + "results": { + "cls": pred_clses, + "ent": pred_ents, + "rel": pred_rels, + "event": pred_events, + "span": pred_spans, + "discon_ent": pred_discon_ents, + "hyper_rel": pred_hyper_rels, + }, + } + ) + + return results + + +if __name__ == "__main__": + pass + # further_finetune() + + # from rex.utils.config import ConfigParser + + # config = ConfigParser.parse_cmd(cmd_args=["-dc", "conf/ner.yaml"]) + # config = ConfigParser.parse_cmd(cmd_args=["-dc", "conf/mirror-ace05en.yaml"]) + + # task = MrcTaggingTask( + # config, + # initialize=True, + # makedirs=True, + # dump_configfile=True, + # ) + # task = SchemaGuidedInstructBertTask.from_taskdir( + # "outputs/InstructBert_TagSpan_DebertaV3Base_ACE05EN_Rel", + # initialize=True, + # load_config=True, + # dump_configfile=False, + # ) + # task = SchemaGuidedInstructBertTask( + # config, + # initialize=True, + # makedirs=True, + # dump_configfile=False, + # ) + # task.load( + # "outputs/InstructBert_TagSpan_DebertaV3Base_ACE05EN_NerRelEvent/ckpt/SchemaGuidedInstructBertModel.epoch.0.pth", + # load_config=False, + # ) + # task.eval("test", verbose=True, dump=True, dump_middle=True, postfix="re_eval") + # task.load( + # # "outputs/Mirror_RobertaBaseWwm_Cons_MsraMrc/ckpt/MrcGlobalPointerModel.best.pth", + # # "outputs/Mirror_RobertaBaseWwm_W2_MsraMrc_HyperParamExp1/ckpt/MrcGlobalPointerModel.best.pth", + # config.base_model_path, + # load_config=False, + # load_model=True, + # load_optimizer=False, + # load_history=False, + # ) + # task.train() + # task = MrcTaggingTask.from_taskdir( + # "outputs/Mirror_W2_MSRAv2_NER", + # initialize=True, + # dump_configfile=False, + # load_config=True, + # ) + # for name, _ in task.model.named_parameters(): + # print(name) + # task.eval("test", verbose=True, dump=True, dump_middle=True, postfix="re_eval.0.1") + + # task = MrcQaTask( + # config, + # initialize=True, + # makedirs=True, + # dump_configfile=True, + # ) + # task.train() + # task.eval("dev", verbose=True, dump=True, dump_middle=True, postfix="re_eval") diff --git a/src/transform.py b/src/transform.py new file mode 100644 index 0000000000000000000000000000000000000000..19fdd5371922bb75d5193aeac2f2926d65a4f436 --- /dev/null +++ b/src/transform.py @@ -0,0 +1,693 @@ +import random +import re +from collections import defaultdict +from typing import Iterable, Iterator, List, MutableSet, Optional, Tuple, TypeVar, Union + +import torch +import torch.nn.functional as F +from rex.data.collate_fn import GeneralCollateFn +from rex.data.transforms.base import CachedTransformBase, CachedTransformOneBase +from rex.metrics import calc_p_r_f1_from_tp_fp_fn +from rex.utils.io import load_json +from rex.utils.iteration import windowed_queue_iter +from rex.utils.logging import logger +from transformers import AutoTokenizer +from transformers.models.bert.tokenization_bert_fast import BertTokenizerFast +from transformers.models.deberta_v2.tokenization_deberta_v2_fast import ( + DebertaV2TokenizerFast, +) +from transformers.tokenization_utils_base import BatchEncoding + +from src.utils import ( + decode_nnw_nsw_thw_mat, + decode_nnw_thw_mat, + encode_nnw_nsw_thw_mat, + encode_nnw_thw_mat, +) + +Filled = TypeVar("Filled") + + +class PaddingMixin: + max_seq_len: int + + def pad_seq(self, batch_seqs: Iterable[Filled], fill: Filled) -> Iterable[Filled]: + max_len = max(len(seq) for seq in batch_seqs) + assert max_len <= self.max_seq_len + for i in range(len(batch_seqs)): + batch_seqs[i] = batch_seqs[i] + [fill] * (max_len - len(batch_seqs[i])) + return batch_seqs + + def pad_mat( + self, mats: List[torch.Tensor], fill: Union[int, float] + ) -> List[torch.Tensor]: + max_len = max(mat.shape[0] for mat in mats) + assert max_len <= self.max_seq_len + for i in range(len(mats)): + num_add = max_len - mats[i].shape[0] + mats[i] = F.pad( + mats[i], (0, 0, 0, num_add, 0, num_add), mode="constant", value=fill + ) + return mats + + +class PointerTransformMixin: + tokenizer: BertTokenizerFast + max_seq_len: int + space_token: str = "[unused1]" + + def build_ins( + self, + query_tokens: list[str], + context_tokens: list[str], + answer_indexes: list[list[int]], + add_context_tokens: list[str] = None, + ) -> Tuple: + # -2: cls and sep + reserved_seq_len = self.max_seq_len - 3 - len(query_tokens) + # reserve at least 20 tokens + if reserved_seq_len < 20: + raise ValueError( + f"Query {query_tokens} too long: {len(query_tokens)} " + f"while max seq len is {self.max_seq_len}" + ) + + input_tokens = [self.tokenizer.cls_token] + input_tokens += query_tokens + input_tokens += [self.tokenizer.sep_token] + offset = len(input_tokens) + input_tokens += context_tokens[:reserved_seq_len] + available_token_range = range( + offset, offset + len(context_tokens[:reserved_seq_len]) + ) + input_tokens += [self.tokenizer.sep_token] + + add_context_len = 0 + max_add_context_len = self.max_seq_len - len(input_tokens) - 1 + add_context_flag = False + if add_context_tokens and len(add_context_tokens) > 0: + add_context_flag = True + add_context_len = len(add_context_tokens[:max_add_context_len]) + input_tokens += add_context_tokens[:max_add_context_len] + input_tokens += [self.tokenizer.sep_token] + new_tokens = [] + for t in input_tokens: + if len(t.strip()) > 0: + new_tokens.append(t) + else: + new_tokens.append(self.space_token) + input_tokens = new_tokens + input_ids = self.tokenizer.convert_tokens_to_ids(input_tokens) + + mask = [1] + mask += [2] * len(query_tokens) + mask += [3] + mask += [4] * len(context_tokens[:reserved_seq_len]) + mask += [5] + if add_context_flag: + mask += [6] * add_context_len + mask += [7] + assert len(mask) == len(input_ids) <= self.max_seq_len + + available_spans = [tuple(i + offset for i in index) for index in answer_indexes] + available_spans = list( + filter( + lambda index: all(i in available_token_range for i in index), + available_spans, + ) + ) + + token_len = len(input_ids) + pad_len = self.max_seq_len - token_len + input_tokens += pad_len * [self.tokenizer.pad_token] + input_ids += pad_len * [self.tokenizer.pad_token_id] + mask += pad_len * [0] + + return input_tokens, input_ids, mask, offset, available_spans + + def update_labels(self, data: dict) -> dict: + bs = len(data["input_ids"]) + seq_len = self.max_seq_len + labels = torch.zeros((bs, 2, seq_len, seq_len)) + for i, batch_spans in enumerate(data["available_spans"]): + # offset = data["offset"][i] + # pad_len = data["mask"].count(0) + # token_len = seq_len - pad_len + for span in batch_spans: + if len(span) == 1: + labels[i, :, span[0], span[0]] = 1 + else: + for s, e in windowed_queue_iter(span, 2, 1, drop_last=True): + labels[i, 0, s, e] = 1 + labels[i, 1, span[-1], span[0]] = 1 + # labels[i, :, 0:offset, :] = -100 + # labels[i, :, :, 0:offset] = -100 + # labels[i, :, :, token_len:] = -100 + # labels[i, :, token_len:, :] = -100 + data["labels"] = labels + return data + + def update_consecutive_span_labels(self, data: dict) -> dict: + bs = len(data["input_ids"]) + seq_len = self.max_seq_len + labels = torch.zeros((bs, 1, seq_len, seq_len)) + for i, batch_spans in enumerate(data["available_spans"]): + for span in batch_spans: + assert span == tuple(sorted(set(span))) + if len(span) == 1: + labels[i, 0, span[0], span[0]] = 1 + else: + labels[i, 0, span[0], span[-1]] = 1 + data["labels"] = labels + return data + + +class CachedPointerTaggingTransform(CachedTransformBase, PointerTransformMixin): + def __init__( + self, + max_seq_len: int, + plm_dir: str, + ent_type2query_filepath: str, + mode: str = "w2", + negative_sample_prob: float = 1.0, + ) -> None: + super().__init__() + + self.max_seq_len: int = max_seq_len + self.tokenizer: BertTokenizerFast = BertTokenizerFast.from_pretrained(plm_dir) + self.ent_type2query: dict = load_json(ent_type2query_filepath) + self.negative_sample_prob = negative_sample_prob + + self.collate_fn: GeneralCollateFn = GeneralCollateFn( + { + "input_ids": torch.long, + "mask": torch.long, + "labels": torch.long, + }, + guessing=False, + missing_key_as_null=True, + ) + if mode == "w2": + self.collate_fn.update_before_tensorify = self.update_labels + elif mode == "cons": + self.collate_fn.update_before_tensorify = ( + self.update_consecutive_span_labels + ) + else: + raise ValueError(f"Mode: {mode} not recognizable") + + def transform( + self, + transform_loader: Iterator, + dataset_name: str = None, + **kwargs, + ) -> Iterable: + final_data = [] + # tp = fp = fn = 0 + for data in transform_loader: + ent_type2ents = defaultdict(set) + for ent in data["ents"]: + ent_type2ents[ent["type"]].add(tuple(ent["index"])) + for ent_type in self.ent_type2query: + gold_ents = ent_type2ents[ent_type] + if ( + len(gold_ents) < 1 + and dataset_name == "train" + and random.random() > self.negative_sample_prob + ): + # skip negative samples + continue + # res = self.build_ins(ent_type, data["tokens"], gold_ents) + query = self.ent_type2query[ent_type] + query_tokens = self.tokenizer.tokenize(query) + try: + res = self.build_ins(query_tokens, data["tokens"], gold_ents) + except (ValueError, AssertionError): + continue + input_tokens, input_ids, mask, offset, available_spans = res + ins = { + "id": data.get("id", str(len(final_data))), + "ent_type": ent_type, + "gold_ents": gold_ents, + "raw_tokens": data["tokens"], + "input_tokens": input_tokens, + "input_ids": input_ids, + "mask": mask, + "offset": offset, + "available_spans": available_spans, + # labels are dynamically padded in collate fn + "labels": None, + # "labels": labels.tolist(), + } + final_data.append(ins) + + # # upper bound analysis + # pred_spans = set(decode_nnw_thw_mat(labels.unsqueeze(0))[0]) + # g_ents = set(available_spans) + # tp += len(g_ents & pred_spans) + # fp += len(pred_spans - g_ents) + # fn += len(g_ents - pred_spans) + + # # upper bound results + # measures = calc_p_r_f1_from_tp_fp_fn(tp, fp, fn) + # logger.info(f"Upper Bound: {measures}") + + return final_data + + def predict_transform(self, texts: List[str]): + dataset = [] + for text_id, text in enumerate(texts): + data_id = f"Prediction#{text_id}" + tokens = self.tokenizer.tokenize(text) + dataset.append( + { + "id": data_id, + "tokens": tokens, + "ents": [], + } + ) + final_data = self(dataset, disable_pbar=True) + return final_data + + +class CachedPointerMRCTransform(CachedTransformBase, PointerTransformMixin): + def __init__( + self, + max_seq_len: int, + plm_dir: str, + mode: str = "w2", + ) -> None: + super().__init__() + + self.max_seq_len: int = max_seq_len + self.tokenizer: BertTokenizerFast = BertTokenizerFast.from_pretrained(plm_dir) + + self.collate_fn: GeneralCollateFn = GeneralCollateFn( + { + "input_ids": torch.long, + "mask": torch.long, + "labels": torch.long, + }, + guessing=False, + missing_key_as_null=True, + ) + + if mode == "w2": + self.collate_fn.update_before_tensorify = self.update_labels + elif mode == "cons": + self.collate_fn.update_before_tensorify = ( + self.update_consecutive_span_labels + ) + else: + raise ValueError(f"Mode: {mode} not recognizable") + + def transform( + self, + transform_loader: Iterator, + dataset_name: str = None, + **kwargs, + ) -> Iterable: + final_data = [] + for data in transform_loader: + try: + res = self.build_ins( + data["query_tokens"], + data["context_tokens"], + data["answer_index"], + data.get("background_tokens"), + ) + except (ValueError, AssertionError): + continue + input_tokens, input_ids, mask, offset, available_spans = res + ins = { + "id": data.get("id", str(len(final_data))), + "gold_spans": sorted(set(tuple(x) for x in data["answer_index"])), + "raw_tokens": data["context_tokens"], + "input_tokens": input_tokens, + "input_ids": input_ids, + "mask": mask, + "offset": offset, + "available_spans": available_spans, + "labels": None, + } + final_data.append(ins) + + return final_data + + def predict_transform(self, data: list[dict]): + """ + Args: + data: a list of dict with query, context, and background strings + """ + dataset = [] + for idx, ins in enumerate(data): + idx = f"Prediction#{idx}" + dataset.append( + { + "id": idx, + "query_tokens": list(ins["query"]), + "context_tokens": list(ins["context"]), + "background_tokens": list(ins.get("background")), + "answer_index": [], + } + ) + final_data = self(dataset, disable_pbar=True, num_samples=0) + return final_data + + +class CachedLabelPointerTransform(CachedTransformOneBase): + """Transform for label-token linking for skip consecutive spans""" + + def __init__( + self, + max_seq_len: int, + plm_dir: str, + mode: str = "w2", + label_span: str = "tag", + include_instructions: bool = True, + **kwargs, + ) -> None: + super().__init__() + + self.max_seq_len: int = max_seq_len + self.mode = mode + self.label_span = label_span + self.include_instructions = include_instructions + + self.tokenizer: DebertaV2TokenizerFast = DebertaV2TokenizerFast.from_pretrained( + plm_dir + ) + self.lc_token = "[LC]" + self.lm_token = "[LM]" + self.lr_token = "[LR]" + self.i_token = "[I]" + self.tl_token = "[TL]" + self.tp_token = "[TP]" + self.b_token = "[B]" + num_added = self.tokenizer.add_tokens( + [ + self.lc_token, + self.lm_token, + self.lr_token, + self.i_token, + self.tl_token, + self.tp_token, + self.b_token, + ] + ) + assert num_added == 7 + + self.collate_fn: GeneralCollateFn = GeneralCollateFn( + { + "input_ids": torch.long, + "mask": torch.long, + "labels": torch.long, + "spans": None, + }, + guessing=False, + missing_key_as_null=True, + # only for pre-training + discard_missing=False, + ) + + self.collate_fn.update_before_tensorify = self.skip_consecutive_span_labels + + def transform(self, instance: dict, **kwargs): + # input + tokens = [self.tokenizer.cls_token] + mask = [1] + label_map = {"lc": {}, "lm": {}, "lr": {}} + # (2, 3): {"type": "lc", "task": "cls/ent/rel/event/hyper_rel/discontinuous_ent", "string": ""} + span_to_label = {} + + def _update_seq( + label: str, + label_type: str, + task: str = "", + label_mask: int = 4, + content_mask: int = 5, + ): + if label not in label_map[label_type]: + label_token_map = { + "lc": self.lc_token, + "lm": self.lm_token, + "lr": self.lr_token, + } + label_tag_start_idx = len(tokens) + tokens.append(label_token_map[label_type]) + mask.append(label_mask) + label_tag_end_idx = len(tokens) - 1 # exact end position + label_tokens = self.tokenizer(label, add_special_tokens=False).tokens() + label_content_start_idx = len(tokens) + tokens.extend(label_tokens) + mask.extend([content_mask] * len(label_tokens)) + label_content_end_idx = len(tokens) - 1 # exact end position + + if self.label_span == "tag": + start_idx = label_tag_start_idx + end_idx = label_tag_end_idx + elif self.label_span == "content": + start_idx = label_content_start_idx + end_idx = label_content_end_idx + else: + raise ValueError(f"label_span={self.label_span} is not supported") + + if end_idx == start_idx: + label_map[label_type][label] = (start_idx,) + else: + label_map[label_type][label] = (start_idx, end_idx) + span_to_label[label_map[label_type][label]] = { + "type": label_type, + "task": task, + "string": label, + } + return label_map[label_type][label] + + if self.include_instructions: + instruction = instance.get("instruction") + if not instruction: + logger.warning( + "include_instructions=True, while the instruction is empty!" + ) + else: + instruction = "" + if instruction: + tokens.append(self.i_token) + mask.append(2) + instruction_tokens = self.tokenizer( + instruction, add_special_tokens=False + ).tokens() + tokens.extend(instruction_tokens) + mask.extend([3] * len(instruction_tokens)) + types = instance["schema"].get("cls") + if types: + for t in types: + _update_seq(t, "lc", task="cls") + mention_types = instance["schema"].get("ent") + if mention_types: + for mt in mention_types: + _update_seq(mt, "lm", task="ent") + discon_ent_types = instance["schema"].get("discontinuous_ent") + if discon_ent_types: + for mt in discon_ent_types: + _update_seq(mt, "lm", task="discontinuous_ent") + rel_types = instance["schema"].get("rel") + if rel_types: + for rt in rel_types: + _update_seq(rt, "lr", task="rel") + hyper_rel_schema = instance["schema"].get("hyper_rel") + if hyper_rel_schema: + for rel, qualifiers in hyper_rel_schema.items(): + _update_seq(rel, "lr", task="hyper_rel") + for qualifier in qualifiers: + _update_seq(qualifier, "lr", task="hyper_rel") + event_schema = instance["schema"].get("event") + if event_schema: + for event_type, roles in event_schema.items(): + _update_seq(event_type, "lm", task="event") + for role in roles: + _update_seq(role, "lr", task="event") + + text = instance.get("text") + if text: + text_tokenized = self.tokenizer( + text, return_offsets_mapping=True, add_special_tokens=False + ) + if any(val for val in label_map.values()): + text_label_token = self.tl_token + else: + text_label_token = self.tp_token + tokens.append(text_label_token) + mask.append(6) + remain_token_len = self.max_seq_len - 1 - len(tokens) + if remain_token_len < 5 and kwargs.get("dataset_name", "train") == "train": + return None + text_off = len(tokens) + text_tokens = text_tokenized.tokens()[:remain_token_len] + tokens.extend(text_tokens) + mask.extend([7] * len(text_tokens)) + else: + text_tokenized = None + + bg = instance.get("bg") + if bg: + bg_tokenized = self.tokenizer( + bg, return_offsets_mapping=True, add_special_tokens=False + ) + tokens.append(self.b_token) + mask.append(8) + remain_token_len = self.max_seq_len - 1 - len(tokens) + if remain_token_len < 5 and kwargs.get("dataset_name", "train") == "train": + return None + bg_tokens = bg_tokenized.tokens()[:remain_token_len] + tokens.extend(bg_tokens) + mask.extend([9] * len(bg_tokens)) + else: + bg_tokenized = None + + tokens.append(self.tokenizer.sep_token) + mask.append(10) + + # labels + # spans: [[(ent_type start, ent_type end + 1), (ent s, ent e + 1)]] + spans = [] # one span may have many parts + if "cls" in instance["ans"]: + for t in instance["ans"]["cls"]: + part = label_map["lc"][t] + spans.append([part]) + if "ent" in instance["ans"]: + for ent in instance["ans"]["ent"]: + label_part = label_map["lm"][ent["type"]] + position_seq = self.char_to_token_span( + ent["span"], text_tokenized, text_off + ) + spans.append([label_part, position_seq]) + if "discontinuous_ent" in instance["ans"]: + for ent in instance["ans"]["discontinuous_ent"]: + label_part = label_map["lm"][ent["type"]] + ent_span = [label_part] + for part in ent["span"]: + position_seq = self.char_to_token_span( + part, text_tokenized, text_off + ) + ent_span.append(position_seq) + spans.append(ent_span) + if "rel" in instance["ans"]: + for rel in instance["ans"]["rel"]: + label_part = label_map["lr"][rel["relation"]] + head_position_seq = self.char_to_token_span( + rel["head"]["span"], text_tokenized, text_off + ) + tail_position_seq = self.char_to_token_span( + rel["tail"]["span"], text_tokenized, text_off + ) + spans.append([label_part, head_position_seq, tail_position_seq]) + if "hyper_rel" in instance["ans"]: + for rel in instance["ans"]["hyper_rel"]: + label_part = label_map["lr"][rel["relation"]] + head_position_seq = self.char_to_token_span( + rel["head"]["span"], text_tokenized, text_off + ) + tail_position_seq = self.char_to_token_span( + rel["tail"]["span"], text_tokenized, text_off + ) + # rel_span = [label_part, head_position_seq, tail_position_seq] + for q in rel["qualifiers"]: + q_label_part = label_map["lr"][q["label"]] + q_position_seq = self.char_to_token_span( + q["span"], text_tokenized, text_off + ) + spans.append( + [ + label_part, + head_position_seq, + tail_position_seq, + q_label_part, + q_position_seq, + ] + ) + if "event" in instance["ans"]: + for event in instance["ans"]["event"]: + event_type_label_part = label_map["lm"][event["event_type"]] + trigger_position_seq = self.char_to_token_span( + event["trigger"]["span"], text_tokenized, text_off + ) + trigger_part = [event_type_label_part, trigger_position_seq] + spans.append(trigger_part) + for arg in event["args"]: + role_label_part = label_map["lr"][arg["role"]] + arg_position_seq = self.char_to_token_span( + arg["span"], text_tokenized, text_off + ) + arg_part = [role_label_part, trigger_position_seq, arg_position_seq] + spans.append(arg_part) + if "span" in instance["ans"]: + # Extractive-QA or Extractive-MRC tasks + for span in instance["ans"]["span"]: + span_position_seq = self.char_to_token_span( + span["span"], text_tokenized, text_off + ) + spans.append([span_position_seq]) + + if self.mode == "w2": + new_spans = [] + for parts in spans: + new_parts = [] + for part in parts: + new_parts.append(tuple(range(part[0], part[-1] + 1))) + new_spans.append(new_parts) + spans = new_spans + elif self.mode == "span": + spans = spans + else: + raise ValueError(f"mode={self.mode} is not supported") + + ins = { + "raw": instance, + "tokens": tokens, + "input_ids": self.tokenizer.convert_tokens_to_ids(tokens), + "mask": mask, + "spans": spans, + "label_map": label_map, + "span_to_label": span_to_label, + "labels": None, # labels are calculated dynamically in collate_fn + } + return ins + + def char_to_token_span( + self, span: list[int], tokenized: BatchEncoding, offset: int = 0 + ) -> list[int]: + token_s = tokenized.char_to_token(span[0]) + token_e = tokenized.char_to_token(span[1] - 1) + if token_e == token_s: + position_seq = (offset + token_s,) + else: + position_seq = (offset + token_s, offset + token_e) + return position_seq + + def skip_consecutive_span_labels(self, data: dict) -> dict: + bs = len(data["input_ids"]) + max_seq_len = max(len(input_ids) for input_ids in data["input_ids"]) + batch_seq_len = min(self.max_seq_len, max_seq_len) + for i in range(bs): + data["input_ids"][i] = data["input_ids"][i][:batch_seq_len] + data["mask"][i] = data["mask"][i][:batch_seq_len] + assert len(data["input_ids"][i]) == len(data["mask"][i]) + pad_len = batch_seq_len - len(data["mask"][i]) + data["input_ids"][i] = ( + data["input_ids"][i] + [self.tokenizer.pad_token_id] * pad_len + ) + data["mask"][i] = data["mask"][i] + [0] * pad_len + data["labels"][i] = encode_nnw_nsw_thw_mat(data["spans"][i], batch_seq_len) + + # # for debugging only + # pred_spans = decode_nnw_nsw_thw_mat(data["labels"][i].unsqueeze(0))[0] + # sorted_gold = sorted(set(tuple(x) for x in data["spans"][i])) + # sorted_pred = sorted(set(tuple(x) for x in pred_spans)) + # if sorted_gold != sorted_pred: + # breakpoint() + + # # for pre-training only + # del data["spans"] + + return data diff --git a/src/udi/__init__.py b/src/udi/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..535ca0ac8e86273765bceec6ac221c3bb9b39f0f --- /dev/null +++ b/src/udi/__init__.py @@ -0,0 +1,38 @@ +# udi-v1: universal data interface +{ + "id": "semeval.train.0", + "instruction": "instruction text", + "schema": { + "cls": ["class1", "class2"], + "ent": ["person", "location"], + "rel": ["birth in", "study in"], + "event": { + "event type (attack)": ["roles like instrument", "attacker"], + "another type": ["role", "role"], + }, + }, + "ans": { + "cls": ["class1"], + "ent": [ + {"type": "person", "text": "1234", "span": [0, 4]} + ], # span: [start, end + 1] + "rel": [ + { + "relation": "study in", + "head": {"text": "1234", "span": [0, 4]}, + "tail": {"text": "1234", "span": [5, 9]}, + } + ], + "event": [ + { + "event_type": "attack", + "trigger": {"text": "hit", "span": [6, 9]}, + "args": [{"role": "instrument", "text": "ax", "span": [8, 10]}], + } + ], + "span": [{"text": "machine learning", "span": [16, 32]}], + }, + # DONE: whether or not to concatenate instruction with text (v2) + "text": "plain text", + "bg": "background text", +} diff --git a/src/udi/check.py b/src/udi/check.py new file mode 100644 index 0000000000000000000000000000000000000000..07a15e03c1452e66a743abeb495eb4e1ea956afa --- /dev/null +++ b/src/udi/check.py @@ -0,0 +1,139 @@ +from rex.utils.io import load_jsonlines + + +def check_udi_instance(instance: dict): + assert isinstance(instance["id"], str) + assert isinstance(instance["instruction"], str) + assert isinstance(instance["schema"], dict) + for key in instance["schema"]: + assert key in ["cls", "ent", "rel", "event"] + if key in ["cls", "ent", "rel"]: + assert isinstance(instance["schema"][key], list) and all( + isinstance(x, str) for x in instance["schema"][key] + ) + elif key == "event": + assert isinstance(instance["schema"][key], dict) + for event_type in instance["schema"][key]: + assert isinstance(instance["schema"][key][event_type], list) and all( + isinstance(x, str) for x in instance["schema"][key][event_type] + ) + else: + raise ValueError + assert isinstance(instance["ans"], dict) + for key in instance["ans"]: + assert key in ["cls", "ent", "rel", "event", "span"] + if key == "cls": + assert isinstance(instance["ans"][key], list) and all( + isinstance(x, str) for x in instance["ans"][key] + ) + elif key == "ent": + assert isinstance(instance["ans"][key], list) and all( + isinstance(x, dict) for x in instance["ans"][key] + ) + for ent in instance["ans"][key]: + assert ( + isinstance(ent["type"], str) + and ent["type"] in instance["schema"]["ent"] + ) + assert ( + isinstance(ent["text"], str) + and instance["text"][ent["span"][0] : ent["span"][1]] == ent["text"] + ) + assert ( + isinstance(ent["span"], list) + and len(ent["span"]) == 2 + and all(isinstance(x, int) for x in ent["span"]) + ) + elif key == "rel": + assert isinstance(instance["ans"][key], list) and all( + isinstance(x, dict) for x in instance["ans"][key] + ) + for rel in instance["ans"][key]: + assert ( + isinstance(rel["relation"], str) + and rel["relation"] in instance["schema"]["rel"] + ) + assert ( + isinstance(rel["head"], dict) + and instance["text"][ + rel["head"]["span"][0] : rel["head"]["span"][1] + ] + == rel["head"]["text"] + ) + assert ( + isinstance(rel["tail"], dict) + and instance["text"][ + rel["tail"]["span"][0] : rel["tail"]["span"][1] + ] + == rel["tail"]["text"] + ) + elif key == "event": + assert isinstance(instance["ans"][key], list) and all( + isinstance(x, dict) for x in instance["ans"][key] + ) + for event in instance["ans"][key]: + assert event["event_type"] in instance["schema"]["event"] + assert ( + isinstance(event["trigger"], dict) + and event["trigger"]["text"] in instance["text"] + and instance["text"][ + event["trigger"]["span"][0] : event["trigger"]["span"][1] + ] + == event["trigger"]["text"] + ) + for arg in event["args"]: + assert ( + arg["role"] in instance["schema"]["event"][event["event_type"]] + ) + assert ( + isinstance(arg["text"], str) + and instance["text"][arg["span"][0] : arg["span"][1]] + == arg["text"] + ) + elif key == "span": + assert isinstance(instance["ans"][key], list) and all( + isinstance(x, dict) for x in instance["ans"][key] + ) + for span in instance["ans"][key]: + assert ( + isinstance(span["text"], str) + and instance["text"][span["span"][0] : span["span"][1]] + == span["text"] + ) + else: + raise ValueError + assert isinstance(instance["text"], str) + assert isinstance(instance["bg"], str) + for key in ["ent", "rel", "event"]: + if instance["schema"].get(key): + assert len(instance["text"]) > 0 + if "span" in instance["ans"]: + assert len(instance["text"]) > 0 + assert instance["instruction"] or instance["text"] or instance["bg"] + + +def is_valid_udi_instance(instance: dict): + ok = True + try: + check_udi_instance(instance) + except: + ok = False + return ok + + +def main(): + filepaths = [] + for filepath in filepaths: + data = load_jsonlines(filepath) + data_ok = True + for ins in data: + ok = is_valid_udi_instance(ins) + if not ok: + data_ok = False + break + if not data_ok: + print(filepath) + + +if __name__ == "__main__": + main() diff --git a/src/utils.py b/src/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e7f5516fdea1457db640645847db0a842dfda398 --- /dev/null +++ b/src/utils.py @@ -0,0 +1,320 @@ +from collections import defaultdict + +import torch +from rex.utils.iteration import windowed_queue_iter +from rex.utils.position import find_all_positions + + +def find_paths_from_adj_mat(adj_mat: torch.Tensor) -> list[tuple[int]]: + assert adj_mat.shape[0] == adj_mat.shape[1] and len(adj_mat.shape) == 2 + + paths = [] + self_loops = set() + adj_map = defaultdict(set) + rev_adj_map = defaultdict(set) + # current -> next + for c, n in adj_mat.detach().nonzero().tolist(): + # self-loop + if c == n: + self_loops.add(c) + else: + adj_map[c].add(n) + # reversed map + rev_adj_map[n].add(c) + for self_loop_node in self_loops: + paths.append((self_loop_node,)) + + def track(path: tuple[int], c: int): + visited: set[tuple[int]] = set() + stack = [(path, c)] + while stack: + path, c = stack.pop() + if c in adj_map: + for n in adj_map[c]: + if (c, n) in visited: + continue + visited.add((c, n)) + stack.append((path + (c,), n)) + # else: + if path: + paths.append(path + (c,)) + + # def track(path: tuple[int], c: int, visited: set[tuple[int]]): + # if c in adj_map: + # for n in adj_map[c]: + # if (c, n) in visited: + # continue + # visited.add((c, n)) + # track(path + (c,), n, visited) + # else: + # if path: + # paths.append(path + (c,)) + + # # # include loops + # # if path not in paths and all(not set(path).issubset(p) for p in paths): + # # paths.append(path) + + start_nodes = set(adj_map.keys()) - set(rev_adj_map.keys()) + for c in start_nodes: + ns = adj_map[c] + for n in ns: + track((c,), n) + + return paths + + +def encode_nnw_thw_mat( + spans: list[tuple[int]], seq_len: int, nnw_id: int = 0, thw_id: int = 1 +) -> torch.Tensor: + mat = torch.zeros(2, seq_len, seq_len) + for span in spans: + if len(span) == 1: + mat[:, span[0], span[0]] = 1 + else: + for s, e in windowed_queue_iter(span, 2, 1, drop_last=True): + mat[nnw_id, s, e] = 1 + mat[thw_id, span[-1], span[0]] = 1 + return mat + + +def decode_nnw_thw_mat( + batch_mat: torch.LongTensor, + nnw_id: int = 0, + thw_id: int = 1, + offsets: list[int] = None, +) -> list[list[tuple[int]]]: + """Decode NNW THW matrix into a list of spans + + Args: + matrix: (batch_size, 2, seq_len, seq_len) + """ + ins_num, cls_num, seq_len1, seq_len2 = batch_mat.shape + assert seq_len1 == seq_len2 + assert cls_num == 2 + + result_batch = [] + for ins_id in range(ins_num): + offset = offsets[ins_id] if offsets else 0 + ins_span_paths = [] + # ins_mat: (2, seq_len, seq_len) + ins_mat = batch_mat[ins_id] + nnw_paths = find_paths_from_adj_mat(ins_mat[nnw_id, ...]) + end_start_to_paths = defaultdict(set) + for path in nnw_paths: + end_start_to_paths[(path[-1], path[0])].add(path) + thw_pairs = ins_mat[thw_id, ...].detach().nonzero().tolist() + # reversed match, end -> start + for e, s in thw_pairs: + for path in end_start_to_paths[(e, s)]: + ins_span_paths.append(tuple(i - offset for i in path)) + result_batch.append(ins_span_paths) + + return result_batch + + +def decode_pointer_mat( + batch_mat: torch.LongTensor, offsets: list[int] = None +) -> list[list[tuple[int]]]: + batch_paths = [] + for i in range(len(batch_mat)): + offset = offsets[i] if offsets else 0 + coordinates = (batch_mat[i, 0] == 1).nonzero().tolist() + paths = [] + for s, e in coordinates: + path = tuple(range(s - offset, e + 1 - offset)) + paths.append(path) + batch_paths.append(paths) + return batch_paths + + +def encode_nnw_nsw_thw_mat( + spans: list[list[tuple[int]]], + seq_len: int, + nnw_id: int = 0, + nsw_id: int = 1, + thw_id: int = 2, +) -> torch.Tensor: + mat = torch.zeros(3, seq_len, seq_len) + for parts in spans: + span = () + for p_i, part in enumerate(parts): + if not all(0 <= el <= seq_len - 1 for el in part): + continue + span += part + if p_i < len(parts) - 1 and 0 <= parts[p_i + 1][0] <= seq_len - 1: + # current part to next part + mat[nsw_id, parts[p_i][-1], parts[p_i + 1][0]] = 1 + if len(span) == 1: + mat[:, span[0], span[0]] = 1 + elif len(span) > 1: + for s, e in windowed_queue_iter(span, 2, 1, drop_last=True): + mat[nnw_id, s, e] = 1 + if span: + mat[thw_id, span[-1], span[0]] = 1 + return mat + + +def split_tuple_by_positions(nums, positions) -> list: + """ + Examples: + >>> nums = (1, 2, 3, 4, 5, 6, 7, 8, 9, 10) + >>> positions = [2, 5, 7] + >>> split_tuple_by_positions(nums, positions) + ((1, 2), (3, 4, 5), (6, 7), (8, 9, 10)) + """ + # Check if the given positions are valid + if not all(p < len(nums) for p in positions): + raise ValueError("Invalid positions") + + # Add 0 and len(nums) to the list of positions + positions = [0] + sorted(positions) + [len(nums)] + + # Split the tuple into multiple tuples based on the positions + result = [] + for i in range(1, len(positions)): + start = positions[i - 1] + end = positions[i] + result.append(nums[start:end]) + + return result + + +def decode_nnw_nsw_thw_mat( + batch_mat: torch.LongTensor, + nnw_id: int = 0, + nsw_id: int = 1, + thw_id: int = 2, + offsets: list[int] = None, +) -> list[list[tuple[int]]]: + """Decode NNW NSW THW matrix into a list of spans + One span has multiple parts + + Args: + batch_mat: (batch_size, 3, seq_len, seq_len) + """ + ins_num, cls_num, seq_len1, seq_len2 = batch_mat.shape + assert seq_len1 == seq_len2 + assert cls_num == 3 + + result_batch = [] + for ins_id in range(ins_num): + offset = offsets[ins_id] if offsets else 0 + ins_span_paths = set() + # ins_mat: (2, seq_len, seq_len) + ins_mat = batch_mat[ins_id] + nsw_connections = { + (part1e, part2s) + for part1e, part2s in ins_mat[nsw_id, ...].detach().nonzero().tolist() + } + nnw_paths = find_paths_from_adj_mat(ins_mat[nnw_id, ...]) + end_start_to_paths = defaultdict(set) + for path in nnw_paths: + end_start_to_paths[(path[-1], path[0])].add(path) + thw_pairs = ins_mat[thw_id, ...].detach().nonzero().tolist() + # reversed match, end -> start + for e, s in thw_pairs: + for path in nnw_paths: + if s in path: + sub_path = path[path.index(s) :] + if e in sub_path: + sub_path = sub_path[: sub_path.index(e) + 1] + chain = tuple(i - offset for i in sub_path) + parts = [] + all_sep_positions = set() + # cut path into multiple spans if there are skip links + if len(chain) > 1: + for sep in nsw_connections: + sep = tuple(i - offset for i in sep) + positions = find_all_positions(list(chain), list(sep)) + if positions: + # +1: (5, 6, 269) with (6, 269) as sep, found position is 1, + # while we want to split after 6, which needs +1 + positions = {p[0] + 1 for p in positions} + all_sep_positions.update(positions) + parts = split_tuple_by_positions(chain, all_sep_positions) + if not parts: + parts = [chain] + ins_span_paths.add(tuple(parts)) + result_batch.append(list(ins_span_paths)) + + return result_batch + + +# def encode_nnw_nsw_thw_mat( +# spans: list[list[tuple[int]]], +# seq_len: int, +# nnw_id: int = 0, +# nsw_id: int = 1, +# thw_id: int = 2, +# ) -> torch.Tensor: +# mat = torch.zeros(3, seq_len, seq_len) +# for span in spans: +# for p_i, part in enumerate(span): +# if len(part) == 1: +# mat[:, part[0], part[0]] = 1 +# else: +# for s, e in windowed_queue_iter(part, 2, 1, drop_last=True): +# mat[nnw_id, s, e] = 1 +# if p_i < len(span) - 1: +# # current part to next part +# mat[nsw_id, span[p_i][-1], span[p_i + 1][0]] = 1 +# mat[thw_id, span[-1][-1], span[0][0]] = 1 +# return mat + + +# def decode_nnw_nsw_thw_mat( +# batch_mat: torch.LongTensor, +# nnw_id: int = 0, +# nsw_id: int = 1, +# thw_id: int = 2, +# offsets: list[int] = None, +# ) -> list[list[tuple[int]]]: +# """Decode NNW NSW THW matrix into a list of spans +# One span has multiple parts + +# Args: +# batch_mat: (batch_size, 3, seq_len, seq_len) +# """ + +# ins_num, cls_num, seq_len1, seq_len2 = batch_mat.shape +# assert seq_len1 == seq_len2 +# assert cls_num == 2 + +# result_batch = [] +# for ins_id in range(ins_num): +# offset = offsets[ins_id] if offsets else 0 +# ins_span_paths = [] +# # ins_mat: (3, seq_len, seq_len) +# ins_mat = batch_mat[ins_id] +# nnw_paths = find_paths_from_adj_mat(ins_mat[nnw_id, ...]) + +# path_index = {"s": defaultdict(set), "e": defaultdict(set)} +# for path in nnw_paths: +# s = path[0] +# e = path[-1] +# path_index["s"][s].add(path) +# path_index["e"][e].add(path) + +# nsw_connections = {(part1e, part2s) for part1e, part2s in ins_mat[nsw_id, ...].detach().nonzero().tolist()} +# thw_connections = {(span_e, span_s) for span_e, span_s in ins_mat[thw_id, ...].detach().nonzero().tolist()} +# for e, s in thw_connections: + + +# path_span_combinations = [] +# for part1_e, part2_s in nsw_connections: +# part1s = path_index["e"][part1_e] +# part2s = path_index["s"][part2_s] +# # for part1 in part1s: +# # for part2 in part2s: +# # if () + +# end_start_to_paths = defaultdict(set) +# for path in nnw_paths: +# end_start_to_paths[(path[-1], path[0])].add(path) +# # reversed match, end -> start +# for e, s in thw_pairs: +# for path in end_start_to_paths[(e, s)]: +# ins_span_paths.append(tuple(i - offset for i in path)) +# result_batch.append(ins_span_paths) + +# return result_batch diff --git a/src/wait.py b/src/wait.py new file mode 100644 index 0000000000000000000000000000000000000000..adb992ce66bb549841b24b0492b73ccbf41ab32d --- /dev/null +++ b/src/wait.py @@ -0,0 +1,47 @@ +import argparse +import random +import string +import sys + +from watchmen import WatchClient + + +def parse_args(in_args=None): + arg_parser = argparse.ArgumentParser() + arg_parser.add_argument("--task_name", type=str, required=True, help="Take Name") + arg_parser.add_argument("--cuda", type=str, required=True, help="cuda to be waited") + arg_parser.add_argument( + "--req_gpu_num", + type=int, + required=False, + default=1, + help="request number of gpus", + ) + arg_parser.add_argument( + "--wait", + choices=["schedule", "queue", "none"], + default="none", + help="scheduling/queue wait", + ) + arg_info = arg_parser.parse_args(args=in_args) + return arg_info + + +if __name__ == "__main__": + in_argv = parse_args() + if in_argv.wait == "none": + sys.exit(0) + random_id = "-" + "".join(random.sample(string.ascii_letters + string.digits, 8)) + exp_id = in_argv.task_name + random_id + watch_client = WatchClient( + id=exp_id, + gpus=eval(f"[{in_argv.cuda}]"), + server_host="localhost", + server_port=62333, + req_gpu_num=in_argv.req_gpu_num, + mode=in_argv.wait, + timeout=60, + ) + available_gpus = watch_client.wait() + available_gpus = [str(x) for x in available_gpus] + print(",".join(available_gpus)) diff --git a/tox.ini b/tox.ini new file mode 100644 index 0000000000000000000000000000000000000000..b8d0f8abf145d51bc3d27a024717321e1ece6bbc --- /dev/null +++ b/tox.ini @@ -0,0 +1,12 @@ +[flake8] +ignore= + # line length + E501, + # whitespace before ':' + E203, + # line break before binary operator + W503, + # import but not used + F401 +exclude= + debug.py