diff --git a/.idea/SummerTime.iml b/.idea/SummerTime.iml new file mode 100644 index 0000000000000000000000000000000000000000..d0876a78d06ac03b5d78c8dcdb95570281c6f1d6 --- /dev/null +++ b/.idea/SummerTime.iml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000000000000000000000000000000000000..3fcc36dd4ca32e866b8b116d958c417702790305 --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,16 @@ + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000000000000000000000000000000000000..105ce2da2d6447d11dfe32bfb846c3d5b199fc99 --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000000000000000000000000000000000000..7e79ffeeee02ce384410d0e7a1e1a7799fdc46c7 --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..37af4f234f3acdeb1851e68efc3f1f017e455116 --- /dev/null +++ b/LICENSE @@ -0,0 +1,202 @@ + Apache License + Version 2.0, January 2004 + https://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + sourc e, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2021 SummerTime + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + diff --git a/README.md b/README.md index 69eca38503ab7e777d9feeb6410498196376036f..fc6323ee09de6db755ad7bee0bd3564aec16cdc2 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ --- title: SummerTime -emoji: 💩 +emoji: 🔥 colorFrom: purple colorTo: green sdk: gradio diff --git a/SummerTime.egg-info/PKG-INFO b/SummerTime.egg-info/PKG-INFO new file mode 100644 index 0000000000000000000000000000000000000000..5534ef587a94d02bd82a4f8e744b1c3218aa0d0c --- /dev/null +++ b/SummerTime.egg-info/PKG-INFO @@ -0,0 +1,124 @@ +Metadata-Version: 2.1 +Name: SummerTime +Version: 0.1 +Summary: A summarization mode +Home-page: https://github.com/LILYlab +Author: Ansong Ni, Murori Mutuma, Zhangir Azerbayev, Yusen Zhang, Tao Yu, Dragomir Radev +Author-email: ansong.ni@yale.edu, murorimutuma@gmail.com, zhangir.azerbayev@yale.edu +License: UNKNOWN +Description: # SummerTime + + A library to help users choose appropriate summarization tools based on their specific tasks or needs. Includes models, evaluation metrics, and datasets. + + + + ## Installation and setup + + #### Create and activate a new `conda` environment: + ```bash + conda create -n st python=3.7 + conda activate st + ``` + + #### `pip` dependencies for local demo: + ```bash + pip install -r requirements.txt + ``` + + + + ## Quick Start + Imports model, initializes default model, and summarizes sample documents. + ```python + import model as st_model + + model = st_model.summarizer() + documents = [ + """ PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. + The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected + by the shutoffs which were expected to last through at least midday tomorrow.""" + ] + model.summarize(documents) + + # ["Ca lifornia's largest electricity provider has turned off power to hundreds of thousands of customers."] + ``` + + Also, please run `demo.ipynb` demo Jupyter notebook for more examples. To start demo Jupyter notebook on localhost: + ```bash + jupyter notebook demo.ipynb + ``` + + + + ## Models + Import and initialization: + ```python + import model as st_model + + default_model = std_model.summarizer() + bart_model = std_model.bart_model.BartModel() + pegasus_model = std_model.pegasus_model.PegasusModel() + lexrank_model = std_model.lexrank_model.LexRankModel() + textrank_model = st_model.textrank_model.TextRankModel() + ``` + + All models can be initialized with the following optional options: + ```python + def __init__(self, + trained_domain: str=None, + max_input_length: int=None, + max_output_length: int=None, + ): + ``` + + All models implement the following methods: + ```python + def summarize(self, + corpus: Union[List[str], List[List[str]]], + queries: List[str]=None) -> List[str]: + + def show_capability(cls) -> None: + + def generate_basic_description(cls) -> str: + ``` + + + + ## Evaluation + Import and initialization: + ```python + import eval as st_eval + + bert_eval = st_eval.bertscore() + bleu_eval = st_eval.bleu_eval() + rouge_eval = st_eval.rouge() + rougewe_eval = st_eval.rougewe() + ``` + + All evaluation metrics can be initialized with the following optional arguments: + ```python + def __init__(self, metric_name): + ``` + + All evaluation metric objects implement the following methods: + ```python + def evaluate(self, model, data): + + def get_dict(self, keys): + ``` + + + ## Datasets + Import and initialization: + ```python + import dataset.stdatasets as st_data + ``` + + ## Contributors + This repository is built by the [LILY Lab](https://yale-lily.github.io/) at Yale University, led by Prof. [Dragomir Radev](https://cpsc.yale.edu/people/dragomir-radev). The main contributors are [Ansong Ni](https://niansong1996.github.io), Zhangir Azerbayev, Troy Feng, Murori Mutuma and Yusen Zhang (Penn State). For comments and question, please open an issue. + +Platform: UNKNOWN +Classifier: Programming Language :: Python :: 3 +Classifier: License :: OSI Approved :: MIT License +Classifier: Operating System :: OS Independent +Description-Content-Type: text/markdown diff --git a/SummerTime.egg-info/SOURCES.txt b/SummerTime.egg-info/SOURCES.txt new file mode 100644 index 0000000000000000000000000000000000000000..b8fa59856c06af4a865bdde6b62bc49d51806867 --- /dev/null +++ b/SummerTime.egg-info/SOURCES.txt @@ -0,0 +1,46 @@ +README.md +setup.py +summertime.py +SummerTime.egg-info/PKG-INFO +SummerTime.egg-info/SOURCES.txt +SummerTime.egg-info/dependency_links.txt +SummerTime.egg-info/top_level.txt +dataset/__init__.py +dataset/datasets_demo.py +dataset/huggingface_datasets.py +dataset/non_huggingface_datasets.py +dataset/st_dataset.py +evaluation/__init__.py +evaluation/base_metric.py +evaluation/bertscore_metric.py +evaluation/bleu_metric.py +evaluation/meteor_metric.py +evaluation/rouge_metric.py +evaluation/rougewe_metric.py +evaluation/summeval_metric.py +model/__init__.py +model/base_model.py +model/defaults.py +model/dialogue/__init__.py +model/dialogue/hmnet_model.py +model/multi_doc/__init__.py +model/multi_doc/base_multi_doc_model.py +model/multi_doc/multi_doc_joint_model.py +model/multi_doc/multi_doc_separate_model.py +model/query_based/__init__.py +model/query_based/base_query_based_model.py +model/query_based/bm25_model.py +model /query_based/tf_idf_model.py +model/single_doc/__init__.py +model/single_doc/bart_model.py +model/single_doc/base_single_doc_model.py +model/single_doc/lexrank_model.py +model/single_doc/longformer_model.py +model/single_doc/pegasus_model.py +model/single_doc/textrank_model.py +tests/__init__.py +tests/dataset_test.py +tests/demo_test.py +tests/evaluation_test.py +tests/integration_test.py +tests/model_test.py \ No newline at end of file diff --git a/SummerTime.egg-info/dependency_links.txt b/SummerTime.egg-info/dependency_links.txt new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/SummerTime.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/SummerTime.egg-info/top_level.txt b/SummerTime.egg-info/top_level.txt new file mode 100644 index 0000000000000000000000000000000000000000..d841e5baccc91f161341cf4a7571bd1f5b62f404 --- /dev/null +++ b/SummerTime.egg-info/top_level.txt @@ -0,0 +1,4 @@ +dataset +evaluation +model +tests diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8a4686f47d52f87f9561f7a9182a2e91d8cb1e0d --- /dev/null +++ b/__init__.py @@ -0,0 +1,3 @@ +import SummerTime.model +import SummerTime.dataset.st_dataset as data +import SummerTime.evaluation diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..7d678129a2bcdbfe49452d42d6e3baba3d78c298 --- /dev/null +++ b/app.py @@ -0,0 +1,28 @@ +import os +import model as st_model +import gradio as gr + + +model = st_model.summarizer() + +def inference(text): + documents = [text] + model.summarize(documents) + return model.summarize(documents)[0] + +title = "SummerTime: Text Summarization for Non-Experts" +description = "This is a demo of SummerTime: An open-source text summarization toolkit for non-experts. You can read more about the project at the links below. Input your text below (or click one of the examples to load them), and the model will generate a summary for it." +article = "

SummerTime: Text Summarization Toolkit for Non-experts | Github Repo | Colab Notebook

" + +gr.Interface( + inference, + [gr.inputs.Textbox(label="Input", lines=20)], + gr.outputs.Textbox(label="Output"), + title=title, + description=description, + article=article, + examples=[["""PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. + The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected + by the shutoffs which were expected to last through at least midday tomorrow."""], + ["""Representative Kevin McCarthy, the House Republican leader, has threatened to retaliate against any company that complies with the congressional committee investigating the Jan. 6 riot, after the panel asked dozens of firms to preserve the phone and social media records of 11 far-right members of Congress who pushed to overturn the results of the 2020 election. Mr. McCarthy’s warning was an escalation of his efforts to thwart a full accounting of the deadly attack at the Capitol carried out by a pro-Trump mob, and his latest attempt to insulate the former president and Republican lawmakers from scrutiny of any ties to the violence. It came after he led the G.O.P. opposition to the creation of an independent bipartisan commission to investigate the riot, and then pulled five Republican congressmen from the select committee that Democrats created on their own, boycotting the proceedings."""], + ["""Asked about the report, Google responded in an email that its "advertising technologies help websites and apps fund their content, enable small businesses to grow, and protect users from exploitative privac y practices and bad ad experiences." A lawsuit by 38 U.S. states and territories accuses Google of abusing its market power in an effort to make its search engine as dominant inside cars, TVs and speakers as it is in phones. This was consolidated with the federal lawsuit for purposes of discovery. Texas, backed by other states, filed a separate lawsuit against Google, accusing it of breaking antitrust law in how it runs its online advertising business."""]]).launch(debug=True) \ No newline at end of file diff --git a/build/scripts-3.9/summertime b/build/scripts-3.9/summertime new file mode 100755 index 0000000000000000000000000000000000000000..2bbe1b6a2b83f4f515c94f4c9109b0e3d47706e6 --- /dev/null +++ b/build/scripts-3.9/summertime @@ -0,0 +1,3 @@ +#!python + +print("welcome to Summer Time!") diff --git a/dataset/__init__.py b/dataset/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bbab0876cdedc94df38fe37e182772c33b7bf8b8 --- /dev/null +++ b/dataset/__init__.py @@ -0,0 +1,36 @@ +from dataset.dataset_loaders import ( + CnndmDataset, + MultinewsDataset, + SamsumDataset, + XsumDataset, + PubmedqaDataset, + MlsumDataset, + ScisummnetDataset, + SummscreenDataset, + QMsumDataset, + ArxivDataset, +) + + +SUPPORTED_SUMM_DATASETS = [ + CnndmDataset, + MultinewsDataset, + SamsumDataset, + XsumDataset, + PubmedqaDataset, + MlsumDataset, + ScisummnetDataset, + SummscreenDataset, + QMsumDataset, + ArxivDataset, +] + + +def list_all_datasets(): + all_datasets = [] + for ds in SUPPORTED_SUMM_DATASETS: + dataset_description = ds.generate_basic_description() + + all_datasets.append((ds.dataset_name, dataset_description)) + + return all_datasets diff --git a/dataset/dataset_loaders.py b/dataset/dataset_loaders.py new file mode 100644 index 0000000000000000000000000000000000000000..f0f1e0637181447dcf76afdc0733009570ad58a9 --- /dev/null +++ b/dataset/dataset_loaders.py @@ -0,0 +1,501 @@ +from os import path +from tqdm import tqdm +from typing import List, Generator, Optional, Union + +from datasets import Dataset + +from dataset.st_dataset import SummInstance, SummDataset + + +# Set directory to load non_huggingface dataset scripts +FILE_DIRECTORY_PATH = path.dirname(path.realpath(__file__)) +BASE_NONHUGGINGFACE_DATASETS_PATH = path.join( + FILE_DIRECTORY_PATH, "non_huggingface_datasets_builders" +) + + +# Huggingface Datasets + + +class CnndmDataset(SummDataset): + """ + The CNN/DM dataset + """ + + dataset_name = "CNN/DailyMail" + + is_query_based = False + is_dialogue_based = False + is_multi_document = False + + huggingface_dataset = True + huggingface_page = "https://huggingface.co/datasets/cnn_dailymail" + + def __init__(self): + super().__init__( + dataset_args=( + "cnn_dailymail", + "3.0.0", + ) + ) + + def _process_data(self, data: Dataset) -> Generator[SummInstance, None, None]: + """ + Overrides the SummDataset '_process_data()' method + This method processes the data contained in the dataset + and puts each data instance into a SummInstance object + :param dataset: a train/validation/test dataset + :rtype: a generator yielding SummInstance objects + """ + for instance in tqdm(data): + article: str = instance["article"] + highlights: str = instance["highlights"] + summ_instance = SummInstance(source=article, summary=highlights) + + yield summ_instance + + +class MultinewsDataset(SummDataset): + """ + The Multi News dataset + """ + + dataset_name = "Multinews" + + is_query_based = False + is_dialogue_based = False + is_multi_document = True + + huggingface_dataset = True + huggingface_page = "https://huggingface.co/datasets/multi_news" + + def __init__(self): + super().__init__(dataset_args=("multi_news",)) + + def _process_data(self, data: Dataset) -> Generator[SummInstance, None, None]: + """ + Overrides the SummDataset '_process_data()' method + This method processes the data contained in the dataset + and puts each data instance into a SummInstance object + :param dataset: a train/validation/test dataset + :rtype: a generator yielding SummInstance objects + """ + for instance in tqdm(data): + document: list = [ + doc for doc in instance["document"].split("|||||") if doc + ] # removes the empty string generated + # since each doc ends with the delimiting token '|||||' + # the final doc creates an empty string + summary: str = instance["summary"] + summ_instance = SummInstance(source=document, summary=summary) + + yield summ_instance + + +class SamsumDataset(SummDataset): + """ + The SAMsum Dataset + """ + + dataset_name = "Samsum" + + is_query_based = False + is_dialogue_based = True + is_multi_document = False + + huggingface_dataset = True + huggingface_page = "https://huggingface.co/datasets/samsum" + + def __init__(self): + super().__init__(dataset_args=("samsum",)) + + def _process_data(self, data: Dataset) -> Generator[SummInstance, None, None]: + """ + Overrides the SummDataset '_process_data()' method + This method processes the data contained in the dataset + and puts each data instance into a SummInstance object + :param dataset: a train/validation/test dataset + :rtype: a generator yielding SummInstance objects + """ + for instance in tqdm(data): + dialogue: List = instance["dialogue"].split( + "\r\n" + ) # split each dialogue into a list of strings such as + # ["speaker1 : utter..", "speaker2 : utter..."] + summary: str = instance["summary"] + summ_instance = SummInstance(source=dialogue, summary=summary) + + yield summ_instance + + +class XsumDataset(SummDataset): + """ + The Xsum Dataset + """ + + dataset_name = "Xsum" + + huggingface_dataset = True + huggingface_page = "https://huggingface.co/datasets/xsum" + + is_query_based = False + is_dialogue_based = False + is_multi_document = False + + def __init__(self): + super().__init__(dataset_args=("xsum",)) + + def _process_data(self, data: Dataset) -> Generator[SummInstance, None, None]: + """ + Overrides the SummDataset '_process_data()' method + This method processes the data contained in the dataset + and puts each data instance into a SummInstance object + :param dataset: a train/validation/test dataset + :rtype: a generator yielding SummInstance objects + """ + for instance in tqdm(data): + document: List = instance["document"] + summary: str = instance["summary"] + summ_instance = SummInstance(source=document, summary=summary) + + yield summ_instance + + +class PubmedqaDataset(SummDataset): + """ + The Pubmed QA dataset + """ + + dataset_name = "Pubmedqa" + + is_query_based = True + is_dialogue_based = False + is_multi_document = False + + huggingface_dataset = True + huggingface_page = "https://huggingface.co/datasets/pubmed_qa" + + def __init__(self, seed=None): + super().__init__( + dataset_args=( + "pubmed_qa", + "pqa_artificial", + ) + ) + + def _process_data(self, data: Dataset) -> Generator[SummInstance, None, None]: + """ + Overrides the SummDataset '_process_data()' method + This method processes the data contained in the dataset + and puts each data instance into a SummInstance object + :param dataset: a train/validation/test dataset + :rtype: a generator yielding SummInstance objects + """ + for instance in tqdm(data): + context: str = " ".join(instance["context"]["context s"]) + answer: str = instance["long_answer"] + query: str = instance["question"] + summ_instance = SummInstance(source=context, summary=answer, query=query) + + yield summ_instance + + +class MlsumDataset(SummDataset): + """ + The MLsum Dataset - A multi-lingual dataset featuring 5 languages + Includes 1.5 million news articles and their corresponding summaries + + "de" - German + "es" - Spanish + "fr" - French + "ru" - Russian + "tu" - Turkish + """ + + dataset_name = "MlSum" + + is_query_based = False + is_dialogue_based = False + is_multi_document = False + + huggingface_dataset = True + huggingface_page = "https://huggingface.co/datasets/mlsum" + supported_languages = ["de", "es", "fr", "ru", "tu"] + + mlsum_instantiation_guide = """The languages supported for the Mlsum Dataset are: + de - German + es - Spanish + fr - French + ru - Russian + tu - Turkish + + Examples to instantiate the dataset: + 1. Dataset with only one language + dataset = MlsumDataset({language_token}) + dataset = MlsumDataset("es") + dataset = MlsumDataset("tu")... + + 2. Dataset with a multiple languages + dataset = MlsumDataset({list of language_token}) + dataset = MlsumDataset(["es","de"]) + dataset = MlsumDataset(["es","de", "tu"])... + + 3. Dataset with all supported languages (default) + dataset = MlsumDataset(all) + dataset = MlsumDataset() + """ + + def __init__(self, languages: Optional[Union[str, List[str]]] = "all"): + super().__init__(dataset_args=(languages,)) + + def _load_dataset_safe(self, languages: Optional[Union[str, List[str]]]): + """ + Overrides the parent class method + Method loads multiple datasets of different languages provided in :param languages: + It then concatenates these datasets into one combined dataset + :rtype: datasetDict containing the combined dataset + :param languages: Optional, either a string or list of strings specifying the languages + to load + """ + print(MlsumDataset.mlsum_instantiation_guide) + + # Choose languages to download articles + if languages == "all": + selected_languages = MlsumDataset.supported_languages + elif isinstance(languages, list): + for language in languages: + assert self.is_supported(language) + selected_languages = languages + else: + assert self.is_supported(languages) + selected_languages = [languages] + + # Concatenate selected languaeges into one dataset + language_datasets = [] + for language in selected_languages: + dataset = super()._load_dataset_safe( + "mlsum", + language, + ) + + language_datasets.append(dataset) + + mlsum_dataset = self._concatenate_dataset_dicts(language_datasets) + + return mlsum_dataset + + def _process_data(self, data: Dataset) -> Generator[SummInstance, None, None]: + """ + Overrides the SummDataset '_process_data()' method + This method processes the data contained in the dataset + and puts each data instance into a SummInstance object + :param dataset: a train/validation/test dataset + :rtype: a generator yielding SummInstance objects + """ + for instance in tqdm(data): + article: List = instance["text"] + summary: str = instance["summary"] + summ_instance = SummInstance(source=article, summary=summary) + + yield summ_instance + + def is_supported(self, language: str): + """ + Checks whether the requested langues is supported + :param language: string containing the requested language + :rtype bool: + """ + if language not in MlsumDataset.supported_languages: + print(MlsumDataset.mlsum_instantiation_guide) + raise ValueError( + f"The language(s): '{language}' entered is not supported. See above message for usage info" + ) + else: + return True + + +# Non-huggingface datasets + + +class ScisummnetDataset(SummDataset): + """ + The SciSummNet dataset. As a dataset not included by huggingface, we need to do manually download, set basic + information for the dataset + """ + + dataset_name = "ScisummNet" + + version = "1.1.0" + description = ( + "A summary of scientific papers should ideally incorporate the impact of the papers on the " + "research community reflected by citations. To facilitate research in citation-aware scientific " + "paper summarization (Scisumm), the CL-Scisumm shared task has been organized since 2014 for " + "papers in the computational linguistics and NLP domain." + ) + + is_dialogue_based = False + is_multi_document = False + is_query_based = False + + huggingface_dataset = False + builder_script_path = path.join( + BASE_NONHUGGINGFACE_DATASETS_PATH, dataset_name.lower() + ".py" + ) + + def __init__(self, seed=None): + super().__init__() + + def _process_data(self, data: Dataset) -> Generator[SummInstance, None, None]: + """ + Overrides the SummDataset '_process_data()' method + This method processes the data contained in the dataset + and puts each data instance into a SummInstance object + :param dataset: a train/validation/test dataset + :rtype: a generator yielding SummInstance objects + """ + for instance in tqdm(data): + docs: List = [ + instance["document_xml"], + instance["citing_sentences_annotated.json"], + ] + summary: str = instance["summary"] + summ_instance = SummInstance(source=docs, summary=summary) + + yield summ_instance + + +class SummscreenDataset(SummDataset): + """ + The SummScreen dataset. As a dataset not included by huggingface, we need to do manually download, set basic + information for the dataset + """ + + dataset_name = "Summscreen" + + version = "1.1.0" + is_dialogue_based = True + is_multi_document = False + is_query_based = False + + huggingface_dataset = False + builder_script_path = path.join( + BASE_NONHUGGINGFACE_DATASETS_PATH, dataset_name.lower() + ".py" + ) + + def __init__(self, seed=None): + super().__init__() + + def _process_data(self, data: Dataset) -> Generator[SummInstance, None, None]: + """ + Overrides the SummDataset '_process_data()' method + This method processes the data contained in the dataset + and puts each data instance into a SummInstance object + :param dataset: a train/validation/test dataset + :rtype: a generator yielding SummInstance objects + """ + for instance in tqdm(data): + transcript: List = instance[ + "transcript" + ] # convert string into a list of string dialogues + recap: str = instance["recap"] + summ_instance = SummInstance(source=transcript, summary=recap) + + yield summ_instance + + +class QMsumDataset(SummDataset): + """ + QMSum Dataset + """ + + dataset_name = "QMsum" + description = """ + QMSum is a new human-annotated benchmark for query-based multi-domain meeting summarization task, + which consists of 1,808 query-summary pairs over 232 meetings in multiple domains. + """ + + is_dialogue_based = True + is_multi_document = False + is_query_based = True + + huggingface_dataset = False + builder_script_path = path.join( + BASE_NONHUGGINGFACE_DATASETS_PATH, dataset_name.lower() + ".py" + ) + + def __init__(self): + super().__init__() + + def _process_data(self, data: Dataset) -> Generator[SummInstance, None, None]: + """ + Overrides the SummDataset '_process_data()' method + This method processes the data contained in the dataset + and puts each data instance into a SummInstance object + :param dataset: a train/validation/test dataset + :rtype: a generator yielding SummInstance objects + """ + for instance in tqdm(data): + for query_set in ( + instance["general_query_list"] + instance["specific_query_list"] + ): + meeting: List = [ + utterance["speaker"] + " : " + utterance["content"] + for utterance in instance["meeting_transcripts"] + ] + query: str = query_set["query"] + summary: str = query_set["answer"] + summ_instance = SummInstance( + source=meeting, summary=summary, query=query + ) + + yield summ_instance + + +class ArxivDataset(SummDataset): + """ + The Arxiv Dataset + """ + + dataset_name = "Arxiv_longsummarization" + description = """ + A summarization dataset comprised of pairs of scientific papers. + The dataset provides a challenging testbed for abstractive summarization. + It contains papers and their abstracts. + """ + + is_dialogue_based = False + is_multi_document = False + is_query_based = False + + huggingface_dataset = False + builder_script_path = path.join( + BASE_NONHUGGINGFACE_DATASETS_PATH, dataset_name.lower() + ".py" + ) + + def __init__(self): + + print( + "*****************", + "***Attention***", + "This dataset is quite large (approx 5Gb and will need about 15 Gb for the extraction process", + "Cancel/interrupt the download if size and time constraints will not be met", + "*****************", + sep="\n", + ) + + super().__init__() + + def _process_data(self, data: Dataset) -> Generator[SummInstance, None, None]: + """ + Overrides the SummDataset '_process_data()' method + This method processes the data contained in the dataset + and puts each data instance into a SummInstance object + :param dataset: a train/validation/test dataset + :rtype: a generator yielding SummInstance objects + """ + for instance in tqdm(data): + article: List = instance["article_text"] + abstract: str = " ".join(instance["abstract_text"]) + summ_instance = SummInstance(source=article, summary=abstract) + + yield summ_instance diff --git a/dataset/non_huggingface_datasets_builders/arxiv_longsummarization.py b/dataset/non_huggingface_datasets_builders/arxiv_longsummarization.py new file mode 100644 index 0000000000000000000000000000000000000000..d88cb47755e3f3cd81777e1b38c918aa2046afcf --- /dev/null +++ b/dataset/non_huggingface_datasets_builders/arxiv_longsummarization.py @@ -0,0 +1,104 @@ +import os +import json +import datasets + + +"""Arxiv dataset.""" + + +_CITATION = """ +@article{Cohan_2018, + title={A Discourse-Aware Attention Model for Abstractive Summarization of + Long Documents}, + url={http://dx.doi.org/10.18653/v1/n18-2097}, + DOI={10.18653/v1/n18-2097}, + journal={Proceedings of the 2018 Conference of the North American Chapter of + the Association for Computational Linguistics: Human Language + Technologies, Volume 2 (Short Papers)}, + publisher={Association for Computational Linguistics}, + author={Cohan, Arman and Dernoncourt, Franck and Kim, Doo Soon and Bui, Trung and Kim, Seokhwan and Chang, Walter and Goharian, Nazli}, + year={2018} +} +""" + +_DESCRIPTION = """ +A summarization dataset comprised of pairs of scientific papers. +The dataset provides a challenging testbed for abstractive summarization. +It contains papers and their abstracts. +""" + +_HOMEPAGE = " https://github.com/armancohan/long-summarization" + +_LICENSE = "Apache-2.0 License" + +_URL = "https://archive.org/download/armancohan-long-summarization-paper-code/arxiv-dataset.zip" + + +class SummertimeArxiv(datasets.GeneratorBasedBuilder): + """Arxiv long summarization dataset.""" + + VERSION = datasets.Version("1.0.0") + + BUILDER_CONFIGS = [ + datasets.BuilderConfig(), + ] + + def _info(self): + features = datasets.Features( + { + "article_id": datasets.Value("string"), + "article_text": [datasets.Value("string")], + "abstract_text": [datasets.Value("string")], + } + ) + return datasets.DatasetInfo( + description=_DESCRIPTION, + features=features, + supervised_keys=None, + homepage=_HOMEPAGE, + license=_LICENSE, + citation=_CITATION, + ) + + def _split_generators(self, dl_manager): + """Returns SplitGenerators.""" + my_urls = _URL + path = dl_manager.download_and_extract(my_urls) + path = os.path.join(path, "arxiv-dataset") + + trainpath = os.path.join(path, "train.txt") + valpath = os.path.join(path, "val.txt") + testpath = os.path.join(path, "test.txt") + + return [ + datasets.SplitGenerator( + name=datasets.Split.TRAIN, + # These kwargs will be passed to _generate_examples + gen_kwargs={"filepath": trainpath, "split": "train"}, + ), + datasets.SplitGenerator( + name=datasets.Split.VALIDATION, + # These kwargs will be passed to _generate_examples + gen_kwargs={"filepath": valpath, "split": "val"}, + ), + datasets.SplitGenerator( + name=datasets.Split.TEST, + # These kwargs will be passed to _generate_examples + gen_kwargs={"filepath": testpath, "split": "test"}, + ), + ] + + def _generate_examples(self, filepath, split): + """Yields examples.""" + + with open(filepath, "r") as f: + for line in f: + + instance = json.loads(line) + + entry = {} + entry["article_id"] = instance["article_id"] + entry["article_text"] = instance["article_text"] + entry["abstract_text"] = instance["abstract_text"] + + yield entry["article_id"], entry diff --git a/dataset/non_huggingface_datasets_builders/qmsum.py b/dataset/non_huggingface_datasets_builders/qmsum.py new file mode 100644 index 0000000000000000000000000000000000000000..7d030c69495fcf1ee1b1b8dca1a56b95c39ca299 --- /dev/null +++ b/dataset/non_huggingface_datasets_builders/qmsum.py @@ -0,0 +1,119 @@ +import os +import json +import datasets + + +"""QMsum dataset.""" + + +_CITATION = """ +@inproceedings{zhong2021qmsum, + title={{QMS}um: {A} {N}ew {B}enchmark for {Q}uery-based {M}ulti-domain {M}eeting {S}ummarization}, + author={Zhong, Ming and Yin, Da and Yu, Tao and Zaidi, Ahmad and Mutuma, Mutethia and Jha, Rahul and Hassan Awadallah, Ahmed and Celikyilmaz, Asli and Liu, Yang and Qiu, Xipeng and Radev, Dragomir}, + booktitle={North American Association for Computational Linguistics (NAACL)}, + year={2021} +} +""" + +_DESCRIPTION = """ +QMSum is a new human-annotated benchmark for query-based multi-domain meeting summarization task, \ +which consists of 1,808 query-summary pairs over 232 meetings in multiple domains. +""" + +_HOMEPAGE = "https://github.com/Yale-LILY/QMSum" + +_BASE_URL = "https://raw.githubusercontent.com/Yale-LILY/QMSum/main/data/ALL/jsonl" +_URLs = { + "train": _BASE_URL + "/train.jsonl", + "val": _BASE_URL + "/val.jsonl", + "test": _BASE_URL + "/test.jsonl", +} + + +class SummertimeQmsum(datasets.GeneratorBasedBuilder): + """QMsum dataset.""" + + VERSION = datasets.Version("1.0.0") + + BUILDER_CONFIGS = [ + datasets.BuilderConfig(), + ] + + def _info(self): + features = datasets.Features( + { + "entry_number": datasets.Value("string"), + "meeting_transcripts": [ + { + "speaker": datasets.Value("string"), + "content": datasets.Value("string"), + } + ], + "general_query_list": [ + { + "query": datasets.Value("string"), + "answer": datasets.Value("string"), + } + ], + "specific_query_list": [ + { + "query": datasets.Value("string"), + "answer": datasets.Value("string"), + "relevant_text_span": [[datasets.Value("string")]], + } + ], + } + ) + return datasets.DatasetInfo( + description=_DESCRIPTION, + features=features, + supervised_keys=None, + homepage=_HOMEPAGE, + license=None, + citation=_CITATION, + ) + + def _split_generators(self, dl_manager): + """Returns SplitGenerators.""" + my_urls = _URLs + downloaded_files = dl_manager.download_and_extract(my_urls) + + trainpath = downloaded_files["train"] + valpath = downloaded_files["val"] + testpath = downloaded_files["test"] + + return [ + datasets.SplitGenerator( + name=datasets.Split.TRAIN, + # These kwargs will be passed to _generate_examples + gen_kwargs={"filepath": trainpath, "split": "train"}, + ), + datasets.SplitGenerator( + name=datasets.Split.VALIDATION, + # These kwargs will be passed to _generate_examples + gen_kwargs={"filepath": valpath, "split": "val"}, + ), + datasets.SplitGenerator( + name=datasets.Split.TEST, + # These kwargs will be passed to _generate_examples + gen_kwargs={"filepath": testpath, "split": "test"}, + ), + ] + + def _generate_examples(self, filepath, split): + """Yields examples.""" + + extraction_path = os.path.join(filepath) + + with open(extraction_path) as f: + for i, line in enumerate(f): + + instance = json.loads(line) + + entry = {} + entry["entry_number"] = split + "_" + str(i) + entry["meeting_transcripts"] = instance["meeting_transcripts"] + entry["general_query_list"] = instance["general_query_list"] + entry["specific_query_list"] = instance["specific_query_list"] + + yield entry["entry_number"], entry diff --git a/dataset/non_huggingface_datasets_builders/scisummnet.py b/dataset/non_huggingface_datasets_builders/scisummnet.py new file mode 100644 index 0000000000000000000000000000000000000000..0b6bcfb5bfc02e09be903d988ec45d0a0a06606e --- /dev/null +++ b/dataset/non_huggingface_datasets_builders/scisummnet.py @@ -0,0 +1,105 @@ +import os +import datasets + + +"""Scisummnet dataset.""" + + +_CITATION = """ +@InProceedings{yasunaga&al.19.scisumm, + title = {{ScisummNet}: A Large Annotated Corpus and Content-Impact Models for Scientific Paper Summarization with Citation Networks}, + author = {Michihiro Yasunaga and Jungo Kasai and Rui Zhang and Alexander Fabbri and Irene Li and Dan Friedman and Dragomir Radev}, + booktitle = {Proceedings of AAAI 2019}, + year = {2019} +} +@InProceedings{yasunaga&al.17, + title = {Graph-based Neural Multi-Document Summarization}, + author = {Yasunaga, Michihiro and Zhang, Rui and Meelu, Kshitijh and Pareek, Ayush and Srinivasan, Krishnan and Radev, Dragomir R.}, + booktitle = {Proceedings of CoNLL 2017}, + year = {2017} +} +""" + +_DESCRIPTION = """ +A summary of scientific papers should ideally incorporate the impact of the papers on the research community +refl ected by citations. To facilitate research in citation-aware scientific paper summarization (Scisumm), +the CL-Scisumm shared task has been organized since 2014 for papers in the computational linguistics and NLP domain. +""" + +_HOMEPAGE = "https://cs.stanford.edu/~myasu/projects/scisumm_net/" + +_LICENSE = "CC BY-SA 4.0" + +_URLs = "https://cs.stanford.edu/~myasu/projects/scisumm_net/scisummnet_release1.1__20190413.zip" + + +class SummertimeScisummnet(datasets.GeneratorBasedBuilder): + """Scisummnet dataset.""" + + VERSION = datasets.Version("1.1.0") + + BUILDER_CONFIGS = [ + datasets.BuilderConfig(), + ] + + def _info(self): + features = datasets.Features( + { + "entry_number": datasets.Value("string"), + "document_xml": datasets.Value("string"), + "citing_sentences_annotated.json": datasets.Value("string"), + "summary": datasets.Value("string"), + } + ) + return datasets.DatasetInfo( + description=_DESCRIPTION, + features=features, + supervised_keys=None, + homepage=_HOMEPAGE, + license=_LICENSE, + citation=_CITATION, + ) + + def _split_generators(self, dl_manager): + """Returns SplitGenerators.""" + my_urls = _URLs + path = dl_manager.download_and_extract(my_urls) + trainpath = os.path.join( + path, "scisummnet_release1.1__20190413", "top1000_complete" + ) + return [ + datasets.SplitGenerator( + name=datasets.Split.TRAIN, + # These kwargs will be passed to _generate_examples + gen_kwargs={"extraction_path": trainpath, "split": "train"}, + ) + ] + + def _generate_examples(self, extraction_path, split): + """Yields examples.""" + + for folder in os.listdir(extraction_path): + + entry = {} + + entry["entry_number"] = folder + + doc_xml_path = os.path.join( + extraction_path, folder, "Documents_xml", folder + ".xml" + ) + with open(doc_xml_path, "r", encoding="utf-8") as f: + entry["document_xml"] = f.read() + + cite_annot_path = os.path.join( + extraction_path, folder, "citing_sentences_annotated.json" + ) + with open(cite_annot_path, "r", encoding="utf-8") as f: + entry["citing_sentences_annotated.json"] = f.read() + + summary_path = os.path.join( + extraction_path, folder, "summary", folder + ".gold.txt" + ) + with open(summary_path, "r", encoding="utf-8") as f: + entry["summary"] = f.read() + + yield entry["entry_number"], entry diff --git a/dataset/non_huggingface_datasets_builders/summscreen.py b/dataset/non_huggingface_datasets_builders/summscreen.py new file mode 100644 index 0000000000000000000000000000000000000000..871b2fbaf273847aa6165b5f232fee6d1f568027 --- /dev/null +++ b/dataset/non_huggingface_datasets_builders/summscreen.py @@ -0,0 +1,123 @@ +import os +import json +import datasets + + +"""Summscreen dataset.""" + + +_CITATION = """ +@article{DBLP:journals/corr/abs-2104-07091, + author = {Mingda Chen and + Zewei Chu and + Sam Wiseman and + Kevin Gimpel}, + title = {SummScreen: {A} Dataset for Abstractive Screenplay Summarization}, + journal = {CoRR}, + volume = {abs/2104.07091}, + year = {2021}, + url = {https://arxiv.org/abs/2104.07091}, + archivePrefix = {arXiv}, + eprint = {2104.07091}, + timestamp = {Mon, 19 Apr 2021 16:45:47 +0200}, + biburl = {https://dblp.org/rec/journals/corr/abs-2104-07091.bib}, + bibsource = {dblp computer science bibliography, https://dblp.org} +} +""" + +_DESCRIPTION = """ +A summary of scientific papers should ideally incorporate the impact of the papers on the research community +reflected by citations. To facilitate research in citation -aware scientific paper summarization (Scisumm), +the CL-Scisumm shared task has been organized since 2014 for papers in the computational linguistics and NLP domain. +""" + +_HOMEPAGE = "https://github.com/mingdachen/SummScreen" + +_LICENSE = "MIT Licencse" + +_URLs = "https://drive.google.com/uc?id=1BvdIllGBo9d2-bzXQRzWuJXB04XPVmfF" + + +class SummertimeSummscreen(datasets.GeneratorBasedBuilder): + """Summscreen dataset.""" + + VERSION = datasets.Version("1.1.0") + + BUILDER_CONFIGS = [ + datasets.BuilderConfig(), + ] + + def _info(self): + features = datasets.Features( + { + "entry_number": datasets.Value("string"), + "transcript": datasets.features.Sequence(datasets.Value("string")), + "recap": datasets.Value("string"), + } + ) + return datasets.DatasetInfo( + description=_DESCRIPTION, + features=features, + supervised_keys=None, + homepage=_HOMEPAGE, + license=_LICENSE, + citation=_CITATION, + ) + + def _split_generators(self, dl_manager): + """Returns SplitGenerators.""" + my_urls = _URLs + path = dl_manager.download_and_extract(my_urls) + path = os.path.join(path, "SummScreen") + + trainpath_fd = os.path.join("ForeverDreaming", "fd_train.json") + trainpath_tms = os.path.join("TVMegaSite", "tms_train.json") + trainpaths = [trainpath_fd, trainpath_tms] + + devpath_fd = os.path.join("ForeverDreaming", "fd_dev.json") + devpath_tms = os.path.join("TVMegaSite", "tms_dev.json") + devpaths = [devpath_fd, devpath_tms] + + testpath_fd = os.path.join("ForeverDreaming", "fd_test.json") + testpath_tms = os.path.join("TVMegaSite", "tms_test.json") + testpaths = [testpath_fd, testpath_tms] + + return [ + datasets.SplitGenerator( + name=datasets.Split.TRAIN, + # These kwargs will be passed to _generate_examples + gen_kwargs={"filepaths": (path, trainpaths), "split": "train"}, + ), + datasets.SplitGenerator( + name=datasets.Split.VALIDATION, + # These kwargs will be passed to _generate_examples + gen_kwargs={"filepaths": (path, devpaths), "split": "dev"}, + ), + datasets.SplitGenerator( + name=datasets.Split.TEST, + # These kwargs will be passed to _generate_examples + gen_kwargs={"filepaths": (path, testpaths), "split": "test"}, + ), + ] + + def _generate_examples(self, filepaths, split): + """Yields examples.""" + + path, relative_filepaths = filepaths + for filepath in relative_filepaths: + + extraction_path = os.path.join(path, filepath) + + with open(extraction_path, "r") as f: + for line in f: + processed_line = line.replace("@@ ", "") + instance = json.loads(processed_line) + + entry = {} + entry["entry_number"] = instance["filename"] + entry["transcript"] = instance["Transcript"] + entry["recap"] = instance["Recap"][ + 0 + ] # Recap is a single string in list + + yield entry["entry_number"], entry diff --git a/dataset/st_dataset.py b/dataset/st_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..6f53c6a1dd945384fbed146fd8076d6eb4fdeb9e --- /dev/null +++ b/dataset/st_dataset.py @@ -0,0 +1,281 @@ +from abc import abstractmethod +from pprint import pformat +from time import sleep +from typing import List, Tuple, Optional, Union, Generator + +from datasets import ( + Dataset, + DatasetDict, + DatasetInfo, + concatenate_datasets, + load_dataset, +) + +# Defualt values for retrying dataset download +DEFAULT_NUMBER_OF_RETRIES_ALLOWED = 5 +DEFAULT_WAIT_SECONDS_BEFORE_RETRY = 5 + +# Default value for creating missing val/test splits +TEST_OR_VAL_SPLIT_RATIO = 0.1 + + +class SummInstance: + """ + Basic instance for summarization tasks + """ + + def __init__( + self, source: Union[List[str], str], summary: str, query: Optional[str] = None + ): + """ + Create a summarization instance + :rtype: object + :param source: either `List[str]` or `str`, depending on the dataset itself, string joining may needed to fit + into specific models. For example, for the same document, it could be simply `str` or `List[str]` for + a list of sentences in the same document + :param summary: a string summary that serves as ground truth + :param query: Optional, applies when a string query is present + """ + self.source = source + self.summary = summary + self.query = query + + def __repr__(self): + instance_dict = {"source": self.source, "summary": self.summary} + if self.query: + instance_dict["query"] = self.query + + return str(instance_dict) + + def __str__(self): + instance_dict = {"source": self.source, "summary": self.summary} + if self.query: + instance_dict["query"] = self.query + + return pformat(instance_dict, indent=1) + + +class SummDataset: + """ + Dataset class for summarization, which takes into account of the following tasks: + * Single document summarization + * Multi-document/Dialogue summarization + * Query-based summarization + """ + + def __init__( + self, dataset_args: Optional[Tuple[str]] = None, splitseed: Optional[int] = None + ): + """Create dataset information from the huggingface Dataset class + :rtype: object + :param dataset_args: a tuple containing arguments to passed on to the 'load_dataset_safe' method. + Only required for datasets loaded from the Huggingface library. + The arguments for each dataset are different and comprise of a string or multiple strings + :param splitseed: a number to instantiate the random generator used to generate val/test splits + for the datasets without them + """ + + # Load dataset from huggingface, use default huggingface arguments + if self.huggingface_dataset: + dataset = self._load_dataset_safe(*dataset_args) + # Load non-huggingface dataset, use custom dataset builder + else: + dataset = self._load_dataset_safe(path=self.builder_script_path) + + info_set = self._get_dataset_info(dataset) + + # Ensure any dataset with a val or dev or validation split is standardised to validation split + if "val" in dataset: + dataset["validation"] = dataset["val"] + dataset.remove("val") + elif "dev" in dataset: + dataset["validation"] = dataset["dev"] + dataset.remove("dev") + + # If no splits other other than training, generate them + assert ( + "train" in dataset or "validation" in dataset or "test" in dataset + ), "At least one of train/validation test needs to be not empty!" + + if not ("validation" in dataset or "test" in dataset): + dataset = self._generate_missing_val_test_splits(dataset, splitseed) + + self.description = info_set.description + self.citation = info_set.citation + self.homepage = info_set.homepage + + # Extract the dataset entries from folders and load into dataset + self._train_set = self._process_data(dataset["train"]) + self._validation_set = self._process_data( + dataset["validation"] + ) # Some datasets have a validation split + self._test_set = self._process_data(dataset["test"]) + + @property + def train_set(self) -> Union[Generator[SummInstance, None, None], List]: + if self._train_set is not None: + return self._train_set + else: + print( + f"{self.d ataset_name} does not contain a train set, empty list returned" + ) + return list() + + @property + def validation_set(self) -> Union[Generator[SummInstance, None, None], List]: + if self._validation_set is not None: + return self._validation_set + else: + print( + f"{self.dataset_name} does not contain a validation set, empty list returned" + ) + return list() + + @property + def test_set(self) -> Union[Generator[SummInstance, None, None], List]: + if self._test_set is not None: + return self._test_set + else: + print( + f"{self.dataset_name} does not contain a test set, empty list returned" + ) + return list() + + def _load_dataset_safe(self, *args, **kwargs) -> Dataset: + """ + This method creates a wrapper around the huggingface 'load_dataset()' function for a more robust download function, + the original 'load_dataset()' function occassionally fails when it cannot reach a server especially after multiple requests. + This method tackles this problem by attempting the download multiple times with a wait time before each retry + + The wrapper method passes all arguments and keyword arguments to the 'load_dataset' function with no alteration. + :rtype: Dataset + :param args: non-keyword arguments to passed on to the 'load_dataset' function + :param kwargs: keyword arguments to passed on to the 'load_dataset' function + """ + + tries = DEFAULT_NUMBER_OF_RETRIES_ALLOWED + wait_time = DEFAULT_WAIT_SECONDS_BEFORE_RETRY + + for i in range(tries): + try: + dataset = load_dataset(*args, **kwargs) + except ConnectionError: + if i < tries - 1: # i is zero indexed + sleep(wait_time) + continue + else: + raise RuntimeError( + "Wait for a minute and attempt downloading the dataset again. \ + The server hosting the dataset occassionally times out." + ) + break + + return dataset + + def _get_dataset_info(self, data_dict: DatasetDict) -> DatasetInfo: + """ + Get the information set from the dataset + The information set contains: dataset name, description, version, citation and licence + :param data_dict: DatasetDict + :rtype: DatasetInfo + """ + return data_dict["train"].info + + @abstractmethod + def _process_data(self, dataset: Dataset) -> Generator[SummInstance, None, None]: + """ + Abstract class method to process the data contained within each dataset. + Each dataset class processes it's own information differently due to the diversity in domains + This method processes the data contained in the dataset + and puts each data instance into a SummInstance object, + the SummInstance has the following properties [source, summary, query[optional]] + :param dataset: a train/validation/test dataset + :rtype: a generator yielding SummInstance objects + """ + return + + def _generate_missing_val_test_splits( + self, dataset_dict: DatasetDict, seed: int + ) -> DatasetDict: + """ + Creating the train, val and test splits from a dataset + the generated sets are 'train: ~.80', 'validation: ~.10', and 'test: ~10' in size + the splits are randomized for each object unless a seed is provided for the random generator + + :param dataset: Arrow Dataset with containing, usually the train set + :param seed: seed for the random generator to shuffle the dataset + :rtype: Arrow DatasetDict containing the three splits + """ + + # Return dataset if no train set available for splitting + if "train" not in dataset_dict: + if "validation" not in dat aset_dict: + dataset_dict["validation"] = None + if "test" not in dataset_dict: + dataset_dict["test"] = None + + return dataset_dict + + # Create a 'test' split from 'train' if no 'test' set is available + if "test" not in dataset_dict: + dataset_traintest_split = dataset_dict["train"].train_test_split( + test_size=TEST_OR_VAL_SPLIT_RATIO, seed=seed + ) + dataset_dict["train"] = dataset_traintest_split["train"] + dataset_dict["test"] = dataset_traintest_split["test"] + + # Create a 'validation' split from the remaining 'train' set if no 'validation' set is available + if "validation" not in dataset_dict: + dataset_trainval_split = dataset_dict["train"].train_test_split( + test_size=TEST_OR_VAL_SPLIT_RATIO, seed=seed + ) + dataset_dict["train"] = dataset_trainval_split["train"] + dataset_dict["validation"] = dataset_trainval_split["test"] + + return dataset_dict + + def _concatenate_dataset_dicts( + self, dataset_dicts: List[DatasetDict] + ) -> DatasetDict: + """ + Concatenate two dataset dicts with similar splits and columns tinto one + :param dataset_dicts: A list of DatasetDicts + :rtype: DatasetDict containing the combined data + """ + + # Ensure all dataset dicts have the same splits + setsofsplits = set(tuple(dataset_dict.keys()) for dataset_dict in dataset_dicts) + if len(setsofsplits) > 1: + raise ValueError("Splits must match for all datasets") + + # Concatenate all datasets into one according to the splits + temp_dict = {} + for split in setsofsplits.pop(): + split_set = [dataset_dict[split] for dataset_dict in dataset_dicts] + temp_dict[split] = concatenate_datasets(split_set) + + return DatasetDict(temp_dict) + + @classmethod + def generate_basic_description(cls) -> str: + """ + Automatically generate the basic description string based on the attributes + :rtype: string containing the description + :param cls: class object + """ + + basic_description = ( + f": {cls.dataset_name} is a " + f"{'query-based ' if cls.is_query_based else ''}" + f"{'dialogue ' if cls.is_dialogue_based else ''}" + f"{'multi-document' if cls.is_multi_document else 'single-document'} " + f"summarization dataset." + ) + + return basic_description + + def show_description(self): + """ + Print the description of the dataset. + """ + print(self.dataset_name, ":\n", self.description) diff --git a/dependencies.txt b/dependencies.txt new file mode 100644 index 0000000000000000000000000000000000000000..920980068e8eba046ccdac72d445120b983b9fd4 --- /dev/null +++ b/dependencies.txt @@ -0,0 +1,11 @@ +Migrate information to documentation/pypi for first release. + +Dependencies: +- lexrank +- sentencepiece +- torch +- transformers + +# datasets +- datasets +- py7zr \ No newline at end of file diff --git a/dist/SummerTime-0.1-py3-none-any.whl b/dist/SummerTime-0.1-py3-none-any.whl new file mode 100644 index 0000000000000000000000000000000000000000..a7e651d45eed37ce88709b7a1dec1d6de5afc5d0 Binary files /dev/null and b/dist/SummerTime-0.1-py3-none-any.whl differ diff --git a/download.py b/download.py new file mode 100644 index 0000000000000000000000000000000000000000..3f59569e354853f0961315d42da1ab3226a96884 --- /dev/null +++ b/download.py @@ -0,0 +1,3 @@ +import nltk + +nltk.download("stopwords") diff --git a/evaluation/__init__.py b/evaluation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cb5a9bf0790375852ca51750e45a4cbc91968275 --- /dev/null +++ b/evaluation/__init__.py @@ -0,0 +1,14 @@ +import site +import os + +# needed so that rouge works +package_path = site.getsitepackages()[0] +os.environ["ROUGE_HOME"] = package_path + "/summ_eval/RO UGE-1.5.5/" + +from .rouge_metric import Rouge +from .bertscore_metric import BertScore +from .rougewe_metric import RougeWe +from .bleu_metric import Bleu +from .meteor_metric import Meteor + +SUPPORTED_EVALUATION_METRICS = [BertScore, Bleu, Rouge, RougeWe, Meteor] diff --git a/evaluation/base_metric.py b/evaluation/base_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..fc6349011a2b7971ba7330e0d28579d9fe5a94fb --- /dev/null +++ b/evaluation/base_metric.py @@ -0,0 +1,27 @@ +from typing import List, Tuple, Dict + + +class SummMetric: + metric_name: str = None + range: Tuple[float, float] = None + higher_is_better: bool = None + requires_heavy_compute: bool = None + + def evaluate( + self, + # TODO zhangir: integrate with dataset api + inputs: List[str], + targets: List[str], + keys: List[str], + ) -> Dict[str, float]: + """ + All metrics should have this function. + :input: A list of summaries. + :target: A list of target summaries corresponding to each entry of input. + :keys: Which metrics to return, + e.g, ['rouge_1_f_score', 'rouge_2_f_score'] + :return: A dictionary with keys metrics and values scores. + """ + raise NotImplementedError( + "the base class for metrics shouldn't be instantiated!" + ) diff --git a/evaluation/bertscore_metric.py b/evaluation/bertscore_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..6ef6dedafd9837a1eedeef05761075ffba6e7a7f --- /dev/null +++ b/evaluation/bertscore_metric.py @@ -0,0 +1,20 @@ +from summ_eval.bert_score_metric import BertScoreMetric +from evaluation.summeval_metric import SummEvalMetric +from typing import List, Dict + + +class BertScore(SummEvalMetric): + metric_name = "bert score" + range = (0, 1) + higher_is_better = True + requires_heavy_compute = True + + def __init__(self): + se_metric = BertScoreMetric() + super(BertScore, self).__init__(se_metric) + + def evaluate( + self, inputs: List[str], targets: List[str], keys: List[str] = ["bert_score_f1"] + ) -> Dict[str, float]: + # TODO zhangir: update when datasets api is merged + return super(BertScore, self).evaluate(inputs, targets, keys) diff --git a/evaluation/bleu_metric.py b/evaluation/bleu_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..ea6c0b5730d647aacca797ff5303c74b8e7517fb --- /dev/null +++ b/evaluation/bleu_metric.py @@ -0,0 +1,20 @@ +from summ_eval.bleu_metric import BleuMetric +from evaluation.summeval_metric import SummEvalMetric +from typing import List, Dict + + +class Bleu(SummEvalMetric): + metric_name = "bleu" + range = (0, 100) + higher_is_better = True + requires_heavy_compute = False + + def __init__(self): + se_metric = BleuMetric() + super(Bleu, self).__init__(se_metric) + + def evaluate( + self, inputs: List[str], targets: List[str], keys: List[str] = ["bleu"] + ) -> Dict[str, float]: + # TODO zhangir: potentially update when dataset api is merged. + return super(Bleu, self).evaluate(inputs, targets, keys) diff --git a/evaluation/meteor_metric.py b/evaluation/meteor_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..e2c6c0bfc340b461a9660d6a2da63a35d3e1177a --- /dev/null +++ b/evaluation/meteor_metric.py @@ -0,0 +1,31 @@ +from .base_metric import SummMetric +from typing import List, Dict +from nltk.translate import meteor_score as nltk_meteor +import nltk +import statistics + + +class Meteor(SummMetric): + metric_name = "meteor" + range = (0, 1) + higher_is_better = True + requires_heavy_compute = False + + def __init__(self): + nltk.download("wordnet") + + def evaluate( + self, inputs: List[str], targets: List[str], keys=["meteor"] + ) -> Dict[str, float]: + + for key in keys: + if key != "meteor": + raise KeyError(key, "is not a valid key") + + meteor_scores = [ + nltk_meteor.meteor_score([input], target) + for input, target in zip(inputs, targets) + ] + meteor_score = statistics.mean(meteor_scores) + + return {key: meteor_score for key in keys} diff --git a/evaluation/rouge_metric.py b/evaluation/rouge_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..65c52db2fdbb344066393d9a3c8f17984d63ddba --- /dev/null +++ b/evaluation/rouge_metric.py @@ -0,0 +1,23 @@ +from summ_eval.rouge_metric import RougeMetric +from evaluation.summeval_metric import SummEvalMetric +from typing import List, Dict + + +class Rouge(SummEvalMetric): + metric_name = "rouge" + range = (0, 1) + higher_is_better = True + requires_heavy_compute = False + + def __init__(self): + se_metric = RougeMetric() + super(Rouge, self).__init__(se_metric) + + def evaluate( + self, + inputs: List[str], + targets: List[str], + keys: List[str] = ["rouge_1_f_score", "rouge_2_f_score", "rouge_l_f_score"], + ) -> Dict[str, float]: + score_dict = self.se_metric.evaluate_batch(inputs, targets) + return {key: score_dict["rouge"][key] for key in keys} diff --git a/evaluation/rougewe_metric.py b/evaluation/rougewe_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..b27aa0ce2266903a3aa898e6e1e4ea095ecbf1cf --- /dev/null +++ b/evaluation/rougewe_metric.py @@ -0,0 +1,24 @@ +from evaluation.summeval_metric import SummEvalMetric +from typing import List, Dict + +import nltk + + +class RougeWe(SummEvalMetric): + metric_name = "rougeWE" + range = (0, 1) + higher_is_better = True + requires_heavy_compute = True + + def __init__(self): + from summ_eval.rouge_we_metric import RougeWeMetric + + nltk.download("stopwords") + se_metric = RougeWeMetric() + super(RougeWe, self).__init__(se_metric) + + def evaluate( + self, inputs: List[str], targets: List[str], keys: List[str] = ["rouge_we_3_f"] + ) -> Dict[str, float]: + # TODO zhangir: update when dataset api is merged. + return super(RougeWe, self).evaluate(inputs, targets, keys) diff --git a/evaluation/summeval_metric.py b/evaluation/summeval_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..5b59ecbd5aa29bbf5a93ff0a95ab7bc31df8ae0c --- /dev/null +++ b/evaluation/summeval_metric.py @@ -0,0 +1,18 @@ +from .base_metric import SummMetric +from summ_eval.metric import Metric as SEMetric +from typing import List, Dict + + +class SummEvalMetric(SummMetric): + """ + Generic class for a summarization metric whose backend is SummEval. + """ + + def __init__(self, se_metric: SEMetric): + self.se_metric = se_metric + + def evaluate( + self, inputs: List[str], targets: List[str], keys: List[str] + ) -> Dict[str, float]: + score_dict = self.se_metric.evaluate_batch(inputs, targets) + return {key: score_dict[key] if key in score_dict else None for key in keys} diff --git a/model/__init__.py b/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..330a910a951c46a985342cb40b9d148d36fd65bf --- /dev/null +++ b/model/__init__.py @@ -0,0 +1,34 @@ +from .single_doc import ( + BartModel, + LexRankModel, + LongformerModel, + PegasusModel, + TextRankModel, +) +from .multi_doc import MultiDocJointModel, MultiDocSeparateModel +from .dialogue import HMNetModel +from .query_based import TFIDFSummModel, BM25SummModel +from .defaults import summarizer + +SUPPORTED_SUMM_MODELS = [ + BartModel, + LexRankModel, + LongformerModel, + PegasusModel, + TextRankModel, + MultiDocJointModel, + MultiDocSeparateModel, + HMNetModel, + TFIDFSummModel, + BM25SummModel, +] + + +def list_all_models(): + all_model_tuples = [] + for model_class in SUPPORTED_SUMM_MODELS: + model_description = model_class.generate_basic_description() + + all_model_tuples.append((model_class, model_description)) + + return all_model_tuples diff --git a/mode l/base_model.py b/model/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..ea5a1bcf065295f3b8058f56e313bd2d1dc4188b --- /dev/null +++ b/model/base_model.py @@ -0,0 +1,81 @@ +from typing import List, Union + + +class SummModel: + """ + Base model class for SummerTime + """ + + # static variables + model_name = "None" + is_extractive = False + is_neural = False + is_query_based = False + is_dialogue_based = False + is_multi_document = False + + def __init__( + self, + trained_domain: str = None, + max_input_length: int = None, + max_output_length: int = None, + ): + self.trained_domain = trained_domain + self.max_input_length = max_input_length + self.max_output_length = max_output_length + + def summarize( + self, corpus: Union[List[str], List[List[str]]], queries: List[str] = None + ) -> List[str]: + """ + All summarization models should have this function + + :param corpus: each string in the list is a source document to be summarized; if the model is multi-document or + dialogue summarization model, then each instance contains a list of documents/utterances + :param queries: a list of queries if this is a query-based model + :return: a list of generated summaries + """ + raise NotImplementedError( + "The base class for models shouldn't be instantiated!" + ) + + @classmethod + def assert_summ_input_type( + cls, corpus: Union[List[str], List[List[str]]], queries: Union[List[str], None] + ): + """ + Verifies that type of input corpus or queries for summarization align with the model type. + """ + raise NotImplementedError( + "The base class for models shouldn't be instantiated!" + ) + + @classmethod + def show_capability(cls) -> None: + """ + Use concise language to show the strength and weakness for each model. Try not to use NLP terminologies + """ + raise NotImplementedError( + "The base class for models shouldn't be instantiated!" + ) + + @classmethod + def generate_basic_description(cls) -> str: + """ + Automatically generate the basic description string based on the attributes + """ + extractive_abstractive = "extractive" if cls.is_extractive else "abstractive" + neural = "neural" if cls.is_neural else "non-neural" + + basic_description = ( + f"{cls.model_name} is a" + f"{'query-based' if cls.is_query_based else ''} " + f"{extractive_abstractive}, {neural} model for summarization." + ) + if cls.is_multi_document or cls.is_dialogue_based: + basic_description += ( + f"It can handle {'multi-document' if cls.is_multi_document else ''} " + f"{'dialogue' if cls.is_dialogue_based else ''} textual data." + ) + + return basic_description diff --git a/model/defaults.py b/model/defaults.py new file mode 100644 index 0000000000000000000000000000000000000000..b9acbf3ca368d343c760a4bf48a475d87fcf7ace --- /dev/null +++ b/model/defaults.py @@ -0,0 +1,10 @@ +from .single_doc import PegasusModel + + +class summarizer(PegasusModel): + def __init__(self, device="cpu"): + super(summarizer, self).__init__(device) + + def show_capability(self): + print("Pegasus is the default singe-document summarization model.") + super(summarizer, self).show_capability() diff --git a/model/dialogue/__init__.py b/model/dialogue/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b3cfbc34ec9abdf44eb4c8732fbf89668296637c --- /dev/null +++ b/model/dialogue/__init__.py @@ -0,0 +1 @@ +from .hmnet_model import HMNetModel diff --git a/model/dialogue/hmnet/ExampleRawData/meeting_summarization/AMI_proprec/test_ami.json b/model/dialogue/hmnet/ExampleRawData/meeting_summarization/AMI_proprec/test_ami.json new file mode 100644 index 00000000000 00000000000000000000000000000..b2f3e7348272a9d52d89db5781e66b600bbffaab --- /dev/null +++ b/model/dialogue/hmnet/ExampleRawData/meeting_summarization/AMI_proprec/test_ami.json @@ -0,0 +1 @@ +[{"source": {"dataset": "../ExampleRawData/meeting_summarization/AMI_proprec/test/"}, "task": "meeting", "name": "ami"}] \ No newline at end of file diff --git a/model/dialogue/hmnet/ExampleRawData/meeting_summarization/role_dict_ext.json b/model/dialogue/hmnet/ExampleRawData/meeting_summarization/role_dict_ext.json new file mode 100644 index 0000000000000000000000000000000000000000..9e26dfeeb6e641a33dae4961196235bdb965b21b --- /dev/null +++ b/model/dialogue/hmnet/ExampleRawData/meeting_summarization/role_dict_ext.json @@ -0,0 +1 @@ +{} \ No newline at end of file diff --git a/model/dialogue/hmnet/config/dialogue.conf b/model/dialogue/hmnet/config/dialogue.conf new file mode 100644 index 0000000000000000000000000000000000000000..5a38368e9ce4402157e40ed5f92e5a6e418c6d4c --- /dev/null +++ b/model/dialogue/hmnet/config/dialogue.conf @@ -0,0 +1,98 @@ +################## +# Trainer settings +################## + +MODEL MeetingNet_Transformer +TASK HMNet +CRITERION MLECriterion + +SEED 1033 + +MAX_NUM_EPOCHS 20 +EVAL_PER_UPDATE_NUM 10 +UPDATES_PER_EPOCH 20 + +# The actuall learning rate will be multiplied with the number of GPUs +OPTIMIZER RAdam +START_LEARNING_RATE 1e-3 +LR_SCHEDULER LnrWrmpInvSqRtDcyScheduler +WARMUP_STEPS 16000 +WARMUP_INIT_LR 1e-4 +WARMUP_END_LR 1e-3 + +# The actuall start learning rate equals START_LEARNING_RATE * GRADIENT_ACCUMULATE_STEP +# Model will be updated after every MINI_BATCH * GRADIENT_ACCUMULATE_STEP samples +GRADIENT_ACCUMULATE_STEP 5 + +GRAD_CLIPPING 2 + +################## +# Task settings +################## + +# This is the relative path to the directory where this conf file locates +USE_REL_DATA_PATH +TRAIN_FILE ../ExampleRawData/meeting_summarization/AMI_proprec/train_ami.json +DEV_FILE ../ExampleRawData/meeting_summarization/AMI_proprec/valid_ami.json +TEST_FILE ../ExampleRawData/meeting_summarization/AMI_proprec/test_ami.json +ROLE_DICT_FILE ../ExampleRawData/meeting_summarization/role_dict_ext.json + +MINI_BATCH 1 +MAX_PADDING_RATIO 1 +BATCH_READ_AHEAD 10 +DOC_SHUFFLE_BUF_SIZE 10 +SAMPLE_SHUFFLE_BUFFER_SIZE 10 +BATCH_SHUFFLE_BUFFER_SIZE 10 + +MAX_TRANSCRIPT_WORD 8300 +#MAX_SENT_LEN 30 +MAX_SENT_LEN 12 +# MAX_SENT_NUM 300 +MAX_SENT_NUM 60 + +################## +# Model settings +################## + +DROPOUT 0.1 +VOCAB_DIM 512 +ROLE_SIZE 32 +ROLE_DIM 16 +POS_DIM 16 +ENT_DIM 16 + +USE_ROLE +USE_POSENT + +USE_BOS_TOKEN +USE_EOS_TOKEN + +TRANSFORMER_EMBED_DROPOUT 0.1 +TRANSFORMER_RESIDUAL_DROPOUT 0.1 +TRANSFORMER_ATTENTION_DROPOUT 0.1 +TRANSFORMER_LAYER 6 +TRANSFORMER_HEAD 8 +TRANSFORMER_POS_DISCOUNT 80 + +PRE_TOKENIZER TransfoXLTokenizer +PRE_TOKENIZER_PATH ../../../third_party/HMNet/ExampleInitModel/transfo-xl-wt103 +PYLEARN_MODEL ../../../third_party/HMNet/ExampleInitModel/AMI-finetuned +# e.g. PYLEARN_MODEL conf_hmnet_AMI_conf~/run_1/11600 + +################## +# Tokenizer settings +################## + +EXTRA_IDS 1000 + +################## +# Decoding settings +################## + +BEAM_WIDTH 6 +EVAL_TOKENIZED +EVAL_LOWERCASE +# MAX_GEN_LENGTH 300 +MAX_GEN_LENGTH 60 +MIN_GEN_LENGTH 10 +NO_REPEAT_NGRAM_SIZE 3 \ No newline at end of file diff --git a/model/dialogue/hmnet_model.py b/model/dialogue/hmnet_model.py new file mode 100644 index 0000000000000000000000000000000000000000..54385d7cd14c723ee99aa7282ee0d6c30802f2eb --- /dev/null +++ b/model/dialogue/hmnet_model.py @@ -0,0 +1,483 @@ +from model.base_model import SummModel +import argparse +import os +import torch +import gzip +import json +from model.third_party.HMNet.Models.Trainers.HMNetTrainer import HMNetTrainer +from model.third_party.HMNet.Utils.Arguments import Arguments + +import spacy + +nlp = spacy.load("en_core_web_sm", disable=["parser"]) +# tagger = nlp.get_pipe('tagger') +# ner = nlp.get_pipe('ner') +# POS = {w: i for i, w in enumerate([''] + list(tagger.labels))} +# ENT = {w: i for i, w in enumerate([''] + list(ner.move_names))} +# These two dicts are adapted from SpaCy 2.3.1, since HMNet's embedding for POS and ENT is fixed +POS = { + "": 0, + "$": 1, + "''": 2, + ",": 3, + "-LRB-": 4, + "-RRB-": 5, + ".": 6, + ":": 7, + "ADD": 8, + "AFX": 9, + "CC": 10, + "CD": 11, + "DT": 12, + "EX": 13, + "FW": 14, + "HYPH": 15, + "IN": 16, + "JJ": 17, + "JJR": 18, + "JJS": 19, + "LS": 20, + "MD": 21, + "NFP": 22, + "NN": 23, + "NNP": 24, + "NNPS": 25, + "NNS": 26, + "PDT": 27, + "POS": 28, + "PRP": 29, + "PRP$": 30, + "RB": 31, + "RBR": 32, + "RBS": 33, + "RP": 34, + "SYM": 35, + "TO": 36, + "UH": 37, + "VB": 38, + "VBD": 39, + "VBG": 40, + "VBN": 41, + "VBP": 42, + "VBZ": 43, + "WDT": 44, + "WP": 45, + "WP$": 46, + "WRB": 47, + "XX": 48, + "_SP": 49, + "``": 50, +} +ENT = { + "": 0, + "B-ORG": 1, + "B-DATE": 2, + "B-PERSON": 3, + "B-GPE": 4, + "B-MONEY": 5, + "B-CARDINAL": 6, + "B-NORP": 7, + "B-PERCENT": 8, + "B-WORK_OF_ART": 9, + "B-LOC": 10, + "B-TIME": 11, + "B-QUANTITY": 12, + "B-FAC": 13, + "B-EVENT": 14, + "B-ORDINAL": 15, + "B-PRODUCT": 16, + "B-LAW": 17, + "B-LANGUAGE": 18, + "I-ORG": 19, + "I-DATE": 20, + "I-PERSON": 21, + "I-GPE": 22, + "I-MONEY": 23, + "I-CARDINAL": 24, + "I-NORP": 25, + "I-PERCENT": 26, + "I-WORK_OF_ART": 27, + "I-LOC": 28, + "I-TIME": 29, + "I-QUANTITY": 30, + "I-FAC": 31, + "I-EVENT": 32, + "I-ORDINAL": 33, + "I-PRODUCT": 34, + "I-LAW": 35, + "I-LANGUAGE": 36, + "L-ORG": 37, + "L-DATE": 38, + "L-PERSON": 39, + "L-GPE": 40, + "L-MONEY": 41, + "L-CARDINAL": 42, + "L-NORP": 43, + "L-PERCENT": 44, + "L-WORK_OF_ART": 45, + "L-LOC": 46, + "L-TIME": 47, + "L-QUANTITY": 48, + "L-FAC": 49, + "L-EVENT": 50, + "L-ORDINAL": 51, + "L-PRODUCT": 52, + "L-LAW": 53, + "L-LANGUAGE": 54, + "U-ORG": 55, + "U-DATE": 56, + "U-PERSON": 57, + "U-GPE": 58, + "U-MONEY": 59, + "U-CARDINAL": 60, + "U-NORP": 61, + "U-PERCENT": 62, + "U-WORK_OF_ART": 63, + "U-LOC": 64, + "U-TIME": 65, + "U-QUANTITY": 66, + "U-FAC": 67, + "U-EVENT": 68, + "U-ORDINAL": 69, + "U-PRODUCT": 70, + "U-LAW": 71, + "U-LANGUAGE": 72, + "O": 73, +} + + +class HMNetModel(SummModel): + # static variables + model_name = "HMNET" + is_extractive = False + is_neural = True + is_dialogue_based = True + + def __init__( + self, + min_gen_length: int = 10, + max_gen_length: int = 300, + beam_width: int = 6, + **kwargs, + ): + """ + Create a summarization model with HMNet backbone. In the default setting, the inference speed will be + 10s/sample (on one GPU), however, if one can tune these three parameters properly, e.g. min_gen_length=10, + max_gen_length=100, and beam_width=2, the inference speed will increase to 2s/sample (on one GPU). + + Args: + min_gen_length (int): minimum generation length of the decoder + max_gen_length (int): maximum generation length of the decoder + beam_width (int): width of the beam when doing beam search in the decoding process + kwargs: the other valid parameters. The valid parameters can be found in + model/dialogue/hmnet/config/dialogue.conf . You can use either lower case or upper case for parameter + name. The valid parameter name is one of the following args, however, we do not encourage you to modify + them, since some unexpected, untested errors might be triggered: + ['MODEL', 'TASK', 'CRITERION', 'SEED', 'MAX_NUM_EPOCHS', 'EVAL_PER_UPDATE_NUM' + , 'UPDATES_PER_EPOCH', 'OPTIMIZER', 'START_LEARNING_RATE', 'LR_SCHEDULER', 'WARMUP_STEPS', + 'WARMUP_INIT_LR', 'WARMUP_END_LR', 'GRADIENT_ACCUMULATE_STEP', 'GRAD_CLIPPING', 'USE_REL_DATA_PATH', + 'TRAIN_FILE ', 'DEV_FILE', 'TEST_FILE', 'ROLE_DICT_FILE', 'MINI_BATCH', 'MAX_PADDING_RATIO', + 'BATCH_READ_AHEAD', 'DOC_SHUFFLE_BUF_SIZE', 'SAMPLE_SHUFFLE_BUFFER_SIZE', 'BATCH_SHUFFLE_BUFFER_SIZE', + 'MAX_TRANSCRIPT_WORD', 'MAX_SENT_LEN', 'MAX_SENT_NUM', 'DROPOUT', 'VOCAB_DIM', 'ROLE_SIZE', 'ROLE_DIM', + 'POS_DIM', 'ENT_DIM', 'USE_ROLE', 'USE_POSENT', 'USE_BOS_TOKEN', 'USE_EOS_TOKEN', + 'TRANSFORMER_EMBED_DROPOUT', 'TRANSFORMER_RESIDUAL_DROPOUT', 'TRANSFORMER_ATTENTION_DROPOUT', + 'TRANSFORMER_LAYER', 'TRANSFORMER_HEAD', 'TRANSFORMER_POS_DISCOUNT', 'PRE_TOKENIZER', + 'PRE_TOKENIZER_PATH', 'PYLEARN_MODEL', 'EXTRA_IDS', 'BEAM_WIDTH', 'EVAL_TOKENIZED', 'EVAL_LOWERCASE', + 'MAX_GEN_LENGTH', 'MIN_GEN_LENGTH', 'NO_REPEAT_NGRAM_SIZE'] + + Return an instance of HMNet model for dialogue summarization. + """ + super(HMNetModel, self).__init__() + self.root_path = self._get_root() + + # we leave the most influential params with prompt and the others as hidden kwargs + kwargs["MIN_GEN_LENGTH"] = min_gen_length + kwargs["MAX_GEN_LENGTH"] = max_gen_length + kwargs["BEAM_WIDTH"] = beam_width + self.opt = self._parse_args(kwargs) + self.model = HMNetTrainer(self.opt) + + def _get_root(self): + root_path = os.getcwd() + while "model" not in os.listdir(root_path): + root_path = os.path.dirname(root_path) + root_path = os.path.join(root_path, "model/dialogue") + return root_path + + def _parse_args(self, kwargs): + parser = argparse.ArgumentParser( + description="HMNet: Pretrain or fine-tune models for HMNet model." + ) + parser.add_argument( + "--command", default="evaluate", help="Command: train/evaluate" + ) + parser.add_argument( + "--conf_file", + default=os.path.join(self.root_path, "hmnet/config/dialogue.conf"), + help="Path to the BigLearn conf file.", + ) + parser.add_argument( + "--PYLEARN_MODEL", help="Overrides this option from the conf file." + ) + parser.add_argument( + "--master_port", help="Overrides this option default", default=None + ) + parser.add_argument("--cluster", help="local, philly or aml", default="local") + parser.add_argument( + "--dist_init_path", help="Distributed init path for AML", default="./tmp" + ) + parser.add_argument( + "--fp16", + action="store_true", + help="Whether to use 16-bit float precision instead of 32-bit", + ) + parser.add_argument( + "--fp16_opt_level", + type=str, + default="O1", + help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." + "See details at https://nvidia.github.io/apex/amp.html", + ) + parser.add_argument("--no_cuda", action="store_true", help="Disable cuda.") + parser.add_argument( + "--config_overrides", + help="Override parameters on config, VAR=val;VAR=val;...", + ) + + cmdline_args = parser.parse_args() + command = cmdline_args.command + conf_file = cmdline_args.conf_file + conf_args = Arguments(conf_file) + opt = conf_args.readArguments() + + if cmdline_args.config_overrides: + for config_override in cmdline_args.config_overrides.split(";"): + config_override = config_override.strip() + if config_override: + var_val = config_override.split("=") + assert ( + len(var_val) == 2 + ), f"Config override '{var_val}' does not have the form 'VAR=val'" + conf_args.add_opt(opt, var_val[0], var_val[1], force_override=True) + + opt["cuda"] = torch.cuda.is_available() and not cmdline_args.no_cuda + opt["confFile"] = conf_file + i f "datadir" not in opt: + opt["datadir"] = os.path.dirname( + conf_file + ) # conf_file specifies where the data folder is + opt["basename"] = os.path.basename( + conf_file + ) # conf_file specifies where the name of save folder is + opt["command"] = command + + # combine cmdline_args into opt dictionary + for key, val in cmdline_args.__dict__.items(): + # if val is not None and key not in ['command', 'conf_file']: + if val is not None: + opt[key] = val + + # combine kwargs into opt dictionary (we allow lower case) + for key, val in kwargs.items(): + valid_keys = [x for x in opt.keys() if x.upper() == x] + if key.upper() not in valid_keys: + print("WARNING: {} is not a valid key in HMNet.".format(key)) + print("The valid keys are:", valid_keys) + continue + if val is not None: + opt[key.upper()] = val + + return opt + + def summarize(self, corpus, queries=None): + print(f"HMNet model: processing document of {corpus.__len__()} samples") + # transform the original dataset to "dialogue" input + # we only use test set path for evaluation + data_folder = os.path.join( + os.path.dirname(self.opt["datadir"]), + "ExampleRawData/meeting_summarization/AMI_proprec/test", + ) + + self._create_datafolder(data_folder) + self._preprocess(corpus, data_folder) + + # return self.model.eval() + results = self._evaluate() + + return results + + def _evaluate(self): + if self.opt["rank"] == 0: + self.model.log("-----------------------------------------------") + self.model.log("Evaluating model ... ") + + self.model.set_up_model() + + eval_dataset = "test" + batch_generator_eval = self.model.get_batch_generator(eval_dataset) + predictions = self._eval_batches( + self.model.module, batch_generator_eval, self.model.saveFolder, eval_dataset + ) + + return predictions + + def _eval_batches(self, module, dev_batches, save_folder, label=""): + max_sent_len = int(self.opt["MAX_GEN_LENGTH"]) + + print("Decoding current model ... \nSaving folder is {}".format(save_folder)) + print("Each sample will cost about 10 second.") + import time + + start_time = time.time() + predictions = [] # prediction of tokens from model + if not isinstance(module.tokenizer, list): + decoder_tokenizer = module.tokenizer + elif len(module.tokenizer) == 1: + decoder_tokenizer = module.tokenizer[0] + elif len(module.tokenizer) == 2: + decoder_tokenizer = module.tokenizer[1] + else: + assert False, "len(module.tokenizer) > 2" + + with torch.no_grad(): + for j, dev_batch in enumerate(dev_batches): + for b in dev_batch: + if torch.is_tensor(dev_batch[b]): + dev_batch[b] = dev_batch[b].to(self.opt["device"]) + + beam_search_res = module( + dev_batch, beam_search=True, max_sent_len=max_sent_len + ) + pred = [ + [t[0] for t in x] if len(x) > 0 else [[]] for x in beam_search_res + ] + predictions.extend( + [ + [ + self._convert_tokens_to_string(decoder_tokenizer, tt) + for tt in t + ] + for t in pred + ] + ) + + if ( + "DEBUG" in self.opt and j >= 10 + ) or j >= self.model.task.evaluator.eval_batches_num: + # in debug mode (decode first 10 batches) ortherwise decode first self.eval_batches_num bathes + break + + top1_predictions = [x[0] for x in predictions] + + print("Total time for inference:", time.time() - start_time) + return top1_predictions + + def _convert_tokens_to_string(self, tokenizer, tokens): + if "EVAL_TOKENIZED" in self.opt: + tokens = [t for t in tokens if t not in tokenizer.all_special_tokens] + if "EVAL_LOWERCASE" in self.opt: + tokens = [t.lower() for t in tokens] + if "EVAL_TOKENIZED" in self.opt: + return " ".join(tokens) + else: + return tokenizer.decode( + tokenizer.convert_tokens_to_ids(tokens), skip_special_tokens=True + ) + + def _preprocess(self, corpus, test_path): + samples = [] + for i, sample in enumerate(corpus): + new_sample = {"id": i, "meeting": [], "summary": []} + if isinstance(sample, str): + raise RuntimeError( + "Error: the input of HMNet should be dialogues, rather than documents." + ) + + # add all the turns one by one + for turn in sample: + turn = [x.strip() for x in turn.split(":")] + if len(turn) < 2: + continue + tokenized_turn = nlp(turn[1]) + # In case we can't find proper entity in move_names + ent_id = [] + pos_id = [] + for token in tokenized_turn: + ent = ( + token.ent_iob_ + "-" + token.ent_type_ + if token.ent_iob_ != "O" + else "O" + ) + ent_id.append(ENT[ent] if ent in ENT else ENT[""]) + + pos = token.tag_ + pos_id.append(POS[pos] if pos in POS else POS[""]) + + new_sample["meeting"].append( + { + "speaker": turn[0], + "role": "", + "utt": { + "word": [str(token) for token in tokenized_turn], + "pos_id": pos_id, + "ent_id": ent_id, + }, + } + ) + new_sample["summary"].append( + "This is a dummy summary. HMNet will filter out the sample w/o summary!" + ) + samples.append(new_sample) + # save to the gzip + file_path = os.path.join(test_path, "split_{}.jsonl.gz".format(i)) + with gzip.open(file_path, "wt", encoding="utf-8") as file: + file.write(json.dumps(new_sample)) + + def _clean_datafolder(self, data_folder): + for name in os.listdir(data_folder): + name = os.path.join(data_folder, name) + if ".gz" in name: + os.remove(name) + + def _create_datafolder(self, data_folder): + if os.path.exists(data_folder): + self._clean_datafolder(data_folder) + else: + os.makedirs(data_folder) + with open( + os.path.join(os.path.dirname(data_folder), "test_ami.json"), + "w", + encoding="utf-8", + ) as file: + json.dump( + [ + { + "source": { + "dataset": "../ExampleRawData/meeting_summarization/AMI_proprec/test/" + }, + "task": "meeting", + "name": "ami", + } + ], + file, + ) + + with open( + os.path.join( + os.path.dirname(os.path.dirname(data_folder)), "role_dict_ext.json" + ), + "w", + ) as file: + json.dump({}, file) + + @classmethod + def show_capability(cls) -> None: + basic_description = cls.generate_basic_description() + more_details = ( + "A HMNet model finetuned on CNN-DM dataset for sum marization.\n\n" + "Strengths:\n - High performance on dialogue summarization task.\n\n" + "Weaknesses:\n - Not suitable for datasets other than dialogues.\n\n" + "Initialization arguments:\n " + " - `corpus`: Unlabelled corpus of documents.\n" + ) + print(f"{basic_description} \n {'#' * 20} \n {more_details}") diff --git a/model/multi_doc/__init__.py b/model/multi_doc/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bd8e13c695783e5c32095bf2990196301204b3a2 --- /dev/null +++ b/model/multi_doc/__init__.py @@ -0,0 +1,2 @@ +from .multi_doc_joint_model import MultiDocJointModel +from .multi_doc_separate_model import MultiDocSeparateModel diff --git a/model/multi_doc/base_multi_doc_model.py b/model/multi_doc/base_multi_doc_model.py new file mode 100644 index 0000000000000000000000000000000000000000..4fd304350cc6fef91acb348bcd8dfc03a8f039e9 --- /dev/null +++ b/model/multi_doc/base_multi_doc_model.py @@ -0,0 +1,40 @@ +from model.base_model import SummModel + + +class MultiDocSummModel(SummModel): + + is_multi_document = True + + def __init__( + self, + trained_domain: str = None, + max_input_length: int = None, + max_output_length: int = None, + ): + super(MultiDocSummModel, self).__init__( + trained_domain=trained_domain, + max_input_length=max_input_length, + max_output_length=max_output_length, + ) + + @classmethod + def assert_summ_input_type(cls, corpus, query): + if not all( + [ + isinstance(ins, list) and all([isinstance(doc, str) for doc in ins]) + for ins in corpus + ] + ): + raise TypeError( + "Multi-document summarization models summarize instances of multiple documents (`List[List[str]]`)." + ) + + if query is not None: + if not isinstance(query, list): + raise TypeError( + "Query-based single-document summarization requires query of `List[str]`." + ) + if not all([isinstance(q, str) for q in query]): + raise TypeError( + "Query-based single-document summarization requires query of `List[str]`." + ) diff --git a/model/multi_doc/multi_doc_joint_model.py b/model/multi_doc/multi_doc_joint_model.py new file mode 100644 index 0000000000000000000000000000000000000000..e5f3568a43cfacdc7dd1e4a8111cabdfccf425be --- /dev/null +++ b/model/multi_doc/multi_doc_joint_model.py @@ -0,0 +1,51 @@ +from .base_multi_doc_model import MultiDocSummModel +from model.base_model import SummModel +from model.single_doc import TextRankModel +from typing import Union, List + + +class MultiDocJointModel(MultiDocSummModel): + + model_name = "Multi-document joint" + is_multi_document = True + + def __init__(self, model_backend: SummModel = TextRankModel, **kwargs): + super(MultiDocJointModel, self).__init__() + model = model_backend(**kwargs) + self.model = model + + def summarize( + self, + corpus: Union[List[str], List[List[str]]], + query: Union[List[str], List[List[str]]] = None, + ) -> List[str]: + self.assert_summ_input_type(corpus, None) + joint_corpus = [] + for instance in corpus: + joint_corpus.append(" ".join(instance)) + + summaries = self.model.summarize(joint_corpus) + + return summaries + + @classmethod + def generate_basic_description(cls) -> str: + basic_description = ( + "MultiDocJointModel performs multi-document summarization by" + " first concatenating all documents," + " and then performing single-document summarization on the concatenation." + ) + return basic_description + + @classmethod + def show_capability(cls): + basic_description = cls.generate_basic_description() + more_details = ( + "A multi-document summarization model." + " Allows for custom model backend selection at initialization." + " Concatenates each document in corpus and returns single-document summarization of joint corpus.\n" + "Strengths: \n - Allows for control of backend model.\n" + "Weaknesses: \n - Assumes all documents are equally weighted.\n" + " - May fail to extract information from certain documents.\n" + ) + print(f"{basic_description}\n{'#' * 20}\n{more_details}") diff --git a/model/multi_doc/multi_doc_separate_model.py b/model/multi_doc/multi_doc_separate_model.py new file mode 100644 index 0000000000000000000000000000000000000000..5eab2288cf9b44580726360c9989b9c0214ab4c1 --- /dev/null +++ b/model/multi_doc/multi_doc_separate_model.py @@ -0,0 +1,49 @@ +from .base_multi_doc_model import MultiDocSummModel +from model.base_model import SummModel +from model.single_doc import TextRankModel +from typing import Union, List + + +class MultiDocSeparateModel(MultiDocSummModel): + + model_name = "Multi-document separate" + is_multi_document = True + + def __init__(self, model_backend: SummModel = TextRankModel, **kwargs): + super(MultiDocSeparateModel, self).__init__() + model = model_backend(**kwargs) + self.model = model + + def summarize( + self, + corpus: Union[List[str], List[List[str]]], + query: Union[List[str], List[List[str]]] = None, + ) -> List[str]: + self.assert_summ_input_type(corpus, None) + summaries = [] + for instance in corpus: + instance_summaries = self.model.summarize(instance) + summaries.append(" ".join(instance_summaries)) + + return summaries + + @classmethod + def generate_basic_description(cls) -> str: + basic_description = ( + "MultiDocSeparateModel performs multi-document summarization by" + " first performing single-document summarization on each document," + " and then concatenating the results." + ) + return basic_description + + @classmethod + def show_capability(cls): + basic_description = cls.generate_basic_description() + more_details = ( + "A multi-document summarization model." + " Allows for custom model backend selection at initialization." + " Performs single-document summarization on each document in corpus and returns concatenated result.\n" + "Strengths: \n - Allows for control of backend model.\n" + "Weaknesses: \n - Assumes all documents are equally weighted.\n - May produce redundant information for similar documents.\n" + ) + print(f"{basic_description}\n{'#' * 20}\n{more_details}") diff --git a/model/query_based/__init__.py b/model/query_based/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..64940297f17e93a966bf7efba25308682eec0cd4 --- /dev/null +++ b/model/query_based/__init__.py @@ -0,0 +1,2 @@ +from .bm25_model import BM25SummModel +from .tf_idf_model import TFIDFSummModel diff --git a/model/query_based/base_query_based_model.py b/model/query_based/base_query_based_model.py new file mode 100644 index 0000000000000000000000000000000000000000..9b94b5a3c7f4cc0bb894c7e0863524330887d6e5 --- /dev/null +++ b/model/query_based/base_query_based_model.py @@ -0,0 +1,147 @@ +from model.base_model import SummModel +from model.single_doc import TextRankModel +from typing import List, Union + +from nltk import sent_tokenize, word_tokenize +from nltk.corpus import stopwords +from nltk.stem import PorterStemmer + + +class QueryBasedSummModel(SummModel): + + is_query_based = True + + def __init__( + self, + trained_domain: str = None, + max_input_length: int = None, + max_output_length: int = None, + model_backend: SummModel = TextRankModel, + retrieval_ratio: float = 0.5, + preprocess: bool = True, + **kwargs, + ): + super(QueryBasedSummModel, self).__init__( + trained_domain=trained_domain, + max_ input_length=max_input_length, + max_output_length=max_output_length, + ) + self.model = model_backend(**kwargs) + self.retrieval_ratio = retrieval_ratio + self.preprocess = preprocess + + def _retrieve(self, instance: List[str], query: List[str], n_best) -> List[str]: + raise NotImplementedError() + + def summarize( + self, + corpus: Union[List[str], List[List[str]]], + queries: List[str] = None, + ) -> List[str]: + self.assert_summ_input_type(corpus, queries) + + retrieval_output = [] # List[str] + for instance, query in zip(corpus, queries): + if isinstance(instance, str): + is_dialogue = False + instance = sent_tokenize(instance) + else: + is_dialogue = True + query = [query] + + # instance & query now are List[str] for sure + if self.preprocess: + preprocessor = Preprocessor() + instance = preprocessor.preprocess(instance) + query = preprocessor.preprocess(query) + + n_best = max(int(len(instance) * self.retrieval_ratio), 1) + top_n_sent = self._retrieve(instance, query, n_best) + + if not is_dialogue: + top_n_sent = " ".join(top_n_sent) # str + retrieval_output.append(top_n_sent) + + summaries = self.model.summarize( + retrieval_output + ) # List[str] or List[List[str]] + return summaries + + def generate_specific_description(self): + is_neural = self.model.is_neural & self.is_neural + is_extractive = self.model.is_extractive | self.is_extractive + model_name = "Pipeline with retriever: {}, summarizer: {}".format( + self.model_name, self.model.model_name + ) + + extractive_abstractive = "extractive" if is_extractive else "abstractive" + neural = "neural" if is_neural else "non-neural" + + basic_description = ( + f"{model_name} is a " + f"{'query-based' if self.is_query_based else ''} " + f"{extractive_abstractive}, {neural} model for summarization." + ) + + return basic_description + + @classmethod + def assert_summ_input_type(cls, corpus, query): + if query is None: + raise TypeError( + "Query-based summarization models summarize instances of query-text pairs, however, query is missing." + ) + + if not isinstance(query, list): + raise TypeError( + "Query-based single-document summarization requires query of `List[str]`." + ) + if not all([isinstance(q, str) for q in query]): + raise TypeError( + "Query-based single-document summarization requires query of `List[str]`." + ) + + @classmethod + def generate_basic_description(cls) -> str: + basic_description = ( + "QueryBasedSummModel performs query-based summarization. Given a query-text pair," + "the model will first extract the most relevant sentences in articles or turns in " + "dialogues, then use the single document summarization model to generate the summary" + ) + return basic_description + + @classmethod + def show_capability(cls): + basic_description = cls.generate_basic_description() + more_details = ( + "A query-based summarization model." + " Allows for custom model backend selection at initialization." + " Retrieve relevant turns and then summarize the retrieved turns\n" + "Strengths: \n - Allows for control of backend model.\n" + "Weaknesses: \n - Heavily depends on the performance of both retriever and summarizer.\n" + ) + print(f"{basic_description}\n{'#' * 20}\n{more_details}") + + +class Preprocessor: + def __init__(self, remove_stopwords=True, lower_case=True, stem=False): + self.sw = stopwords.words("english") + self.stemmer = PorterStemmer() + self.remove_stopwords = remove_stopwords + self.lower_case = lower_case + self.stem = stem + + def preprocess(self, corpus: List[str]) -> List[str]: + if self.lower_case: + corpus = [sent.lower() for sent in corpus] + tokenized_corpus = [word_tokenize(sent) for sent in corpus] + if self.remove_stopwords: + tokenized_corpus = [ + [word for word in sent if word not in self.sw] + for sent in tokenized_corpus + ] + if self.stem: + tokenized_corpus = [ + [self.stemmer.stem(word) for word in sent] for sent in tokenized_corpus + ] + return [" ".join(sent) for sent in tokenized_corpus] diff --git a/model/query_based/bm25_model.py b/model/query_based/bm25_model.py new file mode 100644 index 0000000000000000000000000000000000000000..d5fc06bbebfe0d75eecd0ee239f7e56f4fc2ef17 --- /dev/null +++ b/model/query_based/bm25_model.py @@ -0,0 +1,45 @@ +from .base_query_based_model import QueryBasedSummModel +from model.base_model import SummModel +from model.single_doc import TextRankModel +from typing import List + +from gensim.summarization.bm25 import BM25 +from nltk import word_tokenize + + +class BM25SummModel(QueryBasedSummModel): + + # static variables + model_name = "BM25" + is_extractive = True # only represents the retrieval part + is_neural = False # only represents the retrieval part + is_query_based = True + + def __init__( + self, + trained_domain: str = None, + max_input_length: int = None, + max_output_length: int = None, + model_backend: SummModel = TextRankModel, + retrieval_ratio: float = 0.5, + preprocess: bool = True, + **kwargs + ): + super(BM25SummModel, self).__init__( + trained_domain=trained_domain, + max_input_length=max_input_length, + max_output_length=max_output_length, + model_backend=model_backend, + retrieval_ratio=retrieval_ratio, + preprocess=preprocess, + **kwargs + ) + + def _retrieve(self, instance: List[str], query: List[str], n_best): + bm25 = BM25(word_tokenize(s) for s in instance) + scores = bm25.get_scores(query) + best_sent_ind = sorted( + range(len(scores)), key=lambda i: scores[i], reverse=True + )[:n_best] + top_n_sent = [instance[ind] for ind in sorted(best_sent_ind)] + return top_n_sent diff --git a/model/query_based/tf_idf_model.py b/model/query_based/tf_idf_model.py new file mode 100644 index 0000000000000000000000000000000000000000..cecd798f0882212f5509b1549a65e8f752151ac9 --- /dev/null +++ b/model/query_based/tf_idf_model.py @@ -0,0 +1,46 @@ +from .base_query_based_model import QueryBasedSummModel +from model.base_model import SummModel +from model.single_doc import TextRankModel +from typing import List + +from sklearn.feature_extraction.text import TfidfVectorizer +from sklearn.metrics.pairwise import cosine_similarity + + +class TFIDFSummModel(QueryBasedSummModel): + + # static variables + model_name = "TF-IDF" + is_extractive = True + is_neural = False + is_query_based = True + + def __init__( + self, + trained_domain: str = None, + max_input_length: int = None, + max_output_length: int = None, + model_backend: SummModel = TextRankModel, + retrieval_ratio: float = 0.5, + preprocess: bool = True, + **kwargs + ): + super(TFIDFSummModel, self).__init__( + trained_domain=trained_domain, + max_input_length=max_input_length, + max_output_length=max_output_length, + model_backend=model_backend, + retrieval_ratio=retrieval_ratio, + preprocess=preprocess, + **kwargs + ) + self.vectorizer = TfidfVectorizer() + + def _retrieve(self, instance: List[str], query: List[str], n_best): + instance_vectors = self.vectorizer.fit_transform(instance) + query_vector = self.vectorizer.transform(query) + + similarities = cosine_similarity(query_vector, instance_vectors).squeeze() + top_n_index = similarities.argsort()[::-1][0:n_best] + top_n_sent = [instance[ind] for ind in top_n_index] # List[str] + return top_n_sent diff --git a/model/single_doc/__init__.py b/model/single_doc/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6c8a6c077acb36505a136b1ad1cc1ccd23844e1e --- /dev/null +++ b/model/single_doc/__init__.py @@ -0,0 +1,5 @@ +from .bart_model import BartModel +from .pegasus_model import PegasusModel +from .lexrank_model import LexRankModel +from .longformer_model import LongformerModel +from .textrank_model import TextRankModel diff --git a/model/single_doc/bart_model.py b/model/single_doc/bart_model.py new file mode 100644 index 0000000000000000000000000000000000000000..d7108c277d76995550c578850b36a7e57b57354e --- /dev/null +++ b/model/single_doc/bart_model.py @@ -0,0 +1,36 @@ +from transformers import BartForConditionalGeneration, BartTokenizer +from .base_single_doc_model import SingleDocSummModel + + +class BartModel(SingleDocSummModel): + + # static variables + model_name = "BART" + is_extractive = False + is_neural = False + + def __init__(self, device="cpu"): + super(BartModel, self).__init__() + + self.device = device + model_name = "facebook/bart-large-cnn" + self.tokenizer = BartTokenizer.from_pretrained(model_name) + self.model = BartForConditionalGeneration.from_pretrained(model_name) + + def summarize(self, corpus, queries=None): + self.assert_summ_input_type(corpus, queries) + + batch = self.tokenizer( + corpus, truncation=True, padding="longest", return_tensors="pt" + ).to(self.device) + encoded_summaries = self.model.generate(**batch) + summaries = self.tokenizer.batch_decode( + encoded_summaries, skip_special_tokens=True + ) + + return summaries + + @classmethod + def show_capability(cls) -> None: + # TODO zhangir: add the show capability function for BART + print(cls.generate_basic_description()) diff --git a/model/single_doc/base_single_doc_model.py b/model/single_doc/base_single_doc_model.py new file mode 100644 index 0000000000000000000000000000000000000000..079700afaa3a270bf2424a0bb75a71cccc861a10 --- /dev/null +++ b/model/single_doc/base_single_doc_model.py @@ -0,0 +1,36 @@ +from model.base_model import SummModel + + +class SingleDocSummModel(SummModel): + def __init__( + self, + trained_domain: str = None, + max_input_length: int = None, + max_output_length: int = None, + ): + super(SingleDocSummModel, self).__init__( + trained_domain=trained_domain, + max_input_length=max_input_length, + max_output_length=max_output_length, + ) + + @classmethod + def assert_summ_input_type(cls, corpus, query): + if not isinstance(corpus, list): + raise TypeError( + "Single-document summarization requires corpus of `List[str]`." + ) + if not all([isinstance(ins, str) for ins in corpus]): + raise TypeError( + "Single-document summarization requires corpus of `List[str]`." + ) + + if query is not None: + if not isinstance(query, list): + raise TypeError( + "Query-based single-document summarization requires query of `List[str]`." + ) + if not all([isinstance(q, str) for q in query]): + raise TypeError( + "Query-based single-document summarization requires query of `List[str]`." + ) diff --git a/model/single_doc/lexrank_model.py b/model/single_doc/lexrank_model.py new file mode 100644 index 0000000000000000000000000000000000000000..98582b0fe4560bb02a3020739ecb1f73bae3f25d --- /dev/null +++ b/model/single_doc/lexrank_m odel.py @@ -0,0 +1,50 @@ +from lexrank import STOPWORDS +from lexrank import LexRank as LR +import nltk + +from .base_single_doc_model import SingleDocSummModel + + +class LexRankModel(SingleDocSummModel): + # static variables + model_name = "LexRank" + is_extractive = True + is_neural = False + + def __init__(self, data, summary_length=2, threshold=0.1): + super(LexRankModel, self).__init__() + + nltk.download("punkt", quiet=True) + corpus = [nltk.sent_tokenize(example) for example in data] + self.lxr = LR(corpus, stopwords=STOPWORDS["en"]) + self.summary_length = summary_length + self.threshold = threshold + + def summarize(self, corpus, queries=None): + self.assert_summ_input_type(corpus, queries) + + documents = [nltk.sent_tokenize(document) for document in corpus] + summaries = [ + " ".join( + self.lxr.get_summary( + document, summary_size=self.summary_length, threshold=self.threshold + ) + ) + for document in documents + ] + + return summaries + + @classmethod + def show_capability(cls): + basic_description = cls.generate_basic_description() + more_details = ( + "Works by using a graph-based method to identify the most salient sentences in the document. \n" + "Strengths: \n - Fast with low memory usage \n - Allows for control of summary length \n " + "Weaknesses: \n - Not as accurate as neural methods. \n " + "Initialization arguments: \n " + "- `corpus`: Unlabelled corpus of documents. ` \n " + "- `summary_length`: sentence length of summaries \n " + "- `threshold`: Level of salience required for sentence to be included in summary." + ) + print(f"{basic_description} \n {'#'*20} \n {more_details}") diff --git a/model/single_doc/longformer_model.py b/model/single_doc/longformer_model.py new file mode 100644 index 0000000000000000000000000000000000000000..dfc406c7f6ed91cb2b678e1dddbfdaeadb189c84 --- /dev/null +++ b/model/single_doc/longformer_model.py @@ -0,0 +1,57 @@ +from transformers import LongformerTokenizer, EncoderDecoderModel +from .base_single_doc_model import SingleDocSummModel + + +class LongformerModel(SingleDocSummModel): + + # static variables + model_name = "Longformer" + is_extractive = False + is_neural = True + + def __init__(self): + super(LongformerModel, self).__init__() + + self.model = EncoderDecoderModel.from_pretrained( + "patrickvonplaten/longformer2roberta-cnn_dailymail-fp16" + ) + self.tokenizer = LongformerTokenizer.from_pretrained( + "allenai/longformer-base-4096" + ) + + def summarize(self, corpus, queries=None): + self.assert_summ_input_type(corpus, queries) + + summaries = list(map(lambda doc: self.summarize_single(doc), corpus)) + + return summaries + + def summarize_single(self, document): + # Tokenizes document and returns PyTorch torch.Tensor object with length attribute + tokenized_sequence = self.tokenizer( + document, + return_tensors="pt", + return_length=True, + truncation=True, + max_length=4096, + ) + print( + f"Longformer model: processing document of {tokenized_sequence.length} tokens" + ) + input_ids = tokenized_sequence.input_ids + # output_ids is tensor with one layer: output_ids[0] extracts tensor layer for decoding + output_ids = self.model.generate(input_ids) + + return self.tokenizer.decode(output_ids[0], skip_special_tokens=True) + + @classmethod + def show_capability(cls) -> None: + basic_description = cls.generate_basic_description() + more_details = ( + "A Longformer2Roberta model finetuned on CNN-DM dataset for summarization.\n\n" + "Strengths:\n - Correctly handles longer (> 2000 tokens) corpus.\n\n" + "Weaknesses:\n - Less accurate on contexts outside training domain.\n\n" + "Initialization arguments:\n " + " - `corpus`: Unlabelled corpus of documents.\n" + ) + print(f"{basic_description} \n {'#'*20} \n {more_details}") diff --git a/model/single_doc/pegasus_model.py b/model/single_doc/pegasus_model.py new file mode 100644 index 0000000000000000000000000000000000000000..91580ad6a57386276ba443e51a472d9b2d982f9f --- /dev/null +++ b/model/single_doc/pegasus_model.py @@ -0,0 +1,50 @@ +from transformers import PegasusForConditionalGeneration, PegasusTokenizer +from .base_single_doc_model import SingleDocSummModel + + +class PegasusModel(SingleDocSummModel): + # static variables + model_name = "Pegasus" + is_extractive = False + is_neural = True + + def __init__(self, device="cpu"): + super(PegasusModel, self).__init__() + + self.device = device + model_name = "google/pegasus-xsum" + print("init load pretrained tokenizer") + self.tokenizer = PegasusTokenizer.from_pretrained(model_name) + print("init load pretrained model with tokenizer on " + device) + # self.model = PegasusForConditionalGeneration.from_pretrained(model_name).to(device) + self.model = PegasusForConditionalGeneration.from_pretrained(model_name) + + def summarize(self, corpus, queries=None): + self.assert_summ_input_type(corpus, queries) + + print("batching") + # batch = self.tokenizer(corpus, truncation=True, padding='longest', return_tensors="pt").to(self.device) + batch = self.tokenizer(corpus, truncation=True, return_tensors="pt") + print("encoding batches") + # encoded_summaries = self.model.generate(**batch, max_length=40, max_time=120) + encoded_summaries = self.model.generate(batch["input_ids"], max_time=1024) + print("decoding batches") + # summaries = self.tokenizer.batch_decode(encoded_summaries, skip_special_tokens=True) + summaries = [self.tokenizer.decode(encoded_summaries[0])] + + return summaries + + @classmethod + def show_capability(cls): + basic_description = cls.generate_basic_description() + more_details = ( + "Introduced in 2019, a large neural abstractive summarization model trained on web crawl and " + "news data.\n " + "Strengths: \n - High accuracy \n - Performs well on almost all kinds of non-literary written " + "text \n " + "Weaknesses: \n - High memory usage \n " + "Initialization arguments: \n " + "- `device = 'cpu'` specifies the device the model is stored on and uses for computation. " + "Use `device='gpu'` to run on an Nvidia GPU." + ) + print(f"{basic_description} \n {'#'*20} \n {more_details}") diff --git a/model/single_doc/textrank_model.py b/model/single_doc/textrank_model.py new file mode 100644 index 0000000000000000000000000000000000000000..233d57559d1db67ece3a7ba27a63b94b5a78a954 --- /dev/null +++ b/model/single_doc/textrank_model.py @@ -0,0 +1,89 @@ +import spacy +import pytextrank # noqa: F401 +from math import sqrt +from operator import itemgetter +from .base_single_doc_model import SingleDocSummModel +from typing import Union, List + + +class TextRankModel(SingleDocSummModel): + # static variables + model_name = "TextRank" + is_extractive = True + is_neural = False + + def __init__(self, num_sentences=1): + super(TextRankModel, self).__init__() + + self.num_sentences = num_sentences + # load a spaCy model, depending on language, scale, etc. + self.nlp = spacy.load("en_core_web_sm") + self.nlp.add_pipe("textrank", last=True) + + def summarize( + self, corpus: Union[List[str], List[List[str]]], queries: List[str] = None + ) -> List[str]: + self.assert_summ_input_type(corpus, queries) + + return list(map(lambda x: " ".join(self.summarize_single(x)), corpus)) + + def summarize_single(self, corpus) -> List[str]: + # add PyTextRa nk to the spaCy pipeline + doc = self.nlp(corpus) + sent_bounds = [[s.start, s.end, set([])] for s in doc.sents] + + limit_phrases = self.num_sentences + phrase_id = 0 + unit_vector = [] + for p in doc._.phrases: + unit_vector.append(p.rank) + for chunk in p.chunks: + for sent_start, sent_end, sent_vector in sent_bounds: + if chunk.start >= sent_start and chunk.end <= sent_end: + sent_vector.add(phrase_id) + break + phrase_id += 1 + if phrase_id == limit_phrases: + break + + sum_ranks = sum(unit_vector) + + unit_vector = [rank / sum_ranks for rank in unit_vector] + + sent_rank = {} + sent_id = 0 + for sent_start, sent_end, sent_vector in sent_bounds: + sum_sq = 0.0 + for phrase_id in range(len(unit_vector)): + if phrase_id not in sent_vector: + sum_sq += unit_vector[phrase_id] ** 2.0 + sent_rank[sent_id] = sqrt(sum_sq) + sent_id += 1 + + sorted(sent_rank.items(), key=itemgetter(1)) + + sent_text = {} + sent_id = 0 + limit_sentences = self.num_sentences + summary_sentences = [] + for sent in doc.sents: + sent_text[sent_id] = sent.text + sent_id += 1 + num_sent = 0 + for sent_id, rank in sorted(sent_rank.items(), key=itemgetter(1)): + summary_sentences.append(sent_text[sent_id]) + num_sent += 1 + if num_sent == limit_sentences: + break + + return summary_sentences + + @classmethod + def show_capability(cls): + basic_description = cls.generate_basic_description() + more_details = ( + "A graphbased ranking model for text processing. Extractive sentence summarization. \n " + "Strengths: \n - Fast with low memory usage \n - Allows for control of summary length \n " + "Weaknesses: \n - Not as accurate as neural methods." + ) + print(f"{basic_description} \n {'#'*20} \n {more_details}") diff --git a/model/third_party/HMNet/DataLoader/README.md b/model/third_party/HMNet/DataLoader/README.md new file mode 100644 index 0000000000000000000000000000000000000000..0ed56d8a6bfa4680bbb2f169d35955927e52c494 --- /dev/null +++ b/model/third_party/HMNet/DataLoader/README.md @@ -0,0 +1 @@ +This dataloader is adapted from Microsoft's [infinibatch](https://github.com/microsoft/infinibatch) implementation, which is a library of checkpointable iterators for randomized data loading of massive data sets in deep neural network training. \ No newline at end of file diff --git a/model/third_party/HMNet/DataLoader/__init__.py b/model/third_party/HMNet/DataLoader/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..df61bf8713419f847d7c2ee8c6036797c7b03ef7 --- /dev/null +++ b/model/third_party/HMNet/DataLoader/__init__.py @@ -0,0 +1 @@ +from .infinibatch.infinibatch import datasets, iterators diff --git a/model/third_party/HMNet/DataLoader/infinibatch/LICENSE b/model/third_party/HMNet/DataLoader/infinibatch/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/model/third_party/HMNet/DataLoader/infinibatch/README.md b/model/third_party/HMNet/DataLoader/infinibatch/README.md new file mode 100644 index 0000000000000000000000000000000000000000..b16159add8b0c1ce4ca42a47f832134c5cce7d69 --- /dev/null +++ b/model/third_party/HMNet/DataLoader/infinibatch/README.md @@ -0,0 +1,23 @@ +# InfiniBatch + +To view the documentation, please clone the repository and go to docs/infinibatch/index.html + +To run unit tests, run the following command. +``` +python -m unittest discover -s test +``` + +When working on the documentation, install pdoc: +``` +pip install pdoc3 +``` +You can then start a local http server that dynamically updates the documentation: +``` +pdoc --template-dir do cs --http : infinibatch +``` + +We currently haven't set up the CI to automatically generate the documentation. +Before you merge anything into master, please delete the existing documentation in docs/infinibatch and run +``` +pdoc -o docs --template-dir docs --html infinibatch +``` \ No newline at end of file diff --git a/model/third_party/HMNet/DataLoader/infinibatch/bin/block_randomize.py b/model/third_party/HMNet/DataLoader/infinibatch/bin/block_randomize.py new file mode 100644 index 0000000000000000000000000000000000000000..d20c3583db347e51cb8407e8fc63ae92b1bec178 --- /dev/null +++ b/model/third_party/HMNet/DataLoader/infinibatch/bin/block_randomize.py @@ -0,0 +1,160 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +#!/usr/bin/python3.6 + +# simple command-line wrapper around the chunked_dataset_iterator +# Example: +# block_randomize my_chunked_data_folder/ +# block_randomize --azure-storage-key $MY_KEY https://myaccount.blob.core.windows.net/mycontainer/my_chunked_data_folder + +import os, sys, inspect + +sys.path.insert( + 0, + os.path.dirname( + os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) + ), +) # find our imports + +from infinibatch.datasets import chunked_dataset_iterator + +from typing import Union, Iterator, Callable, Any, Optional, Dict +import os, sys, re +import gzip + + +# helper functions to abstract access to Azure blobs +# @TODO: These will be abstracted into a helper library in a future version. +def _try_parse_azure_blob_uri(path: str): + try: + m = re.compile("https://([a-z0-9]*).blob.core.windows.net/([^/]*)/(.*)").match( + path + ) + # print (m.group(1)) + # print (m.group(2)) + # print (m.group(3)) + return (m.group(1), m.group(2), m.group(3)) + except: + return None + + +def _get_azure_key( + storage_account: str, credentials: Optional[Union[str, Dict[str, str]]] +): + if not credentials: + return None + elif isinstance(credentials, str): + return credentials + else: + return credentials[storage_account] + + +def read_utf8_file( + path: str, credentials: Optional[Union[str, Dict[str, str]]] +) -> Iterator[str]: + blob_data = _try_parse_azure_blob_uri(path) + if blob_data is None: + with open(path, "rb") as f: + data = f.read() + else: + try: + # pip install azure-storage-blob + from azure.storage.blob import BlobClient + except: + print( + "Failed to import azure.storage.blob. Please pip install azure-storage-blob", + file=sys.stderr, + ) + raise + data = ( + BlobClient.from_blob_url( + path, + credential=_get_azure_key( + storage_account=blob_data[0], credentials=credentials + ), + ) + .download_blob() + .readall() + ) + if path.endswith(".gz"): + data = gzip.decompress(data) + # @TODO: auto-detect UCS-2 by BOM + return iter(data.decode(encoding="utf-8").splitlines()) + + +def enumerate_files( + dir: str, ext: str, credentials: Optional[Union[str, Dict[str, str]]] +): + blob_data = _try_parse_azure_blob_uri(dir) + if blob_data is None: + return [ + os.path.join(dir, path.name) + for path in os.scandir(dir) + if path.is_file() and (ext is None or path.name.endswith(ext)) + ] + else: + try: + # pip install azure-storage-blob + from azure.storage.blob import ContainerClient + except: + print( + "Failed to import azure.storage.blob. Please pip install azure-storage-blob", + file=sys.stderr, + ) + raise + account, container, blob_path = blob_data + + print("enumerate_files: enumerating blobs in", dir, file=sys.stderr, flush=True) + # @BUGBUG: The prefix does not seem to have to sta rt; seems it can also be a substring + container_uri = "https://" + account + ".blob.core.windows.net/" + container + container_client = ContainerClient.from_container_url( + container_uri, credential=_get_azure_key(account, credentials) + ) + if not blob_path.endswith("/"): + blob_path += "/" + blob_uris = [ + container_uri + "/" + blob["name"] + for blob in container_client.walk_blobs(blob_path, delimiter="") + if (ext is None or blob["name"].endswith(ext)) + ] + print( + "enumerate_files:", + len(blob_uris), + "blobs found", + file=sys.stderr, + flush=True, + ) + for blob_name in blob_uris[:10]: + print(blob_name, file=sys.stderr, flush=True) + return blob_uris + + +if sys.argv[1] == "--azure-storage-key": + credential = sys.argv[2] + paths = sys.argv[3:] +else: + credential = None + paths = sys.argv[1:] + +chunk_file_paths = [ # enumerate all .gz files in the given paths + subpath for path in paths for subpath in enumerate_files(path, ".gz", credential) +] +chunk_file_paths.sort() # make sure file order is always the same, independent of OS +print( + "block_randomize: reading from", + len(chunk_file_paths), + "chunk files", + file=sys.stderr, +) + +ds = chunked_dataset_iterator( + chunk_refs=chunk_file_paths, + read_chunk_fn=lambda path: read_utf8_file(path, credential), + shuffle=True, + buffer_size=1000000, + seed=1, + use_windowed=True, +) +for line in ds: + print(line) diff --git a/model/third_party/HMNet/DataLoader/infinibatch/bin/block_randomize_and_batch.py b/model/third_party/HMNet/DataLoader/infinibatch/bin/block_randomize_and_batch.py new file mode 100644 index 0000000000000000000000000000000000000000..ed6cc8f0a3adcd0fa5b76fc18a5148395f869b2c --- /dev/null +++ b/model/third_party/HMNet/DataLoader/infinibatch/bin/block_randomize_and_batch.py @@ -0,0 +1,36 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + + +#!/usr/bin/python3.6 + +# simple command-line wrapper around BucketedReadaheadBatchIterator on a IterableChunkedDataset +# Example: +# block_randomize_and_batch my_chunked_data + +import os, sys, inspect + +sys.path.insert( + 0, + os.path.dirname( + os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) + ), +) # find our imports + +from infinibatch.datasets import chunked_dataset_iterator +from infinibatch.iterators import BucketedReadaheadBatchIterator + +sets = sys.argv[1:] + +ds = chunked_dataset_iterator(sets, shuffle=True, buffer_size=10000000, seed=1) +batch_labels = 500 +bg = BucketedReadaheadBatchIterator( + ds, + read_ahead=100, + key=lambda line: len(line), + batch_size=lambda line: batch_labels // (1 + len(line)), + seed=1, +) +for batch in bg: + print(f"\n---- size {len(batch)} ---\n") + print("\n".join(batch)) diff --git a/model/third_party/HMNet/DataLoader/infinibatch/docs/config.mako b/model/third_party/HMNet/DataLoader/infinibatch/docs/config.mako new file mode 100644 index 0000000000000000000000000000000000000000..b6b0e8da72e870314441c80638908c2626f0d525 --- /dev/null +++ b/model/third_party/HMNet/DataLoader/infinibatch/docs/config.mako @@ -0,0 +1,41 @@ +<%! + # This is a configuration file for pdoc3, the tool we use for generating html documentation from docstrings. + # Please look at the README.md for instruction on how to generate the documentation. + # Template configuration. Copy over in your template directory + # (used with --template-dir) and adapt as required. + html_lang = 'en' + show_inherited_members = False + extract_module_toc_into_sidebar = True + list_class_variables_in_index = True + sort_identifiers = False + show_type_annotations = True + # Show collapsed source code block next to each item. + # Disabling this can improve rendering speed of large modules. + show_source_code = True + # If set, format link s to objects in online source code repository + # according to this template. Supported keywords for interpolation + # are: commit, path, start_line, end_line. + #git_link_template = 'https://github.com/USER/PROJECT/blob/{commit}/{path}#L{start_line}-L{end_line}' + #git_link_template = 'https://gitlab.com/USER/PROJECT/blob/{commit}/{path}#L{start_line}-L{end_line}' + #git_link_template = 'https://bitbucket.org/USER/PROJECT/src/{commit}/{path}#lines-{start_line}:{end_line}' + #git_link_template = 'https://CGIT_HOSTNAME/PROJECT/tree/{path}?id={commit}#n{start-line}' + git_link_template = None + # A prefix to use for every HTML hyperlink in the generated documentation. + # No prefix results in all links being relative. + link_prefix = '' + # Enable syntax highlighting for code/source blocks by including Highlight.js + syntax_highlighting = True + # Set the style keyword such as 'atom-one-light' or 'github-gist' + # Options: https://github.com/highlightjs/highlight.js/tree/master/src/styles + # Demo: https://highlightjs.org/static/demo/ + hljs_style = 'github' + # If set, insert Google Analytics tracking code. Value is GA + # tracking id (UA-XXXXXX-Y). + google_analytics = '' + # If set, render LaTeX math syntax within \(...\) (inline equations), + # or within \[...\] or $$...$$ or `.. math::` (block equations) + # as nicely-formatted math formulas using MathJax. + # Note: in Python docstrings, either all backslashes need to be escaped (\\) + # or you need to use raw r-strings. + latex_math = False +%> \ No newline at end of file diff --git a/model/third_party/HMNet/DataLoader/infinibatch/docs/infinibatch/closablequeue.html b/model/third_party/HMNet/DataLoader/infinibatch/docs/infinibatch/closablequeue.html new file mode 100644 index 0000000000000000000000000000000000000000..c34daf178470f98409a676fc9c58d34451d8988a --- /dev/null +++ b/model/third_party/HMNet/DataLoader/infinibatch/docs/infinibatch/closablequeue.html @@ -0,0 +1,279 @@ + + + + + + +infinibatch.closablequeue API documentation + + + + + + + + + +
+
+
+

Module infinibatch.closablequeue

+
+
+
+ +Expand source code + +
from collections import deque
+from threading import Condition, Lock, Thread
+
+
+class ClosedException(Exception):
+    pass
+
+
+class ClosableQueue:
+    """
+    A thread-safe queue that can be closed
+
+    As long as the the queue is not closed, it behaves just like a thread-safe queue with a capacity limit:
+        - put blocks until the item can be added
+        - get blocks until there is an item to be returned
+
+    Once the queue is closed, no more items can be added but existing items can be removed:
+        - put always raises a ClosedException
+        - get returns an item if the queue is not empty and otherwise raises a ClosedException
+    """
+    def __init__(self, maxsize: int=1000):
+        self._maxsize = maxsize
+        self._queue = deque()
+        self._mutex = Lock()
+        self._not_empty = Condition(self._mutex)
+        self._not_full = Condition(self._mutex)
+        self._closed = False
+
+    def put(self, item):
+        with self._not_full:
+            if self._closed:
+                raise ClosedException('This queue has been closed, no more items can be added.')
+            while len
(self._queue) >= self._maxsize:
+                self._not_full.wait()
+                if self._closed:
+                    raise ClosedException('This queue has been closed, no more items can be added.')
+            self._queue.append(item)
+            self._not_empty.notify()
+        
+    def get(self):
+        with self._not_empty:
+            if self._closed and len(self._queue) == 0:
+                raise ClosedException('This queue has been closed and is empty, no more items can be retrieved.')
+            while len(self._queue) == 0:
+                self._not_empty.wait()
+                if self._closed and len(self._queue) == 0:
+                    raise ClosedException('This queue has been closed and is empty, no more items can be retrieved.')
+            item = self._queue.popleft()
+            self._not_full.notify()
+        return item
+            
+    def close(self):
+        with self._mutex:
+            self._closed = True
+            self._not_empty.notify_all()
+            self._not_full.notify_all()
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class ClosedException +(...) +
+
+

Common base class for all non-exit exceptions.

+
+ +Expand source code + +
class ClosedException(Exception):
+    pass
+
+

Ancestors

+
    +
  • builtins.Exception
  • +
  • builtins.BaseException
  • +
+
+
+class ClosableQueue +(maxsize: int = 1000) +
+
+

A thread-safe queue that can be closed

+

As long as the the queue is not closed, it behaves just like a thread-safe queue with a capacity limit: +- put blocks until the item can be added +- get blocks until there is an item to be returned

+

Once the queue is closed, no more items can be added but existing items can be removed: +- put always raises a ClosedException +- get returns an item if the queue is not empty and otherwise raises a ClosedException

+
+ +Expand source code + +
class ClosableQueue:
+    """
+    A thread-safe queue that can be closed
+
+    As long as the the queue is not closed, it behaves just like a thread-safe queue with a capacity limit:
+        - put blocks until the item can be added
+        - get blocks until there is an item to be returned
+
+    Once the queue is closed, no more items can be added but existing items can be removed:
+        - put always raises a ClosedException
+        - get returns an item if the queue is not empty and otherwise raises a ClosedException
+    """
+    def __init__(self, maxsize: int=1000):
+        self._maxsize = maxsize
+        self._queue = deque()
+        self._mutex = Lock()
+        self._not_empty = Condition(self._mutex)
+        self._not_full = Condition(self._mutex)
+        self._closed = False
+
+    def put(self, item):
+        with self._not_full:
+            if self._closed:
+                raise ClosedException('This queue has been closed, no more items can be added.')
+            while len(self._queue) >= self._maxsize:
+                self._not_full.wait()
+                if self._closed:
+                    raise ClosedException('This queue has been closed, no more items can be added.')
+            self._queue.append(item)
+            self._not_empty.notify()
+        
+    def get(self):
+        with self._not_empty:
+            if self._closed and len(self._queue) == 0:
+                raise ClosedException('This queue has been closed and is empty, no more items can be retrieved.')
+            while len(self._queue) == 0:
+                self._not_empty.wait()
+                if self._closed and len(self._queue) == 0:
+                    raise ClosedException('This queue has been closed and is empty, no more items can be retrieved.')
+            item = self._queue.popleft()
+            self._not_full.notify()
+        return item
+            
+    def close(self):
+        with self._mutex:
+            self._closed = True
+            self._not_empty.notify_all()
+            self._not_full.notify_all()
+
+

Methods

+
+
+def put(self, item) +
+
+
+
+ +Expand source code + +
def put(self, item):
+    with self._not_full:
+        if self._closed:
+            raise ClosedException('This queue has been closed, no more items can be added.')
+        while len(self._queue) >= self._maxsize:
+            self._not_full.wait()
+            if self._closed:
+                raise ClosedException('This queue has been closed, no more items can be added.')
+        self._queue.append(item)
+        self._not_empty.notify()
+
+
+
+def get(self) +
+
+
+
+ +Expand source code + +
def get(self):
+    with self._not_empty:
+        if self._closed and len(self._queue) == 0:
+            raise ClosedException('This queue has been closed and is empty, no more items can be retrieved.')
+        while len(self._queue) == 0:
+            self._not_empty.wait()
+            if self._closed and len(self._queue) == 0:
+                raise ClosedException('This queue has been closed and is empty, no more items can be retrieved.')
+        item = self._queue.popleft()
+        self._not_full.notify()
+    return item
+
+
+
+def close(self) +
+
+
+
+ +Expand source code + +
def close(self):
+    with self._mutex:
+        self._closed = True
+        self._not_empty.notify_all()
+        self._not_full.notify_all()
+
+
+
+
+
+
+
+ +
+ + + + + \ No newline at end of file diff --git a/model/third_party/HMNet/DataLoader/infinibatch/docs/infinibatch/datasets.html b/model/third_party/HMNet/DataLoader/infinibatch/docs/infinibatch/datasets.html new file mode 100644 index 0000000000000000000000000000000000000000..bcd7bcb81e9e2e6c0700fbf10d31fdc35f8576ee --- /dev/null +++ b/model/third_party/HMNet/DataLoader/infinibatch/docs/infinibatch/datasets.html @@ -0,0 +1,242 @@ + + + + + + +infinibatch.datasets API documentation + + + + + + + + + +
+
+
+

Module infinibatch.datasets

+
+
+
+ +Expand source code + +
from .iterators import create_source_iterator, SelectManyIterator, PrefetchIterator, BufferedShuffleIterator, BlockwiseShuffleIterator, MapIterator
+from typing import List, Union, Iterable, Iterator, Callable, Any, Optional, Dict
+import os, sys
+
+"""
+This module contains common datasets, which are implemented as convenience functions that compose underlying Infinibatch iterators.
+"""
+
+
+def bump_seed(seed: Optional[int], step = 1):
+    """
+    Helper to bump a random seed if not None.
+    """
+    return None if seed is None else seed + 1
+
+
+def chunked_dataset_iterator(chunk_refs: List, read_chunk_fn: Callable[[Any], Iterator], buffer_size: int,
+                             train: bool=True,
+                             seed: Optional[int]=None, shuffle: bool=True, use_windowed: bool=False,
+                             transform: Callable[[Any],Any]=None,
+                             prefetch: bool=True,
+                             num_instances: int=1, instance_rank: int=0):
+    """
+    Dataset reading data from gzipped chunks.
+
+    If train=True, this chunks are strided assigned to instances in strides and the data is infinitely repeated in permutations.
+    Otherwise, the chunks are split among the instances in consecutive blocks and the data is not repeated.
+    This way, when using this dataset for inference on multiple GPUs, to order the outputs in a way that corresponds
+    to the original order of the data items in the dataset, one simply has to collect the lists of outputs from each GPU
+    and then concatenate these lists in order of increasing rank.
+    When using MPI, this can be achieved by a gather-operation to get a list of lists of outputs, one list per GPU,
+    followed by flattening the lists back into a single list.
+
+    Args:
+        chunk_refs: references (such as path names) to chunk files
+        read_chunk_fn: function(chunk_ref) -> Iterator to read a chunk's content into an iterator over its items, e.g. read a file and split into text lines
+        train: see above
+        shuffle: if true, the data is shuffled. If train is False then shuffle must be False as well.
+        buffer_size: size of the buffer in number of samples / data items used for shuffling (default: 2**20)
+        transform: transform to be applied to each data item (transform(Any) -> Any)
+        prefetch: if True, insert a prefetch iterator with buffer_size
+        seed: random seed (or None)
+      
  num_instances: number of instances of this dataset. Meant for use with multi-process data loading, e.g., in distributed training.
+        instance_rank: rank of this instance of the dataset. Meant for use with multi-process data loading, e.g., in distributed training.
+        use_windowed: temporary option to switch back to the WindowedShuffleIterator (default False). Will go away once shown that we don't need it anymore.
+    """
+    if not train and shuffle:
+        raise ValueError('shuffling is not supported when train=False')
+    # set up the chunk reader
+    chunk_refs = create_source_iterator(chunk_refs, train=train, seed=seed, shuffle=shuffle, num_instances=num_instances, instance_rank=instance_rank)
+    # set up the item reader
+    samples = SelectManyIterator(source_iterator=chunk_refs, collection_selector=read_chunk_fn)
+    # wrap the I/O operation in a prefetch iterator
+    if prefetch:
+        samples = PrefetchIterator(samples, buffer_size)
+    # set up the item randomizer
+    if shuffle:
+        if use_windowed:
+            samples = BufferedShuffleIterator(samples, buffer_size, bump_seed(seed, 1))
+        else:
+            samples = BlockwiseShuffleIterator(samples, buffer_size, bump_seed(seed, 1))
+    # apply transform, if given
+    if transform is not None:
+        samples = MapIterator(samples, transform)
+    # this is what we are serving out
+    return samples
+
+
+
+
+
+
+
+

Functions

+
+
+def bump_seed(seed: Union[int, NoneType], step=1) +
+
+

Helper to bump a random seed if not None.

+
+ +Expand source code + +
def bump_seed(seed: Optional[int], step = 1):
+    """
+    Helper to bump a random seed if not None.
+    """
+    return None if seed is None else seed + 1
+
+
+
+def chunked_dataset_iterator(chunk_refs: List, read_chunk_fn: Callable[[Any], Iterator], buffer_size: int, train: bool = True, seed: Union[int, NoneType] = None, shuffle: bool = True, use_windowed: bool = False, transform: Callable[[Any], Any] = None, prefetch: bool = True, num_instances: int = 1, instance_rank: int = 0) +
+
+

Dataset reading data from gzipped chunks.

+

If train=True, this chunks are strided assigned to instances in strides and the data is infinitely repeated in permutations. +Otherwise, the chunks are split among the instances in consecutive blocks and the data is not repeated. +This way, when using this dataset for inference on multiple GPUs, to order the outputs in a way that corresponds +to the original order of the data items in the dataset, one simply has to collect the lists of outputs from each GPU +and then concatenate these lists in order of increasing rank. +When using MPI, this can be achieved by a gather-operation to get a list of lists of outputs, one list per GPU, +followed by flattening the lists back into a single list.

+

Args

+
+
chunk_refs
+
references (such as path names) to chunk files
+
read_chunk_fn
+
function(chunk_ref) -> Iterator to read a chunk's content into an iterator over its items, e.g. read a file and split into text lines
+
train
+
see above
+
shuffle
+
if true, the data is shuffled. If train is False then shuffle must be False as well.
+
buffer_size
+
size of the buffer in number of samples / data items used for shuffling (default: 2**20)
+
transform
+
transform to be applied to each data item (transform(Any) -> Any)
+
prefetch
+
if True, insert a prefetch iterator with buffer_size
+
seed
+
random seed (or None)
+
num_instances
+
number of instances of this dataset. Meant for use with multi-process data loading, e.g., in distributed training.
+
instance_rank
+
rank of this instance of the dataset. Meant for use with multi-process data loading, e.g., in distributed training.
+
use_windowed
+
temporary option to switch back to the WindowedShuffleIterator (default False). Will go away once shown that we don't need it anymore.
+
+
+ +Expand source code + +
def chunked_dataset_iterator(chunk_refs: List, read_chunk_fn: Callable[[Any], Iterator], buffer_size: int,
+                             train: bool=True,
+                             seed: Optional[int]=None, shuffle: bool=True, use_windowed: bool=False,
+                             transform: Callable[[Any],Any]=None,
+                             prefetch: bool=True,
+                             num_instances: int=1, instance_rank: int=0):
+    """
+    Dataset reading data from gzipped chunks.
+
+    If train=True, this chunks are strided assigned to instances in strides and the data is infinitely repeated in permutations.
+    Otherwise, the chunks are split among the instances in consecutive blocks and the data is not repeated.
+    This way, when using this dataset for inference on multiple GPUs, to order the outputs in a way that corresponds
+    to the original order of the data items in the dataset, one simply has to collect the lists of outputs from each GPU
+    and then concatenate these lists in order of increasing rank.
+    When using MPI, this can be achieved by a gather-operation to get a list of lists of outputs, one list per GPU,
+    followed by flattening the lists back into a single list.
+
+    Args:
+        chunk_refs: references (such as path names) to chunk files
+        read_chunk_fn: function(chunk_ref) -> Iterator to read a chunk's content into an iterator over its items, e.g. read a file and split into text lines
+        train: see above
+        shuffle: if true, the data is shuffled. If train is False then shuffle must be False as well.
+        buffer_size: size of the buffer in number of samples / data items used for shuffling (default: 2**20)
+        transform: transform to be applied to each data item (transform(Any) -> Any)
+        prefetch: if True, insert a prefetch iterator with buffer_size
+        seed: random seed (or None)
+        num_instances: number of instances of this dataset. Meant for use with multi-process data loading, e.g., in distributed training.
+        instance_rank: rank of this instance of the dataset. Meant for use with multi-process data loading, e.g., in distributed training.
+        use_windowed: temporary option to switch back to the WindowedShuffleIterator (default False). Will go away once shown that we don't need it anymore.
+    """
+    if not train and shuffle:
+        raise ValueError('shuffling is not supported when train=False')
+    # set up the chunk reader
+    chunk_refs = create_source_iterator(chunk_refs, train=train, seed=seed, shuffle=shuffle, num_instances=num_instances, instance_rank=instance_rank)
+    # set up the item reader
+    samples = SelectManyIterator(source_iterator=chunk_refs, collection_selector=read_chunk_fn)
+    # wrap the I/O operation in a prefetch iterator
+    if prefetch:
+        samples = PrefetchIterator(samples, buffer_size)
+    # set up the item randomizer
+    if shuffle:

+        if use_windowed:
+            samples = BufferedShuffleIterator(samples, buffer_size, bump_seed(seed, 1))
+        else:
+            samples = BlockwiseShuffleIterator(samples, buffer_size, bump_seed(seed, 1))
+    # apply transform, if given
+    if transform is not None:
+        samples = MapIterator(samples, transform)
+    # this is what we are serving out
+    return samples
+
+
+
+
+
+
+
+ +
+ + + + + \ No newline at end of file diff --git a/model/third_party/HMNet/DataLoader/infinibatch/docs/infinibatch/index.html b/model/third_party/HMNet/DataLoader/infinibatch/docs/infinibatch/index.html new file mode 100644 index 0000000000000000000000000000000000000000..b121c03951b6400592ed517bb0b6d8c94ff2b842 --- /dev/null +++ b/model/third_party/HMNet/DataLoader/infinibatch/docs/infinibatch/index.html @@ -0,0 +1,629 @@ + + + + + + +infinibatch API documentation + + + + + + + + + +
+
+
+

Module infinibatch

+
+
+

Infinibatch is a library of checkpointable iterators for randomized data loading of massive data sets in deep neural network training.

+

Features

+
    +
  • support for corpora much larger than fit into RAM
  • +
  • hierarchical block+sentence-level randomization over the whole corpus, different randomization in each epoch
  • +
  • only load the data that is needed
  • +
  • very fast start-up time (does not need to read full corpus)
  • +
  • only requires the most basic of data preparation (e.g. no indexing)
  • +
  • for multi-GPU, only load what the respective GPU needs
  • +
  • 100% accurate check-pointing, restore from checkpoint should not read all data up to the checkpoint
  • +
  • support automatic bucketed batching with dynamic batch sizes
  • +
  • pre-fetching thread
  • +
  • composable, as to support for complex batching, e.g. negative samples from multiple documents
  • +
+

Getting Started

+

Infinibatch requires Python 3.5 and has no dependencies. +There is presently no pip package. +To install it, please copy this library into a subfolder in your project:

+
cd YOUR_PROJECT_FOLDER
+git clone <https://msasg.visualstudio.com/DefaultCollection/SDRG/_git/infinibatch>
+
+

or, better, as a submodule reference:

+
git submodule add <https://msasg.visualstudi
o.com/DefaultCollection/SDRG/_git/infinibatch>
+
+

It is now located at infinibatch/infinibatch, e.g. the main import file is infinibatch/infinibatch/__init__.py.

+

To import it, you need to add that folder to your PYTHONPATH variable externally, or to sys.path inside the code:

+
import sys
+sys.path.insert(0,'infinibatch')  # note: relative paths are relative to your current dir, not to the python script
+import infinibatch
+
+

Tutorial

+

This little tutorial walks you through the steps of preparing your data and consuming them from Python code as batches.

+

Infinibatch Basics: Iterators and Checkpointing

+

Infinibatch provides Python iterators +to read your data. +An iterator represents a stream of data that can be retrieved item by item, e.g. via a +for loop or repeatedly calling next() on it.

+

Infinibatch is agnostic to the data type of the items, which is determined by a user-supplied file-read function. +In NLP applications, items would typically be tuples of text. In other applications, +they can be images or an audio file with a textual annotation.

+

Infinibatch makes it easy to read your data in randomized order, and supports checkpointing, which allows you to restart training exactly where you left off.

+

Randomization is done on the fly, which means that it is not necessary to read the entire data set into memory +to be shuffled. Infinibatch implements a hierarchical shuffling algorithm +that only holds a subset of the data in RAM at any point in time.

+

Infinibatch iterators are checkpointable. +Checkpointing lets you retrieve the current position (the "checkpoint") in the data stream at any time, so that +later, you can "rewind" to that same position. +The sad reality is that long-running trainings occasionally crash. +To be able to continue a crashed training as if it had not crashed, +save your Infinibatch iterator's checkpoint to disk whenever you save an intermediate model during training. +To restart a crashed training, reset the iterator to the saved checkpoint. +The data reader will now yield the exact same data-item sequence it would have yielded without the crash.

+

Data Preparation

+

Infinibatch has one requirement on your data organization: +To use your data with Infinibatch, it must be split into a large number of small chunks. +A chunk is the smallest unit of data that is loaded from disk into RAM. Infinibatch holds a random subset of chunks in memory +that it randomly draws samples from.

+

Below we want to show how such a split can be created. An easy way to split your data into chunks is with the Linux split command.

+

In this tutorial, our "corpus" consists of 6 lines of text, where each line is one data item. +To create that corpus, please run this command in a bash shell. It creates a 6-line text file named corpus.txt:

+
echo \
+'Lorem ipsum dolor sit amet,
+consectetur adipiscing elit,
+sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
+Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
+Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
+The quick brown fox jumps over the lazy dog.' \
+> corpus.txt
+
+

Now let us split it into 3 chunks of 2 lines each. Each chunk is stored as a zipped text file. +We will create them inside a new subdirectory called corpus_chunks:

+
mkdir corpus_chunks
+split  --lines 2  --numeric-suffixes                 \
+       --filter 'gzip > corpus_chunks/$FILE.txt.gz'  \
+       corpus.txt  corpus.
+
+

This will have created three files: corpus_chunks/corpus.00.txt.gz, corpus_chunks/corpus.01.txt.gz, and corpus_chunks/corpus.02.txt.gz. +To verify whether the data has been split as expected, you can use this command:

+
zcat corpus_chunks/corpus.*.txt.gz
+
+

Hint: For large corpora, we recommend replacing gzip by pigz (apt-get install pigz), which runs notably faster via multi-threading.

+

Reading Items in Random Order With Infinibatch

+

We will first show the easiest way to read data with Infinibatch, using the helper function chunked_dataset_iterator``(). +This function will create an Infinibatch iterator that yields the content of your data in random order. +Please the following program:

+
import sys, gzip, glob
+sys.path.insert(0,'infinibatch')
+from infinibatch import datasets as ds
+
+ds = ds.chunked_dataset_iterator(
+    chunk_refs = glob.glob('corpus_chunks/corpus.*.txt.gz'),
+    read_chunk_fn = lambda path: iter(gzip.decompress(open(path, "rb")  \
+                                      .read()).decode(encoding='utf-8') \
+                                      .splitlines()),
+    buffer_size = 6, seed = 1)
+
+for i in range(10):
+    print(next(ds))
+
+

You should get output that contains the 6 example lines in randomized order:

+
Lorem ipsum dolor sit amet,
+consectetur adipiscing elit,
+Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
+Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
+The quick brown fox jumps over the lazy dog.
+sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
+consectetur adipiscing elit,
+Lorem ipsum dolor sit amet,
+The quick brown fox jumps over the lazy dog.
+sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
+
+

Note: The buffer_size parameter determines how many sentences are read into memory at any given time, +to draw randomized items from. In real settings with corpora of hundreds of millions of text lines, +the buffer_size parameter should be set in the millions. +RAM usage and startup time will be proportional to the buffer size +(but much lower than having to load the entire corpus into RAM).

+

Reading Items of Different Lengths in Batches

+

For deep learning, we want to group multiple items into batches. +For NLP tasks, items are often lines of text of varying length. +Infinibatch implements an algorithm that randomizes the input sequence and groups it into +batches of approximately the same length (aka bucketing).

+

Infinibatch's BucketedReadaheadBatchIterator performs this task. +It implements an algorithm modeled after the Marian toolkit +that preloads a large number of randomized items (typically millions; in this example: 6), +sorts them and groups them into batches of similar length, and then yields +them, in turn, in randomized order.

+

Here is an example. Note that the BucketedReadaheadBatchIterator accepts +the previous randomized sentence sequence iterator (ds) as the source of items to randomize over. +This is an example how one forms pipelines of iterators with Infinibatch +(a concept familiar from Python's own itertools). +Once an iterator is passed to another as its source, consider it owned by that other iterator, +it must no longer be accessed by the calling code.

+
import sys, gzip, glob
+sys.path.insert(0,'infinibatch')
+from infinibatch import datasets as ds
+from infinibatch import iterators as it
+
+ds = ds.chunked_dataset_iterator(
+    chunk_refs = glob.glob('corpus_chunks/corpus.*.txt.gz'),
+    read_chun
k_fn = lambda path: iter(gzip.decompress(open(path, "rb")  \
+                                      .read()).decode(encoding='utf-8') \
+                                      .splitlines()),
+    buffer_size = 6, seed = 1)
+
+bs = it.BucketedReadaheadBatchIterator(
+    source_iterator = ds,   # note: this is the iterator from above
+    read_ahead = 6,
+    key = lambda line: len(line),
+    batch_size = 2,
+    seed = 1)
+
+for i in range(25):
+    print(next(bs))
+
+

This code should output something like this:

+
['sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.',
+ 'The quick brown fox jumps over the lazy dog.']
+['consectetur adipiscing elit,', 'Lorem ipsum dolor sit amet,']
+['Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.',
+ 'Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.']
+
+

followed by different permutations of the same tuples. +As you can see, the sentences are in random order and grouped in batches of 2 of approximately the same length. +You may notice that there is no variation in how the items get grouped into batches–that +is an artifact of this example, and generally not the case in real use when the data size is much larger +than the batch size.

+

In NLP, sentence length often varies considerably. As a result, using batches of a fixed number of lines, +as in the example above, will waste GPU RAM and cores. +This is because the number of lines is limited by the longest possible sequence; batches of shorter lines +would leave GPU cycles on the table. +Ideally, one would use batches that have as many lines as fit into GPU RAM, +given the number of tokens of the longest line in the batch. +To support variable batch sizes, Infinibatch allows to pass a function as the batch_size parameter. +That function will be given the longest item of a batch and should estimate how many items of at most this length can fit.

+

In our example, we assume that batches can hold at most 150 tokens. +Please change the above code as follows:

+
    batch_size = lambda longest_line: 150 // len(longest_line),
+
+

The output looks like this:

+
['consectetur adipiscing elit,', 'Lorem ipsum dolor sit amet,']
+['Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.']
+['sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.',
+ 'The quick brown fox jumps over the lazy dog.']
+['Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.']
+
+

That shorter sentences got grouped, while longer did not because they would exceed the total of 150 characters.

+

Reading Batches Into Numpy Arrays

+

Lastly, we will need to feed batches into our favorite deep-learning tool. +We will show how to convert the batches of text lines into padded numpy arrays.

+

In a typical NLP application, text items would be tokenized, and then each token +would be represented by an index into a unit vocabulary. +For simplicity, in this example each character is its own token, +and each token's numeric unit index is just its ASCII code. +These sequences are then padded to equal length with -1, and converted into a numpy array.

+

Please rerun the previous example, but first insert the following code before the final for loop. +This example uses an Infinibatch MapIterator, which applies a user-supplied function or +lambda to each item:

+
import numpy as np
+def collate(lines_batch):
+    # tokenize all lines in the batch and map to unit ids
+    ids_batch = [[ord(c) for c in line] for line in lines_batch]
+    # create a padded numpy array as wide as the longest line,
+    # where shorter sequences are padded with -
1
+    width = max(len(ids) for ids in ids_batch)
+    return np.array([ids + [-1] * (width-len(ids)) for ids in ids_batch])
+
+bs = it.MapIterator(
+    source_iterator = bs,
+    transform = collate)
+
+

This will output batches like this. Note that in batches with multiple sentences, +some entries are padded with -1.

+
[[ 99 111 110 115 101  99 116 101 116 117 114  32  97 100 105 112 105 115
+   99 105 110 103  32 101 108 105 116  44]
+ [ 76 111 114 101 109  32 105 112 115 117 109  32 100 111 108 111 114  32
+  115 105 116  32  97 109 101 116  44  -1]]
+[[ 85 116  32 101 110 105 109  32  97 100  32 109 105 110 105 109  32 118
+  101 110 105  97 109  44  32 113 117 105 115  32 110 111 115 116 114 117
+  100  32 101 120 101 114  99 105 116  97 116 105 111 110  32 117 108 108
+   97 109  99 111  32 108  97  98 111 114 105 115  32 110 105 115 105  32
+  117 116  32  97 108 105 113 117 105 112  32 101 120  32 101  97  32  99
+  111 109 109 111 100 111  32  99 111 110 115 101 113 117  97 116  46]]
+[[115 101 100  32 100 111  32 101 105 117 115 109 111 100  32 116 101 109
+  112 111 114  32 105 110  99 105 100 105 100 117 110 116  32 117 116  32
+  108  97  98 111 114 101  32 101 116  32 100 111 108 111 114 101  32 109
+   97 103 110  97  32  97 108 105 113 117  97  46]
+ [ 84 104 101  32 113 117 105  99 107  32  98 114 111 119 110  32 102 111
+  120  32 106 117 109 112 115  32 111 118 101 114  32 116 104 101  32 108
+   97 122 121  32 100 111 103  46  -1  -1  -1  -1  -1  -1  -1  -1  -1  -1
+   -1  -1  -1  -1  -1  -1  -1  -1  -1  -1  -1  -1]]
+[[ 68 117 105 115  32  97 117 116 101  32 105 114 117 114 101  32 100 111
+  108 111 114  32 105 110  32 114 101 112 114 101 104 101 110 100 101 114
+  105 116  32 105 110  32 118 111 108 117 112 116  97 116 101  32 118 101
+  108 105 116  32 101 115 115 101  32  99 105 108 108 117 109  32 100 111
+  108 111 114 101  32 101 117  32 102 117 103 105  97 116  32 110 117 108
+  108  97  32 112  97 114 105  97 116 117 114  46]]
+
+

Where To Go From Here

+

The above tutorial showed you the use of the most common iterator type, as created by the +convenience function chunked_dataset_iterator().

+

Not all real-life scenarios are covered by this function. For example, multi-task learning +scenarios require more complex combinations of data. To create those, you will need +to compose the necessary data reader from the underlying building blocks. +This is described at the documentation of the module infinibatch.iterators.

+
+ +Expand source code + +
"""
+Infinibatch is a library of checkpointable iterators for randomized data loading of massive data sets in deep neural network training.
+
+
+## Features
+
+  * support for corpora much larger than fit into RAM
+  * hierarchical block+sentence-level randomization over the whole corpus, different randomization in each epoch
+  * only load the data that is needed
+  * very fast start-up time (does not need to read full corpus)
+  * only requires the most basic of data preparation (e.g. no indexing)
+  * for multi-GPU, only load what the respective GPU needs
+  * 100% accurate check-pointing, restore from checkpoint should not read all data up to the checkpoint
+  * support automatic bucketed batching with dynamic batch sizes
+  * pre-fetching thread
+  * composable, as to support for complex batching, e.g. negative samples from multiple documents
+
+
+## Getting Started
+
+Infinibatch requires Python 3.5 and has no dependencies.
+There is presently no pip package.
+To install it, please copy this library into a subfolder in your project:
+```bash
+cd YOUR_PROJECT_FOLDER
+git clone https://msasg.visualstudio.com/DefaultCollection/SDRG/_git/infinibatch
+```
+or, better, as a submodule reference:
+```bash
+git submodule add https://msasg.visualstudio.com/DefaultCo
llection/SDRG/_git/infinibatch
+```
+It is now located at `infinibatch/infinibatch`, e.g. the main import file is `infinibatch/infinibatch/__init__.py`.
+
+To import it, you need to add that folder to your `PYTHONPATH` variable externally, or to `sys.path` inside the code:
+```python
+import sys
+sys.path.insert(0,'infinibatch')  # note: relative paths are relative to your current dir, not to the python script
+import infinibatch
+```
+
+## Tutorial
+
+This little tutorial walks you through the steps of preparing your data and consuming them from Python code as batches.
+
+### Infinibatch Basics: Iterators and Checkpointing
+
+Infinibatch provides [Python iterators](https://docs.python.org/3.5/glossary.html#term-iterator)
+to read your data.
+An iterator represents a stream of data that can be retrieved item by item, e.g. via a
+`for` loop or repeatedly calling `next()` on it.
+
+Infinibatch is agnostic to the data type of the items, which is determined by a user-supplied file-read function.
+In NLP applications, items would typically be tuples of text. In other applications,
+they can be images or an audio file with a textual annotation.
+
+Infinibatch makes it easy to read your data in randomized order, and supports checkpointing, which allows you to restart training exactly where you left off.
+
+Randomization is done _on the fly_, which means that it is not necessary to read the entire data set into memory
+to be shuffled. Infinibatch implements a hierarchical shuffling algorithm
+that only holds a subset of the data in RAM at any point in time.
+
+Infinibatch iterators are _checkpointable_.
+Checkpointing lets you retrieve the current position (the "checkpoint") in the data stream at any time, so that
+later, you can "rewind" to that same position.
+The sad reality is that long-running trainings occasionally crash.
+To be able to continue a crashed training as if it had not crashed,
+save your Infinibatch iterator's checkpoint to disk whenever you save an intermediate model during training.
+To restart a crashed training, reset the iterator to the saved checkpoint.
+The data reader will now yield the exact same data-item sequence it would have yielded without the crash.
+
+### Data Preparation
+
+Infinibatch has one requirement on your data organization:
+To use your data with Infinibatch, it must be split into a large number of small chunks.
+A chunk is the smallest unit of data that is loaded from disk into RAM. Infinibatch holds a random subset of chunks in memory
+that it randomly draws samples from.
+
+Below we want to show how such a split can be created. An easy way to split your data into chunks is with the Linux `split` command.
+
+In this tutorial, our "corpus" consists of 6 lines of text, where each line is one data item.
+To create that corpus, please run this command in a bash shell. It creates a 6-line text file named `corpus.txt`:
+```bash
+echo \\
+'Lorem ipsum dolor sit amet,
+consectetur adipiscing elit,
+sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
+Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
+Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
+The quick brown fox jumps over the lazy dog.' \\
+> corpus.txt
+```
+Now let us split it into 3 chunks of 2 lines each. Each chunk is stored as a zipped text file.
+We will create them inside a new subdirectory called `corpus_chunks`:
+```bash
+mkdir corpus_chunks
+split  --lines 2  --numeric-suffixes                 \\
+       --filter 'gzip > corpus_chunks/$FILE.txt.gz'  \\
+       corpus.txt  corpus.
+```
+This will have created three files: `corpus_chunks/corpus.00.txt.gz`, `corpus_chunks/corpus.01.txt.gz`, and `corpus_chunks/corpus.02.txt.gz`.
+To verify whether the data has been split as expected, you can use this command:
+```bash
+zcat corpus_chunks/corpus.*.txt.gz
+```
+
+Hint: For large corpora, we recommend replacing `gzip` by `pigz` (`apt-get install pi
gz`), which runs notably faster via multi-threading.
+
+### Reading Items in Random Order With Infinibatch
+
+We will first show the easiest way to read data with Infinibatch, using the helper function `chunked_dataset_iterator``()`.
+This function will create an Infinibatch iterator that yields the content of your data in random order.
+Please the following program:
+```python
+import sys, gzip, glob
+sys.path.insert(0,'infinibatch')
+from infinibatch import datasets as ds
+
+ds = ds.chunked_dataset_iterator(
+    chunk_refs = glob.glob('corpus_chunks/corpus.*.txt.gz'),
+    read_chunk_fn = lambda path: iter(gzip.decompress(open(path, "rb")  \\
+                                      .read()).decode(encoding='utf-8') \\
+                                      .splitlines()),
+    buffer_size = 6, seed = 1)
+
+for i in range(10):
+    print(next(ds))
+```
+You should get output that contains the 6 example lines in randomized order:
+```text
+Lorem ipsum dolor sit amet,
+consectetur adipiscing elit,
+Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
+Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
+The quick brown fox jumps over the lazy dog.
+sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
+consectetur adipiscing elit,
+Lorem ipsum dolor sit amet,
+The quick brown fox jumps over the lazy dog.
+sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
+```
+Note: The `buffer_size` parameter determines how many sentences are read into memory at any given time,
+to draw randomized items from. In real settings with corpora of hundreds of millions of text lines,
+the `buffer_size` parameter should be set in the millions.
+RAM usage and startup time will be proportional to the buffer size
+(but much lower than having to load the entire corpus into RAM).
+
+### Reading Items of Different Lengths in Batches
+
+For deep learning, we want to group multiple items into batches.
+For NLP tasks, items are often lines of text of varying length.
+Infinibatch implements an algorithm that randomizes the input sequence and groups it into
+batches of approximately the same length (aka _bucketing_).
+
+Infinibatch's `BucketedReadaheadBatchIterator` performs this task.
+It implements an algorithm modeled after the [Marian toolkit](https://github.com/marian-nmt/marian)
+that preloads a large number of randomized items (typically millions; in this example: 6),
+sorts them and groups them into batches of similar length, and then yields
+them, in turn, in randomized order.
+
+Here is an example. Note that the `BucketedReadaheadBatchIterator` accepts
+the previous randomized sentence sequence iterator (`ds`) as the source of items to randomize over.
+This is an example how one forms pipelines of iterators with Infinibatch
+(a concept familiar from Python's own `itertools`).
+Once an iterator is passed to another as its source, consider it owned by that other iterator,
+it must no longer be accessed by the calling code.
+```python
+import sys, gzip, glob
+sys.path.insert(0,'infinibatch')
+from infinibatch import datasets as ds
+from infinibatch import iterators as it
+
+ds = ds.chunked_dataset_iterator(
+    chunk_refs = glob.glob('corpus_chunks/corpus.*.txt.gz'),
+    read_chunk_fn = lambda path: iter(gzip.decompress(open(path, "rb")  \\
+                                      .read()).decode(encoding='utf-8') \\
+                                      .splitlines()),
+    buffer_size = 6, seed = 1)
+
+bs = it.BucketedReadaheadBatchIterator(
+    source_iterator = ds,   # note: this is the iterator from above
+    read_ahead = 6,
+    key = lambda line: len(line),
+    batch_size = 2,
+    seed = 1)
+
+for i in range(25):
+    print(next(bs))
+```
+This code should output something like this:
+```python
+['sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.',
+ 'The quick brown fox jumps over the lazy dog.']
+['consectetur adipiscing elit,', 'Lorem ipsum dolor sit amet,']
+['Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.',
+ 'Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.']
+```
+followed by different permutations of the same tuples.
+As you can see, the sentences are in random order and grouped in batches of 2 of approximately the same length.
+You may notice that there is no variation in how the items get grouped into batches--that
+is an artifact of this example, and generally not the case in real use when the data size is much larger
+than the batch size.
+
+In NLP, sentence length often varies considerably. As a result, using batches of a fixed number of lines,
+as in the example above, will waste GPU RAM and cores.
+This is because the number of lines is limited by the longest possible sequence; batches of shorter lines
+would leave GPU cycles on the table.
+Ideally, one would use batches that have as many lines as fit into GPU RAM,
+given the number of tokens of the longest line in the batch.
+To support variable batch sizes, Infinibatch allows to pass a function as the `batch_size` parameter.
+That function will be given the longest item of a batch and should estimate how many items of at most this length can fit.
+
+In our example, we assume that batches can hold at most 150 tokens.
+Please change the above code as follows:
+```python
+    batch_size = lambda longest_line: 150 // len(longest_line),
+```
+The output looks like this:
+```
+['consectetur adipiscing elit,', 'Lorem ipsum dolor sit amet,']
+['Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.']
+['sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.',
+ 'The quick brown fox jumps over the lazy dog.']
+['Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.']
+```
+That shorter sentences got grouped, while longer did not because they would exceed the total of 150 characters.
+
+### Reading Batches Into Numpy Arrays
+
+Lastly, we will need to feed batches into our favorite deep-learning tool.
+We will show how to convert the batches of text lines into padded `numpy` arrays.
+
+In a typical NLP application, text items would be tokenized, and then each token
+would be represented by an index into a unit vocabulary.
+For simplicity, in this example each character is its own token,
+and each token's numeric unit index is just its ASCII code.
+These sequences are then padded to equal length with -1, and converted into a `numpy` array.
+
+Please rerun the previous example, but first insert the following code before the final `for` loop.
+This example uses an Infinibatch `MapIterator`, which applies a user-supplied function or
+lambda to each item:
+```python
+import numpy as np
+def collate(lines_batch):
+    # tokenize all lines in the batch and map to unit ids
+    ids_batch = [[ord(c) for c in line] for line in lines_batch]
+    # create a padded numpy array as wide as the longest line,
+    # where shorter sequences are padded with -1
+    width = max(len(ids) for ids in ids_batch)
+    return np.array([ids + [-1] * (width-len(ids)) for ids in ids_batch])
+
+bs = it.MapIterator(
+    source_iterator = bs,
+    transform = collate)
+```
+This will output batches like this. Note that in batches with multiple sentences,
+some entries are padded with `-1`.
+```python
+[[ 99 111 110 115 101  99 116 101 116 117 114  32  97 100 105 112 105 115
+   99 105 110 103  32 101 108 105 116  44]
+ [ 76 111 114 101 109  32 105 112 115 117 109  32 100 111 108 111 114  32
+  115 105 116  32  97 109 101 116  44  -1]]
+[[ 85 116  32 101 110 105 109  32  97 100  32 109 105 110 105 109  32 118
+  101 110 105  97 109  44  32 113 117 105 115  32 110 111 115 116 114 117
+  100  32 101 120 101 114  99 105 116  97 116 105 111 110  32 117 108 108
+   97 109  99 111  32 108  9
7  98 111 114 105 115  32 110 105 115 105  32
+  117 116  32  97 108 105 113 117 105 112  32 101 120  32 101  97  32  99
+  111 109 109 111 100 111  32  99 111 110 115 101 113 117  97 116  46]]
+[[115 101 100  32 100 111  32 101 105 117 115 109 111 100  32 116 101 109
+  112 111 114  32 105 110  99 105 100 105 100 117 110 116  32 117 116  32
+  108  97  98 111 114 101  32 101 116  32 100 111 108 111 114 101  32 109
+   97 103 110  97  32  97 108 105 113 117  97  46]
+ [ 84 104 101  32 113 117 105  99 107  32  98 114 111 119 110  32 102 111
+  120  32 106 117 109 112 115  32 111 118 101 114  32 116 104 101  32 108
+   97 122 121  32 100 111 103  46  -1  -1  -1  -1  -1  -1  -1  -1  -1  -1
+   -1  -1  -1  -1  -1  -1  -1  -1  -1  -1  -1  -1]]
+[[ 68 117 105 115  32  97 117 116 101  32 105 114 117 114 101  32 100 111
+  108 111 114  32 105 110  32 114 101 112 114 101 104 101 110 100 101 114
+  105 116  32 105 110  32 118 111 108 117 112 116  97 116 101  32 118 101
+  108 105 116  32 101 115 115 101  32  99 105 108 108 117 109  32 100 111
+  108 111 114 101  32 101 117  32 102 117 103 105  97 116  32 110 117 108
+  108  97  32 112  97 114 105  97 116 117 114  46]]
+```
+
+## Where To Go From Here
+
+The above tutorial showed you the use of the most common iterator type, as created by the
+convenience function `chunked_dataset_iterator()`.
+
+Not all real-life scenarios are covered by this function. For example, multi-task learning
+scenarios require more complex combinations of data. To create those, you will need
+to compose the necessary data reader from the underlying building blocks.
+This is described at the documentation of the module `iterators`.
+"""
+
+
+
+

Sub-modules

+
+
infinibatch.closablequeue
+
+
+
+
infinibatch.datasets
+
+
+
+
infinibatch.iterators
+
+

Overview …

+
+
infinibatch.torch
+
+
+
+
+
+
+
+
+
+
+
+
+ +
+ + + + + \ No newline at end of file diff --git a/model/third_party/HMNet/DataLoader/infinibatch/docs/infinibatch/iterators.html b/model/third_party/HMNet/DataLoader/infinibatch/docs/infinibatch/iterators.html new file mode 100644 index 0000000000000000000000000000000000000000..ace103bf246d5956b51285592bdc2cdeae494053 --- /dev/null +++ b/model/third_party/HMNet/DataLoader/infinibatch/docs/infinibatch/iterators.html @@ -0,0 +1,2696 @@ + + + + + + +infinibatch.iterators API documentation + + + + + + + + + +
+
+
+

Module infinibatch.iterators

+
+
+

Overview

+

This part of the documentation covers the advanced usage of Infinibatch by assembling custom data loading pipelines. +Before you continue, please go through the tutorial on the top-level of the documentation of the infinibatch module.

+

Two of the main features of Infinibatch are lazy evaluation through the use of iterators +and built-in support for checkpointing. +In this section, we give an introduction to these features and the basic usage of the Infinibatch iterator library.

+

Iterators

+

As a Python programmer, you are probably familiar with the concept of iterators. +According to the Python documentation, +an iterator is an object representing a stream of data, +and repeated calls to the iterator's __next__() method (or passing it to the built-in function next()) +return successive items in the stream. +It is important not to confuse an iterator +with an iterable. +For more information on this subject, please follow the links above.

+

The Python standard library contains a module of iterators called itertools +that bears some resembles to Infinibatch. +Infinibatch differs from itertools in two ways:

+
    +
  1. Infinibatch provides iterators specifically for the purpose of creating randomized batches of data for machine learning.
  2. +
  3. All iterators in Infinibatch support checkpointing (see the following section).
  4. +
+

Infinibatch iterators are not directly compatible with itertools due to the checkpointing requirement.

+

Infinibatch enables you to build complex data loaders by combining iterators from this module into a pipeline. +To give you a high-level idea of how this is works, we provide a very simple example. +Note that this example is completely artificial and does not solve any useful task. +Its only purpose is to demonstrate the behavior of a pipeline of iterators. +We provide a more realistic example in a later section.

+

First, we create a small test data set.

+
>>> dataset = list(range(6))  # 0, 1, 2, 3, 4, 5
+
+

We can turn this data set into an Infinibatch iterator by wrapping it in a NativeCheckpointableIterator.

+
>>> it = NativeCheckpointableIterator(dataset)  # 0, 1, 2, 3, 4, 5
+
+

We can then transform the data items using a MapIterator, +which applies a given function to each individual data item. +For example, we can multiply each data item by 2.

+
>>> it = MapIterator(it, lambda n: 2 * n)  # 0, 2, 4, 6, 8, 10
+
+

We can restructure the data set by batching together pairs of data items into lists using a FixedBatchIterator.

+
>>> it = FixedBatchIterator(it, batch_size=2)  # [0, 2], [4, 6], [8, 10]
+
+

Using another MapIterator, we can reduce each of these lists to its second element.

+
>>> it = MapIterator(it, lambda l: l[1])  # 2, 6, 10
+
+

Finally, we can use the resulting iterator it just like any standard Python iterator.

+
>>> for item in it:
+...     print(item)
+2
+6
+10
+
+
+

By using iterators, Infinibatch operates in a lazy fashion: +It generally doesn't apply operations to an entire data set at once, +but rather operates on individual data items on-the-fly as they are consumed. +When used correctly, this allows Infinibatch to have a low start-up time and low memory overhead. +For more detail on this, please consult the section on performance considerations below.

+

Checkpointing

+

The main features that sets Infinibatch iterators apart from standard Python iterators is that they support checkpointing. +A checkpoint encapsulates the internal state of an entire pipeline of iterators at a specific point while iterating through a data set. +Once you retrieve a checkpoint, you can later use it to reset the pipeline of iterators to the exact state it was in +when the checkpoint was created. +Checkpoints can easily be serialized and stored to disk using Pythons pickle module. +Infinibatch's checkpointing feature is particularly useful when you're training large deep neural network models over days or weeks, +and you want to make sure that, in case your training is interrupted for any reason, you can pick up your training exactly where you left off.

+

The checkpointing interface consists of two functions getstate and setstate that are defined in CheckpointableIterator, +the common base class of all iterators in this module. +As the names suggest getstate returns a checkpoint object that represents the state of a pipeline at the time the function is called, +and 'setstate' receives a checkpoint object to reset the state of a pipeline. +setstate also accepts None, which resets a pipeline to the beginning of the iteration, +i.e. the state of the pipeline immediately after its construction.

+

It is important to realize that a checkpoint represents the state of a complete pipeline of iterators. +If you have a pipeline consisting of a sequence of iterators, you only have to call getstate on the last iterator in the sequence +to capture the state of the entire pipeline. +Internally, this is achieved by recursive calls that traverse the entire data loading pipeline to collect the state of every iterator in it. +Similarly, when you want to reset a pipeline to a previous state, you only have to call setstate on the last iterator i n the pipeline.

+

To demonstrate this, we recreate the pipeline from the previous section.

+
>>> dataset = list(range(6))  # 0, 1, 2, 3, 4, 5
+>>> it = NativeCheckpointableIterator(dataset)  # 0, 1, 2, 3, 4, 5
+>>> it = MapIterator(it, lambda n: 2 * n)  # 0, 2, 4, 6, 8, 10
+>>> it = FixedBatchIterator(it, batch_size=2)  # [0, 2], [4, 6], [8, 10]
+>>> it = MapIterator(it, lambda l: l[1])  # 2, 6, 10
+
+

Since it behaves just like a standard Python iterator, we can call next to retrieve its first element.

+
>>> next(it)
+2
+
+

We can now call getstate on it (which is the last MapIterator in the pipeline) +to get a checkpoint of the internal state of the entire data loading pipeline.

+
>>> checkpoint = it.getstate()
+
+

Note that the checkpoint represents the internal state of the pipeline after the data item 2 has been retrieved. +Using the checkpoint, we can always return to this exact point in the data set. +To show this, let's exhaust the iterator by casting it to a list.

+
>>> list(it)
+[6, 10]
+
+

Since the iterator is now exhausted, calling next raises a StopIteration exception.

+
>>> next(it)
+Traceback (most recent call last):
+    ...
+StopIteration
+
+
+

We can now reset the pipeline to the checkpoint using setstate.

+
>>> it.setstate(checkpoint)
+
+

This recovers the state of the pipeline after the data item 2 has been retrieved. +Thereby, we expect the next element to be 6.

+
>>> next(it)
+6
+
+

Types of Iterators

+

This section provides a brief overview of the different types of iterators in Infinibatch.

+

Classes and Factory Functions

+

Most iterators in this module are implemented as classes that inherit from the abstract base class CheckpointableIterator. +However, some iterators (such as the BlockwiseShuffleIterator()) are simple combinations of other iterators. +These iterators are implemented as factory functions that construct a pipeline of iterators +and return the last iterator in the pipeline. +For consistency with class-based iterators, +we name these factory function using CamelCase instead of the more pythonic use_of_underscores.

+
+

TODO

+

We currently also have one factory function that actually looks like one: create_source_iterator(). +Provide a comment on this describing why that is.

+
+

Source Iterators

+

There are three iterators that are intended to go at the beginning of a data loading pipeline:

+
    +
  • InfinitePermutationSourceIterator: +This iterator accepts a list, shuffles it, and yields its elements. +It repeats this infinitely, shuffling the list after each pass. +Thereby, this iterator is infinte and cannot be exhausted. +This iterator is meant to be used as the first iterator in a training scenario +and supports splitting the data for multi-GPU training.
  • +
  • ChunkedSourceIterator(): +This iterator accepts a list and yields its elements. +It is meant to be used as the first iterator in an inference or validation scenario +and supports splitting the data for mult-GPU inference.
  • +
  • NativeCheckpointableIterator: +This iterator wraps a Python iterable and makes it checkpointable. +It is mainly intended for demonstration and debugging purposes.
  • +
+

Shuffling

+ +

Batching, SelectMany, and Windowing

+ +

Mapping

+ +

Other Iterators

+ +

Complete Example

+
+

TODO

+

Give a more realistic example following, in broad strokes, the ChunkedDataset including:

+
    +
  • use gzip chunks
  • +
  • training pipeline example
  • +
  • inference pipeline example
  • +
  • pipeline that can do both
  • +
  • etc.
  • +
+
+

Performance Considerations

+
+

TODO

+

Describe what parameters influence performance measures such as memory usage and start-up time.

+
+
+ +Expand source code + +
"""
+## Overview
+
+This part of the documentation covers the __advanced usage__ of Infinibatch by assembling __custom data loading pipelines__.
+Before you continue, please go through the tutorial on the top-level of the documentation of the `infinibatch` module.
+
+Two of the main features of Infinibatch are __lazy evaluation__ through the use of __iterators__
+and built-in support fo
r __checkpointing__.
+In this section, we give an introduction to these features and the basic usage of the Infinibatch iterator library.
+
+
+### Iterators
+
+As a Python programmer, you are probably familiar with the concept of iterators.
+According to the [Python documentation](https://docs.python.org/3.5/glossary.html#term-iterator),
+an iterator is an object representing a stream of data,
+and repeated calls to the iterator's `__next__()` method (or passing it to the built-in function `next()`)
+return successive items in the stream.
+It is important not to confuse an [iterator](https://docs.python.org/3.5/glossary.html#term-iterator)
+with an [iterable](https://docs.python.org/3.5/glossary.html#term-iterable).
+For more information on this subject, please follow the links above.
+
+The Python standard library contains a module of iterators called `itertools`
+that bears some resembles to Infinibatch.
+Infinibatch differs from `itertools` in two ways:
+
+1. Infinibatch provides iterators specifically for the purpose of creating __randomized batches of data for machine learning__.
+2. All iterators in Infinibatch support __checkpointing__ (see the following section).
+
+Infinibatch iterators are not directly compatible with itertools due to the checkpointing requirement.
+
+Infinibatch enables you to build complex data loaders by combining iterators from this module into a pipeline.
+To give you a high-level idea of how this is works, we provide a very simple example.
+Note that this example is completely artificial and does not solve any useful task.
+Its only purpose is to demonstrate the behavior of a pipeline of iterators.
+We provide a more realistic example in a later section.
+
+First, we create a small test data set.
+>>> dataset = list(range(6))  # 0, 1, 2, 3, 4, 5
+
+We can turn this data set into an Infinibatch iterator by wrapping it in a `NativeCheckpointableIterator`.
+>>> it = NativeCheckpointableIterator(dataset)  # 0, 1, 2, 3, 4, 5
+
+We can then transform the data items using a `MapIterator`,
+which applies a given function to each individual data item.
+For example, we can multiply each data item by 2.
+>>> it = MapIterator(it, lambda n: 2 * n)  # 0, 2, 4, 6, 8, 10
+
+We can restructure the data set by batching together pairs of data items into lists using a `FixedBatchIterator`.
+>>> it = FixedBatchIterator(it, batch_size=2)  # [0, 2], [4, 6], [8, 10]
+
+Using another `MapIterator`, we can reduce each of these lists to its second element.
+>>> it = MapIterator(it, lambda l: l[1])  # 2, 6, 10
+
+Finally, we can use the resulting iterator `it` just like any standard Python iterator.
+```py
+>>> for item in it:
+...     print(item)
+2
+6
+10
+
+```
+
+By using iterators, Infinibatch operates in a __lazy__ fashion:
+It generally doesn't apply operations to an entire data set at once,
+but rather operates on individual data items on-the-fly as they are consumed.
+When used correctly, this allows Infinibatch to have a low start-up time and low memory overhead.
+For more detail on this, please consult the section on performance considerations below.
+
+
+### Checkpointing
+
+The main features that sets Infinibatch iterators apart from standard Python iterators is that they support __checkpointing__.
+A checkpoint encapsulates the internal state of an entire pipeline of iterators at a specific point while iterating through a data set.
+Once you retrieve a checkpoint, you can later use it to reset the pipeline of iterators to the exact state it was in
+when the checkpoint was created.
+Checkpoints can easily be serialized and stored to disk using [Pythons `pickle` module](https://docs.python.org/3.5/library/pickle.html).
+Infinibatch's checkpointing feature is particularly useful when you're training large deep neural network models over days or weeks,
+and you want to make sure that, in case your training is interrupted for any reason, __you can pick up your training exactly where you left off__.
+
+The checkpointing interface consists 
of two functions `getstate` and `setstate` that are defined in `CheckpointableIterator`,
+the common base class of all iterators in this module.
+As the names suggest `getstate` returns a checkpoint object that represents the state of a pipeline at the time the function is called,
+and 'setstate' receives a checkpoint object to reset the state of a pipeline.
+`setstate` also accepts `None`, which resets a pipeline to the __beginning__ of the iteration,
+i.e. the state of the pipeline immediately after its construction.
+
+It is important to realize that __a checkpoint represents the state of a complete pipeline of iterators__.
+If you have a pipeline consisting of a sequence of iterators, you only have to call `getstate` on the __last__ iterator in the sequence
+to capture the state of the entire pipeline.
+Internally, this is achieved by recursive calls that traverse the entire data loading pipeline to collect the state of every iterator in it.
+Similarly, when you want to reset a pipeline to a previous state, you only have to call `setstate` on the __last__ iterator in the pipeline.
+
+
+To demonstrate this, we recreate the pipeline from the previous section.
+>>> dataset = list(range(6))  # 0, 1, 2, 3, 4, 5
+>>> it = NativeCheckpointableIterator(dataset)  # 0, 1, 2, 3, 4, 5
+>>> it = MapIterator(it, lambda n: 2 * n)  # 0, 2, 4, 6, 8, 10
+>>> it = FixedBatchIterator(it, batch_size=2)  # [0, 2], [4, 6], [8, 10]
+>>> it = MapIterator(it, lambda l: l[1])  # 2, 6, 10
+
+Since `it` behaves just like a standard Python iterator, we can call `next` to retrieve its first element.
+>>> next(it)
+2
+
+We can now call `getstate` on `it` (which is the last `MapIterator` in the pipeline)
+to get a checkpoint of the internal state of the entire data loading pipeline.
+>>> checkpoint = it.getstate()
+
+Note that the checkpoint represents the internal state of the pipeline after the data item `2` has been retrieved.
+Using the checkpoint, we can always return to this __exact__ point in the data set.
+To show this, let's exhaust the iterator by casting it to a list.
+>>> list(it)
+[6, 10]
+
+Since the iterator is now exhausted, calling `next` raises a `StopIteration` exception.
+```
+>>> next(it)
+Traceback (most recent call last):
+    ...
+StopIteration
+
+```
+
+We can now reset the pipeline to the checkpoint using `setstate`.
+>>> it.setstate(checkpoint)
+
+This recovers the state of the pipeline after the data item `2` has been retrieved.
+Thereby, we expect the next element to be `6`.
+>>> next(it)
+6
+
+
+## Types of Iterators
+
+This section provides a brief overview of the different types of iterators in Infinibatch.
+
+
+### Classes and Factory Functions
+
+Most iterators in this module are implemented as classes that inherit from the abstract base class `CheckpointableIterator`.
+However, some iterators (such as the `BlockwiseShuffleIterator`) are simple combinations of other iterators.
+These iterators are implemented as __factory functions__ that construct a pipeline of iterators
+and return the last iterator in the pipeline.
+For consistency with class-based iterators,
+we name these factory function using CamelCase instead of the more pythonic use_of_underscores.
+
+.. todo::
+    We currently also have one factory function that actually looks like one: `create_source_iterator`.
+    Provide a comment on this describing why that is.
+
+
+### Source Iterators
+
+There are three iterators that are intended to go at the __beginning__ of a data loading pipeline:
+
+- `InfinitePermutationSourceIterator`:
+This iterator accepts a list, shuffles it, and yields its elements.
+It repeats this infinitely, shuffling the list after each pass.
+Thereby, __this iterator is infinte and cannot be exhausted__.
+This iterator is meant to be used as the first iterator in a training scenario
+and supports splitting the data for multi-GPU training.
+- `ChunkedSourceIterator`:
+This iterator accepts a list and yields its elements.
+It
 is meant to be used as the first iterator in an inference or validation scenario
+and supports splitting the data for mult-GPU inference.
+- `NativeCheckpointableIterator`:
+This iterator wraps a Python iterable and makes it checkpointable.
+It is mainly intended for demonstration and debugging purposes.
+
+
+### Shuffling
+
+.. todo:: Describe `BufferedShuffleIterator` and `BlockwiseShuffleIterator`.
+
+
+### Batching, SelectMany, and Windowing
+
+.. todo:: Describe `FixedBatchIterator`, `SelectManyIterator`, and `WindowedIterator`.
+
+
+### Mapping
+
+.. todo:: Describe `MapIterator`, `ParallelMapIterator`, `RecurrentIterator`, and `SamplingRandomMapIterator`.
+
+
+### Other Iterators
+
+.. todo:: Describe `ZipIterator`, `PrefetchIterator`, and `BucketedReadaheadBatchIterator`.
+
+
+## Complete Example
+
+.. todo::
+    Give a more realistic example following, in broad strokes, the ChunkedDataset including:
+
+    - use gzip chunks
+    - training pipeline example
+    - inference pipeline example
+    - pipeline that can do both
+    - etc.
+
+## Performance Considerations
+
+.. todo::
+    Describe what parameters influence performance measures such as memory usage and start-up time.
+"""
+
+from abc import abstractmethod
+import collections
+import copy
+import gzip
+from itertools import cycle, islice
+import math
+from multiprocessing import Pool
+import os
+from queue import Full, Queue
+from random import Random
+from threading import Thread
+from typing import Any, Callable, Dict, Generator, Iterable, Iterator, List, Optional, Tuple, Union
+
+
+from infinibatch.closablequeue import ClosableQueue, ClosedException
+
+
+# TODO for next release:
+#  - benchmark the accuracy when using BlockwiseShuffleIterator vs. the BufferedShuffleIterator
+#  - change all convenience functions back to true classes, using a wrapper class
+
+# TODO later:
+# - make iterator pipeline work for streaming data
+
+def _advance_iterator(iterator: Iterator, n: int):
+    """ Little helper to advance an iterator by n items """
+    for _ in range(n):
+        next(iterator)
+    return n
+
+
+class CheckpointableIterator(collections.abc.Iterator):
+    """
+    Abstract base class that defines the interface for checkpointing.
+    
+    The interface (getstate, setstate) is inspired by Python's random package.
+    """
+    def __iter__(self):
+        return self
+
+    @abstractmethod
+    def getstate(self) -> Dict:
+        """
+        Get checkpoint of current state of iterator
+        
+        In a pipeline of iterators, this function __recursively__ calls itself on the preceeding iterator
+        and includes the gathered information in the returned checkpoint.
+        Thereby, to obtain a checkpoint of the state of an entire pipeline of iterators
+        you only have to call this function on the __last__ iterator in the pipeline.
+        A checkpoint is represented as a `dict`,
+        but the caller should treat a checkpoint as an opaque object
+        and not make any assumptions about the existence or meaning of the `dict` entries.
+        """
+        pass
+
+    @abstractmethod
+    def setstate(self, checkpoint: Optional[Dict]):
+        """
+        Set state of iterator to given checkpoint
+
+        In a pipeline of iterators, this function __recursively__ calls itself on the preceeding iterator.
+        Thereby, to set the state of an entire pipeline of iterators to a given checkpoint
+        you only have to call this function on the __last__ iterator in the pipeline.
+
+        Args:
+            checkpoint: Checkpoint that should be used to reset the state of the iterator (or pipeline).
+                        If this is __None__, the state of the iterator (or pipeline) is reset to the initial
+                        state immediately after construction.
+        """
+        pass
+
+    def __getstate__(self) -> Dict:  # implementation of pickle Protocol
+        return self.getstate()

+
+    def __setstate__(self, checkpoint: Optional[Dict]):
+        self.setstate(checkpoint)
+
+    @abstractmethod
+    def __next__(self):
+        pass
+
+
+class NativeCheckpointableIterator(CheckpointableIterator):
+    """
+    Simple wrapper class that turns a Python Iterable into a CheckpointableIterator
+    
+    When calling setstate on this class, it simply replays the iterator all the way to the checkpoint one element at a time,
+    which makes it generally inefficient.
+
+    Warning: This class cannot be used with Iterators (as opposed to Iterables), which have an `__iter__` function that simply returns self, but does not reset.
+    """
+    def __init__(self, iterable: Iterable):
+        # check whether iterable is iterable or iterator:
+        # if the variable iterable contains an iterator, the function __iter__ returns self
+        # if the variable iterable is an actual iterator, it should not return self
+        if iter(iterable) is iterable:  
+            raise ValueError('It looks like you are passing an iterator instead of an iterable. This is not supported and can cause undefined behavior when used with checkpointing.')
+        self._input_iterable = iterable
+        self.setstate(None)
+
+    def getstate(self) -> Dict:
+        return {'num_items_yielded': self._num_items_yielded}
+
+    def setstate(self, checkpoint: Optional[Dict]):
+        self._iterator = iter(self._input_iterable)
+        self._num_items_yielded = _advance_iterator(self._iterator, checkpoint['num_items_yielded']) if checkpoint is not None else 0
+
+    def __next__(self):
+        item = next(self._iterator)  # call this before increasing _num_items_yielded to correctly handle the case when a StopIteration exception is thrown
+        self._num_items_yielded += 1
+        return item
+
+
+def create_source_iterator(source_items: List, train: bool=True, seed: Optional[int]=None, shuffle: bool=True, num_instances: int=1, instance_rank: int=0):
+    if not train and shuffle:
+        raise ValueError('shuffling is not supported when train=False')
+    if train:
+        return InfinitePermutationSourceIterator(source_items, seed=seed, shuffle=shuffle, num_instances=num_instances, instance_rank=instance_rank)
+    else:
+        return ChunkedSourceIterator(source_items, num_instances=num_instances, instance_rank=instance_rank)
+
+
+def ChunkedSourceIterator(source_items: List, num_instances: int=1, instance_rank: int=0):
+    """
+    Cuts source list into chunks, one per instance, and serves out items in chunk corresponding to instance_rank
+
+    This is a source iterator:
+    It is meant to be used at the beginning of a data loading pipeline.
+    As such, it takes a list as its source and not a CheckpointableIterator.
+
+    Args:
+        source_items: input list, must not be empty and must be small enough to fit into RAM entirely, ownership of the list and the data goes to the iterator, do not modify it!
+        num_instances: number of instances of this iterator. Meant for use with multi-process data loading, e.g., in distributed training.
+        instance_rank: rank of this instance of the iterator. Meant for use with multi-process data loading, e.g., in distributed training.
+    """
+    # heuristic: assuming blocks are all of the same size, math.ceil should give us the shortest makespan
+    chunk_size = math.ceil(len(source_items) / num_instances)
+    # this does not cause any out-of-bounds issues:
+    # a slice with a start-index beyong the end of the list is empty,
+    # and an end-index of a slice is capped at the end of the list
+    chunk = source_items[instance_rank * chunk_size : (instance_rank + 1) * chunk_size]
+    return NativeCheckpointableIterator(chunk)
+
+
+class InfinitePermutationSourceIterator(CheckpointableIterator):
+    """
+    Infinitely generates permutations of the items in the given list.
+
+    This is a source iterator:
+    It is meant to be used at the beginning of a data loading pipeline.
+    As such, it takes a list as its source and not a CheckpointableIterator.
+    The given list is loaded completely into RAM.
+
+    For example, this is used for randomizing the pathnames of data blocks read by ChunkedReadlinesIterator.
+    """
+    def __init__(self, source_items: List, seed: Optional[int]=None, shuffle: bool=True, num_instances: int=1, instance_rank: int=0):
+        """
+        Args:
+            source_items: input list, must not be empty and must be small enough to fit into RAM entirely, ownership of the list and the data goes to the iterator, do not modify it!
+            seed: random seed used for shuffling (or None)
+            shuffle: set False to bypass the shuffling. Then this is just a checkpointed version of itertools.cycle(). (Default: True)
+            num_instances: number of instances of this iterator. Meant for use with multi-process data loading, e.g., in distributed training.
+            instance_rank: rank of this instance of the iterator. Meant for use with multi-process data loading, e.g., in distributed training.
+        """
+        self._source_items = source_items
+        if not self._source_items:
+            raise ValueError("InfinitePermutationIterator: source must not be empty")
+        self._shuffle = shuffle
+        self._seed = seed
+        self._num_instances = num_instances
+        self._instance_rank = instance_rank
+        self.setstate(None)
+
+    def getstate(self) -> Dict:
+        return {'random_state':      self._random_state,  # state of random generator before generating the current shuffling of the sequence
+                'num_items_yielded': self._num_items_yielded}    # how many items have already been iterated over in the current shuffling
+
+    def setstate(self, checkpoint: Optional[Dict]):
+        # set iteration state. Do this outside the generator below in case getstate() is called before ever iterating
+        self._random_state      = checkpoint['random_state']      if checkpoint else None
+        self._num_items_yielded = checkpoint['num_items_yielded'] if checkpoint else 0
+        # We define the iteration itself as a generator for ease of implementation.
+        # We could as well just have used an explicit state machine represented by class members.
+        def _generate() -> Iterator:
+            # create and reset random generator
+            random = Random(self._seed)
+            if self._random_state is not None:  # restore the random generator's state
+                random.setstate(self._random_state)
+            skip_to_checkpoint = self._num_items_yielded  # items to skip in order to advance to checkpoint
+            # main outer loop for infinite passes over items (reshuffle before each pass)
+            while True:
+                # (re-)shuffle all items
+                self._random_state = random.getstate()  # remember random state before shuffling
+                self._num_items_yielded   = 0
+                shuffled_items = self._source_items[:]  # note: if underlying iterator is checkpointable, use setstate(checkpoint['nested_state']) on it
+                if self._shuffle:
+                    random.shuffle(shuffled_items)
+                shuffled_iterator = iter(shuffled_items)
+                # skip initial items when restarting from checkpoint
+                if skip_to_checkpoint:  # @TODO: find a way to abstract this more, so that we can plug it into the 'for' statement directly
+                    self._num_items_yielded += _advance_iterator(shuffled_iterator, skip_to_checkpoint)
+                    skip_to_checkpoint = 0  # done skipping
+                # main inner loop over items
+                for item in shuffled_iterator:
+                    self._num_items_yielded += 1  # record how many items we have iterated over in this pass over the items
+                    if (self._num_items_yielded-1) % self._num_instances == self._instance_rank:  # build-
in islice facility
+                        yield item
+        self._iterator = _generate()
+
+    def __next__(self):
+        return next(self._iterator)
+
+
+class SelectManyIterator(CheckpointableIterator):
+    """
+    Projects each element of a source sequence to a sequence and flattens the resulting sequences into one sequence.
+    """
+    def __init__(self, source_iterator: CheckpointableIterator, collection_selector: Optional[Callable[[Any], Iterator]]=None):
+        """
+        Args:
+            source_iterator: iterator over the items to pass to collection_selector()
+            collection_selector: user callback that maps an item into an Iterable, whose items will be yielded.
+                                 The returned Iterator is used only once. Hence, it is also allowed to
+                                 return self-iterables, such as iterators and generator expressions.
+                                 If None is given, no callback is applied.
+        """
+        if not isinstance(source_iterator, CheckpointableIterator):
+            raise ValueError('source_iterator has to be a CheckpointableIterator')
+        self._source_iterator = source_iterator          # type: CheckpointableIterator
+        self._collection_selector = collection_selector  # type: Callable[[Any], Iterator]
+        self.setstate(None)
+
+    def getstate(self) -> Dict:
+        return {'source_state':            self._source_state,
+                'flattened_items_yielded': self._flattened_items_yielded}
+
+    def setstate(self, checkpoint: Optional[Dict]):
+        self._source_state            = checkpoint['source_state']            if checkpoint else None
+        self._flattened_items_yielded = checkpoint['flattened_items_yielded'] if checkpoint else 0
+        self._source_iterator.setstate(self._source_state)
+        def _generate():
+            skip_to_checkpoint = self._flattened_items_yielded
+            # main loop over source source_items
+            for source_item in self._source_iterator:
+                if self._collection_selector is not None:
+                    data = iter(self._collection_selector(source_item))
+                else:
+                    data = iter(source_item)
+                self._flattened_items_yielded = 0
+                if skip_to_checkpoint:
+                    #print("Skipping to index", skip_to_checkpoint, file=sys.stderr)
+                    self._flattened_items_yielded += _advance_iterator(data, skip_to_checkpoint)
+                    skip_to_checkpoint = 0
+                # main loop over lines
+                for item in data:
+                    self._flattened_items_yielded += 1
+                    yield item
+                self._source_state = self._source_iterator.getstate()
+        self._iterator = _generate()
+
+    def __next__(self):
+        return next(self._iterator)
+
+
+class BufferedShuffleIterator(CheckpointableIterator):
+    """
+    Shuffles given iterable using a limited buffer.
+    """
+    def __init__(self, source_iterator: CheckpointableIterator, buffer_size: int, seed: int = 0):
+        """
+        Args:
+            source_iterator: checkpointable iterator or restartable iterable over input items to shuffle
+            buffer_size: size of the buffer in number of items used for shuffling
+            seed: random seed used for shuffling (or None)
+        """
+        if not isinstance(source_iterator, CheckpointableIterator):
+            raise ValueError('source_iterator has to be a CheckpointableIterator')
+        self._source_iterator = source_iterator
+        self._buffer = [None for _ in range(buffer_size)]  # maybe do this lazily?   --Yes, since user may set state immediately, then this is not needed here
+        self._random = Random(seed)
+        self.setstate(None)
+
+    def getstate(self) -> Dict:
+        return {'source_state': self._so
urce_iterator.getstate(),
+                'buffer':       copy.deepcopy(self._buffer),
+                'random_state': self._random.getstate()}
+
+    def setstate(self, checkpoint: Optional[Dict]):
+        if checkpoint:
+            self._source_iterator.setstate(checkpoint['source_state'])
+            self._buffer = checkpoint['buffer']
+            self._random.setstate(checkpoint['random_state'])
+            # @TODO: Can we add a comment how the flush part is handled?
+        else:
+            self._source_iterator.setstate(None)
+        self._iterator = self._generate()
+
+    def _generate(self) -> Iterator:
+        # shuffle data with a buffer:
+        # this is similar to what the Fisher-Yates shuffle does,
+        # but modified to run with a constant-size buffer
+        # see https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle
+        # this was inspired by an algorithm implemented in Kaldi
+        # see https://kaldi-asr.org/doc/nnet-shuffle-egs_8cc.html
+        for item in self._source_iterator:
+            index = self._random.randrange(0, len(self._buffer))
+            result = None
+            if self._buffer[index] is not None:
+                result = self._buffer[index]
+            self._buffer[index] = item
+            # only yield value once buffer is updated to allow for correct checkpointing!
+            if result is not None:
+                yield result
+
+        # flush buffer
+        while self._buffer:
+            item = self._buffer.pop()
+            if item is not None:
+                yield item
+
+    def __next__(self):
+        return next(self._iterator)
+
+
+class MapIterator(CheckpointableIterator):
+    """
+    Applies given tranform to each data item
+    """
+    def __init__(self, source_iterator: CheckpointableIterator, transform: Callable[[str],Any]):
+        """
+        Args:
+            source_iterator: checkpointable iterator
+            transform: function to be applied to each data item
+        """
+        if not isinstance(source_iterator, CheckpointableIterator):
+            raise ValueError('source_iterator has to be a CheckpointableIterator')
+        self._source_iterator = source_iterator
+        self._transform = transform
+
+    def getstate(self) -> Dict:
+        return self._source_iterator.getstate()
+
+    def setstate(self, checkpoint: Optional[Dict]):
+        self._source_iterator.setstate(checkpoint)
+
+    def __next__(self):
+        return self._transform(next(self._source_iterator))
+
+
+def ParallelMapIterator(source_iterator: CheckpointableIterator, transform: Callable[[str],Any], num_processes: int, num_items_per_process: int):
+    """
+    Applies given transform to each data item
+
+    Behaves the same as MapIterator, but applies transform in parallel using multiple processes in a parallel map operation.
+
+    Warning:
+    The transform function has to be pickleable because it is sent across process boundaries.
+    To achieve this, transform should be a top-level function.
+
+    Args:
+        source_iterator: checkpointable iterator
+        transform: function to be applied to each data item, has to be pickleable, see above
+        num_processes: number of processes to use for parallel map
+        num_items_per_process: number of data items each process operates on
+    """
+    # divide stream of data items into batches
+    batched_samples = FixedBatchIterator(source_iterator, num_processes * num_items_per_process)
+    # create process pool and capture it in closure that performs parallel map
+    p = Pool(num_processes)
+    def parallel_map_transform(buffer):
+        return p.map(transform, buffer)
+    # apply transform in parallel to data items in a batch
+    batched_transformed_samples = MapIterator(batched_samples, parallel_map_transform)
+    # unpack batches to go back to stream of (now transformed) data items
+    transformed_samples = SelectManyIterator(bat
ched_transformed_samples)
+    return transformed_samples
+
+
+class ZipIterator(CheckpointableIterator):
+    """
+    Zips items from all given iterators, like the Python standard function zip().
+
+    Like Python's build-in zip(), the iteration stops when the shortest input iterable is exhausted.
+    """
+    def __init__(self, *source_iterators: CheckpointableIterator):
+        """
+        Args:
+            source_iterators: list of iterators to zip, item by item
+        """
+        for source_iterator in source_iterators:
+            if not isinstance(source_iterator, CheckpointableIterator):
+                raise ValueError('all iterators in source_iterators have to be CheckpointableIterator')
+        self._source_iterators = source_iterators    # type: List[CheckpointableIterator]
+
+    def getstate(self) -> Dict:
+        return {'input_states': tuple(iterator.getstate() for iterator in self._source_iterators)}
+
+    def setstate(self, checkpoint: Optional[Dict]):
+        if checkpoint is None:
+            for iterator in self._source_iterators:
+                iterator.setstate(None)
+        else:
+            for iterator, state in zip(self._source_iterators, checkpoint['input_states']):
+                iterator.setstate(state)
+
+    def __next__(self):
+        res = []  # (note: can't use a generator expression, as it gets confused when a next() call raises StopIteration)
+        for iterator in self._source_iterators:
+            res.append(next(iterator))
+        return tuple(res)
+
+
+# @TODO: The yield makes a (shallow) copy of the window, which has complexity O(width * length). In some cases,
+#        we don't actually need to consume all items in the window. Hence, to make this faster, we should use
+#        double-buffering and return a slice view (which we'd have to write).
+class WindowedIterator(CheckpointableIterator):
+    """
+    Yields 'width' consecutive items in a sliding window.
+
+    E.g. [1, 2, 3, 4, 5, 6] with width = 3 will yield
+    [[1, 2, 3], [2, 3, 4], [3, 4, 5], [4, 5, 6]]
+    """
+    def __init__(self, source_iterator: CheckpointableIterator, width: int):
+        """
+        Args:
+            source_iterator: checkpointable input iterators
+        """
+        if not isinstance(source_iterator, CheckpointableIterator):
+            raise ValueError('source_iterator has to be a CheckpointableIterator')
+        self._source_iterator = source_iterator  # type: CheckpointableIterator
+        self._width = width                      # type: int
+        self.setstate(None)
+
+    def getstate(self) -> Dict:
+        return {'source_state': self._source_state,  # state for first item in FIFO
+                'item_index':  self._item_index}   # index of next item to serve
+
+    def setstate(self, checkpoint: Optional[Dict]):
+        self._source_state = checkpoint['source_state'] if checkpoint else None
+        self._item_index   = checkpoint['item_index']   if checkpoint else 0
+        self._source_iterator.setstate(self._source_state)
+        self._iterator = self._generate()
+
+    def _fifo_slice(self, i):  # returns a window into the FIFO beginning at i
+        # @TODO: for efficiency, make this a slice view
+        return tuple(self._fifo[i:i + self._width])
+
+    def _generate(self) -> Iterator:
+        self._source_state = self._source_iterator.getstate()
+        self._fifo = list(islice(self._source_iterator, self._width))
+        # we do this in overlapping blocks of length 2*width, for easier checkpointing and potential efficiency
+        while len(self._fifo) == self._width:
+            # we got 'width' items; append another 'width' (or less if at end)
+            next_input_state = self._source_iterator.getstate()
+            self._fifo.extend(islice(self._source_iterator, self._width))
+            # now serve all positio
ns in first half (last = width - 1). If at end, then limit accordingly.
+            last = min(self._width - 1, len(self._fifo) - self._width)
+            while self._item_index <= last:
+                window = self._fifo_slice(self._item_index)
+                self._item_index += 1
+                yield window
+            # drop all we just served; if < width left, we have hit the end
+            self._fifo = self._fifo[last + 1:]    # Note: This must be a new list, since the old might still be in a slice view.
+            self._source_state = next_input_state  # this reflects now the first element in the FIFO 
+            self._item_index = 0
+
+    def __next__(self):
+        return next(self._iterator)
+
+
+# @TODO: research on whether this operation has a well-known name
+class FixedBatchIterator(CheckpointableIterator):
+    """
+    Batches N consecutive items into a single item that is a list of these items.
+
+    E.g. [1, 2, 3 4, 5, 6, 7, 8] with batch_size = 3 will yield
+    [(1, 2, 3), (4, 5, 6), (7, 8)]
+    """
+    def __init__(self, source_iterator: CheckpointableIterator, batch_size: int):
+        """
+        Args:
+            source_iterator: checkpointable input iterators
+            batch_size: number of items per batch
+        """
+        if not isinstance(source_iterator, CheckpointableIterator):
+            raise ValueError('source_iterator has to be a CheckpointableIterator')
+        self._source_iterator = source_iterator  # type: CheckpointableIterator
+        self._batch_size = batch_size            # type: int
+        self.setstate(None)
+
+    def getstate(self) -> Dict:
+        return {'source_state': self._source_iterator.getstate()}  # state for first item in next batch
+
+    def setstate(self, checkpoint: Optional[Dict]):
+        self._source_state = checkpoint['source_state'] if checkpoint else None
+        self._source_iterator.setstate(self._source_state)
+        self._iterator = self._generate()
+
+    def _generate(self) -> Iterator:
+        while True:
+            batch = list(islice(self._source_iterator, self._batch_size))
+            if not batch:
+                break
+            yield batch
+
+    def __next__(self):
+        return next(self._iterator)
+
+
+class RandomIterator(CheckpointableIterator):
+    """
+    Iterator to generate uniformly distributed random numbers in the interval [0,1).
+    Very similar to Random.random(), except that random numbers are
+    obtained via next().
+    """
+    def __init__(self, seed: Optional[int]=None):
+        """
+        Args:
+            seed: Random seed.
+        """
+        self._random = Random()  # type: Random
+        if seed is not None:
+            self._random.seed(seed)
+
+    def getstate(self) -> Dict:
+        return {'random_state': self._random.getstate()}
+
+    def setstate(self, checkpoint: Optional[Dict]):
+        self._random.setstate(checkpoint['random_state'] if checkpoint else None)
+
+    def __next__(self):
+        return self._random.random()
+
+
+class RecurrentIterator(CheckpointableIterator):
+    """
+    Iterates statefully over a step function. The step function accepts a state and a new item,
+    and returns a new state and an output item, which is yielded.
+    """
+    def __init__(self, source_iterator: CheckpointableIterator, step_function: Callable[[Any,Any], Tuple[Any,Any]], initial_state: Any = None):
+        """
+        Args:
+            source_iterator: checkpointable iterator to recur over
+            step_function: user-supplied function with signature step_function(state, item) -> (new_state, output)
+            initial_state: initial state to be passed to the step_function upon first invocation
+        """
+        if not isinstance(source_iterator, CheckpointableIterator):
+            raise ValueError('source_iterator has to be a Chec
kpointableIterator')
+        self._source_iterator = source_iterator  # type: CheckpointableIterator
+        self._step_function = step_function      # type: Callable[[Any,Any], Tuple[Any,Any]]
+        self._initial_state = initial_state      # type: Any
+        self.setstate(None)
+    
+    def getstate(self):
+        return {'recurrent_state': self._recurrent_state,
+                'source_state':    self._source_iterator.getstate()}
+    
+    def setstate(self, checkpoint):
+        self._recurrent_state = checkpoint['recurrent_state'] if checkpoint else self._initial_state
+        self._source_iterator.setstate(checkpoint['source_state'] if checkpoint else None)
+        def _generate():
+            for item in self._source_iterator:
+                self._recurrent_state, output = self._step_function(self._recurrent_state, item)
+                yield output
+        self._iterator = _generate()
+
+    def __next__(self):
+        return next(self._iterator)
+
+
+def SamplingRandomMapIterator(source_iterator: CheckpointableIterator, transform: Callable[[Random,Any],Any], seed: Optional[int]=None):
+    """
+    An iterator that calls a transform function on each item, while also passing a checkpointed
+    random generator.
+
+    Args:
+        source_iterator: checkpointable iterator to recur over
+        step_function: user-supplied function with signature step_function(random, item) -> result_item
+        seed: random seed
+    """
+    _random = Random()
+    if seed is not None:
+        _random.seed(seed)
+    def _step_function(state, item):
+        _random.setstate(state)
+        output = transform(_random, item)
+        return _random.getstate(), output
+    return RecurrentIterator(source_iterator, _step_function, initial_state=_random.getstate())
+
+
+def BlockwiseShuffleIterator(source_iterator: CheckpointableIterator, block_size: int, seed: int = 0):
+    """
+    Shuffles a sequence of items by grouping consecutive items in blocks of fixed size, shuffling
+    each block, and yielding the shuffled items of all blocks as a flat sequence.
+
+    E.g. [1, 2, 3, 4, 5, 6, 7, 8] with block_size = 3 may yield [3, 1, 2, 4, 6, 5, 8, 7].
+
+    Args:
+        source_iterator: checkpointable iterator or restartable iterable over input items to shuffle
+        block_size: size of the buffer in number of items used for shuffling
+        seed: random seed used for shuffling (or None)
+    """
+    # This is implemented as a pipeline:
+    #  - group N consecutive items together
+    #  - shuffle them
+    #  - flatten the result
+    blocks = FixedBatchIterator(source_iterator, batch_size=block_size)
+    def shuffle_block_fn(random: Random, block: List):
+        random.shuffle(block)
+        return block
+    shuffled_blocks = SamplingRandomMapIterator(blocks, transform=shuffle_block_fn, seed=seed)
+    samples = SelectManyIterator(shuffled_blocks, collection_selector=lambda shuffled_block: iter(shuffled_block))
+    return samples
+
+
+class PrefetchIterator(CheckpointableIterator):
+    """
+    An iterator prefetching data into a buffer on a seperate thread to smooth out IO latency.
+
+    Args:
+        source_iterator: checkpointable iterator to recur over
+        buffer_size: size of the queue between the threads
+    """
+    def __init__(self, source_iterator: CheckpointableIterator, buffer_size: int=1000):
+        if not isinstance(source_iterator, CheckpointableIterator):
+            raise ValueError('source_iterator has to be a CheckpointableIterator')
+        self._source_iterator = source_iterator  # type:CheckpointableIterator
+        self._buffer_size = buffer_size          # type: int
+        self._queue = None                       # type: Optional[ClosableQueue]
+        self._thread = None                      # type: Optional[Thread]
+        self.setstate(None)
+        
+    def getstate(self) -> Dict:
+        return {'source_state': self.
_source_state,
+                'item_offset' : self._item_offset  }
+
+    def setstate(self, checkpoint: Optional[Dict]):
+        if self._thread is not None:  # if there is a prefetching thread running, close the queue and wait for the thread to terminate
+            assert self._queue is not None
+            self._queue.close()
+            self._thread.join()
+        
+        self._source_state = checkpoint['source_state'] if checkpoint is not None else None
+        self._item_offset  = checkpoint['item_offset' ] if checkpoint is not None else 0
+
+        self._source_iterator.setstate(self._source_state)
+
+        self._queue = ClosableQueue(maxsize=self._buffer_size)  # clear queue
+        # make thread daemonic so it is killed when the main program terminates
+        self._thread = Thread(target=self._prefetch_thread_fn, args=(self._source_iterator, self._item_offset, self._buffer_size, self._queue), daemon=True)
+        self._thread.start()
+
+    @staticmethod
+    def _prefetch_thread_fn(source, item_offset, buffer_size, queue):  # behavior of the prefetching thread, only call from that thread!
+        _advance_iterator(source, item_offset)  # skip to checkpoint
+
+        while True:
+            try:
+                item = next(source)
+            except StopIteration:
+                queue.close()
+                return
+            
+            if item_offset == buffer_size - 1:  # send a new source state a the END of each window of length _buffer_size
+                source_state = source.getstate()  # this is the state for retrieving the NEXT element, i.e. the first element of the next buffer
+                item_offset = 0
+            else:
+                source_state = None
+                item_offset += 1
+            msg = (item, source_state)
+
+            try:
+                queue.put(msg)
+            except ClosedException:
+                return
+
+    def __next__(self):
+        try:
+            msg = self._queue.get()
+        except ClosedException:
+            raise StopIteration
+
+        item, prefetch_source_state = msg
+        if prefetch_source_state is not None:
+            assert self._item_offset == self._buffer_size - 1  # we expect a new source state at then END of each window of length _buffer_size
+            self._source_state = prefetch_source_state
+            self._item_offset = 0
+        else:
+            self._item_offset = self._item_offset + 1
+            assert self._item_offset < self._buffer_size
+        return item  # for debugging, its useful to return msg instead of item
+
+    def __del__(self):  # note: this is often not called. If you really need it, gc.collect() will do the trick.
+        if self._thread is not None:
+            assert self._queue is not None
+            self._queue.close()
+            try:
+                self._thread.join()
+            except:
+                pass
+
+class BucketedReadaheadBatchIterator(CheckpointableIterator):
+    """
+    Iterates over items from a checkpointable iterator and groups items of similar length into batches.
+
+    The algorithm reads a head a certain number of lines (e.g. 10 million), sorts them by
+    length, and them groups them into batches from start to end. The sort is stable, such
+    that prior randomization is not undone (except for the length grouping). The batch size
+    is dynamic, and determined by a user-provided callback.
+
+    This is based on Marian NMT's BatchGenerator.
+    """
+
+    def __init__(self, source_iterator: CheckpointableIterator, read_ahead: int, key: Callable[[Any], Any], batch_size: Union[int,Callable[[Any], int]], shuffle: bool=True, seed: Optional[int]=None):
+        """
+        Args:
+            source_iterator: The data set that is read from. Typically this is an infinite source.
+            read_ahead: Number of items to fetch ahead for grouping purposes.
+            key: User-provided callback to define how data is sorted for purpose of batching.
+    
        batch_size: Batch size in number of items. Either an integer or a callback to determine batch size for a given first batch item.
+            shuffle: Pass False to not randomize the batches. (default: True)
+            seed: Random seed for batch shuffling.
+        """
+        if not isinstance(source_iterator, CheckpointableIterator):
+            raise ValueError('source_iterator has to be a CheckpointableIterator')
+        # keep arguments
+        self._key = key                # type: Callable[[Any], Any]
+        self._batch_size = batch_size  # type: Union[int,Callable[[Any], int]]
+        self._read_ahead = read_ahead  # type: int
+        # initialize state
+        self._random = None
+        if shuffle:
+            self._random = Random()                    # type: Random
+            if seed is not None:
+                self._random.seed(seed)
+        self._source_iterator = iter(source_iterator)  # type: CheckpointableIterator
+        self.setstate(None)
+
+    def getstate(self):
+        return {'source_state': self._source_state,
+                'random_state': self._random_state,
+                'num_served':   self._num_batches_yielded}
+
+    def setstate(self, checkpoint: Optional[Dict]):
+        self._source_state        = checkpoint['source_state'] if checkpoint else None  # type: Dict  -- state of input before reading the current set of batches
+        self._random_state        = checkpoint['random_state'] if checkpoint else None  # type: Any   -- state of random generator at _source_state
+        self._num_batches_yielded = checkpoint['num_served']   if checkpoint else 0     # type: int   -- number of batches served from the current set of batches
+        # checkpointing: restore to start of current set of batches
+        self._source_iterator.setstate(self._source_state)
+        if self._random_state:
+            self._random.setstate(self._random_state)
+        self._source_exhausted = False  # type: bool  -- set to True once we hit StopIteration on source
+        def _generate():
+            skip_to_checkpoint = self._num_batches_yielded
+            source_exhausted = False
+            while not source_exhausted:
+                # prefetch the readahead buffer
+                self._source_state = self._source_iterator.getstate()
+                self._random_state = self._random.getstate() if self._random else None
+                items = list(islice(self._source_iterator, self._read_ahead))
+                source_exhausted = (len(items) < self._read_ahead)
+                # create batches
+                batches = self._create_batches(items)
+                # shuffle the batches
+                if self._random:
+                    self._random.shuffle(batches)
+                # on first loop iteration, restore iterator inside batches from checkpoint
+                batches = iter(batches)
+                self._num_batches_yielded = _advance_iterator(batches, skip_to_checkpoint)
+                skip_to_checkpoint = 0
+                # main loop over batches in current read-ahead section
+                for batch in batches:
+                    self._num_batches_yielded += 1
+                    yield batch
+        self._iterator = _generate()  # type: Iterator  -- iterator into current set of batches
+
+    def _create_batches(self, items: List[Any]) -> List[List[Any]]:  # helper to form batches from a list of items
+            # sort by length, longest first
+            items.sort(key=self._key, reverse=True)  # note: sort() is stable, so we won't undo any randomization besides the bucketing
+            # group into batches
+            cur_batch = None
+            batches = []
+            for item in items:
+                if not cur_batch:
+                    batch_size = self._batch_size if isinstance(self._batch_size, int) else \
+                                 self._batch_size(item)
+                    cur_batch = []
+                cur_batch.append(it
em)
+                if len(cur_batch) >= batch_size:  # this batch is full
+                    batches.append(cur_batch)
+                    cur_batch = None
+            if cur_batch:
+                batches.append(cur_batch)
+            return batches
+
+    def __next__(self):
+        return next(self._iterator)
+
+
+
+
+
+
+
+

Functions

+
+
+def create_source_iterator(source_items: List, train: bool = True, seed: Union[int, NoneType] = None, shuffle: bool = True, num_instances: int = 1, instance_rank: int = 0) +
+
+
+
+ +Expand source code + +
def create_source_iterator(source_items: List, train: bool=True, seed: Optional[int]=None, shuffle: bool=True, num_instances: int=1, instance_rank: int=0):
+    if not train and shuffle:
+        raise ValueError('shuffling is not supported when train=False')
+    if train:
+        return InfinitePermutationSourceIterator(source_items, seed=seed, shuffle=shuffle, num_instances=num_instances, instance_rank=instance_rank)
+    else:
+        return ChunkedSourceIterator(source_items, num_instances=num_instances, instance_rank=instance_rank)
+
+
+
+def ChunkedSourceIterator(source_items: List, num_instances: int = 1, instance_rank: int = 0) +
+
+

Cuts source list into chunks, one per instance, and serves out items in chunk corresponding to instance_rank

+

This is a source iterator: +It is meant to be used at the beginning of a data loading pipeline. +As such, it takes a list as its source and not a CheckpointableIterator.

+

Args

+
+
source_items
+
input list, must not be empty and must be small enough to fit into RAM entirely, ownership of the list and the data goes to the iterator, do not modify it!
+
num_instances
+
number of instances of this iterator. Meant for use with multi-process data loading, e.g., in distributed training.
+
instance_rank
+
rank of this instance of the iterator. Meant for use with multi-process data loading, e.g., in distributed training.
+
+
+ +Expand source code + +
def ChunkedSourceIterator(source_items: List, num_instances: int=1, instance_rank: int=0):
+    """
+    Cuts source list into chunks, one per instance, and serves out items in chunk corresponding to instance_rank
+
+    This is a source iterator:
+    It is meant to be used at the beginning of a data loading pipeline.
+    As such, it takes a list as its source and not a CheckpointableIterator.
+
+    Args:
+        source_items: input list, must not be empty and must be small enough to fit into RAM entirely, ownership of the list and the data goes to the iterator, do not modify it!
+        num_instances: number of instances of this iterator. Meant for use with multi-process data loading, e.g., in distributed training.
+        instance_rank: rank of this instance of the iterator. Meant for use with multi-process data loading, e.g., in distributed training.
+    """
+    # heuristic: assuming blocks are all of the same size, math.ceil should give us the shortest makespan
+    chunk_size = math.ceil(len(source_items) / num_instances)
+    # this does not cause any out-of-bounds issues:
+    # a slice with a start-index beyong the end of the list is empty,
+    # and an end-index of 
a slice is capped at the end of the list
+    chunk = source_items[instance_rank * chunk_size : (instance_rank + 1) * chunk_size]
+    return NativeCheckpointableIterator(chunk)
+
+
+
+def ParallelMapIterator(source_iterator: CheckpointableIterator, transform: Callable[[str], Any], num_processes: int, num_items_per_process: int) +
+
+

Applies given transform to each data item

+

Behaves the same as MapIterator, but applies transform in parallel using multiple processes in a parallel map operation.

+

Warning: +The transform function has to be pickleable because it is sent across process boundaries. +To achieve this, transform should be a top-level function.

+

Args

+
+
source_iterator
+
checkpointable iterator
+
transform
+
function to be applied to each data item, has to be pickleable, see above
+
num_processes
+
number of processes to use for parallel map
+
num_items_per_process
+
number of data items each process operates on
+
+
+ +Expand source code + +
def ParallelMapIterator(source_iterator: CheckpointableIterator, transform: Callable[[str],Any], num_processes: int, num_items_per_process: int):
+    """
+    Applies given transform to each data item
+
+    Behaves the same as MapIterator, but applies transform in parallel using multiple processes in a parallel map operation.
+
+    Warning:
+    The transform function has to be pickleable because it is sent across process boundaries.
+    To achieve this, transform should be a top-level function.
+
+    Args:
+        source_iterator: checkpointable iterator
+        transform: function to be applied to each data item, has to be pickleable, see above
+        num_processes: number of processes to use for parallel map
+        num_items_per_process: number of data items each process operates on
+    """
+    # divide stream of data items into batches
+    batched_samples = FixedBatchIterator(source_iterator, num_processes * num_items_per_process)
+    # create process pool and capture it in closure that performs parallel map
+    p = Pool(num_processes)
+    def parallel_map_transform(buffer):
+        return p.map(transform, buffer)
+    # apply transform in parallel to data items in a batch
+    batched_transformed_samples = MapIterator(batched_samples, parallel_map_transform)
+    # unpack batches to go back to stream of (now transformed) data items
+    transformed_samples = SelectManyIterator(batched_transformed_samples)
+    return transformed_samples
+
+
+
+def SamplingRandomMapIterator(source_iterator: CheckpointableIterator, transform: Callable[[random.Random, Any], Any], seed: Union[int, NoneType] = None) +
+
+

An iterator that calls a transform function on each item, while also passing a checkpointed +random generator.

+

Args

+
+
source_iterator
+
checkpointable iterator to recur over
+
step_function
+
user-supplied function with signature step_function(random, item) -> result_item
+
seed
+
random seed
+
+
+ +Expand source code + +
def SamplingRandomMapIterator(source_iterator: CheckpointableIterator, transform: Callable[[Random,Any],Any], seed: Optional[int]=None):
+    """
+    An iterator that calls a transform function on each item, while also passing a checkpointed
+    random generator.
+
+    Args:
+        source_iterator: checkpointable iterator to recur over
+        step_function: user-supplied function with signature step_function(random, item) -> result_item
+        seed: random seed
+    """
+    _random = Random()
+    if seed is not None:
+        _random.seed(seed)
+    def _step_function(state, item):
+        _random.setstate(state)
+        output = transform(_random, item)
+        return _random.getstate(), output
+    return RecurrentIterator(source_iterator, _step_function, initial_state=_random.getstate())
+
+
+
+def BlockwiseShuffleIterator(source_iterator: CheckpointableIterator, block_size: int, seed: int = 0) +
+
+

Shuffles a sequence of items by grouping consecutive items in blocks of fixed size, shuffling +each block, and yielding the shuffled items of all blocks as a flat sequence.

+

E.g. [1, 2, 3, 4, 5, 6, 7, 8] with block_size = 3 may yield [3, 1, 2, 4, 6, 5, 8, 7].

+

Args

+
+
source_iterator
+
checkpointable iterator or restartable iterable over input items to shuffle
+
block_size
+
size of the buffer in number of items used for shuffling
+
seed
+
random seed used for shuffling (or None)
+
+
+ +Expand source code + +
def BlockwiseShuffleIterator(source_iterator: CheckpointableIterator, block_size: int, seed: int = 0):
+    """
+    Shuffles a sequence of items by grouping consecutive items in blocks of fixed size, shuffling
+    each block, and yielding the shuffled items of all blocks as a flat sequence.
+
+    E.g. [1, 2, 3, 4, 5, 6, 7, 8] with block_size = 3 may yield [3, 1, 2, 4, 6, 5, 8, 7].
+
+    Args:
+        source_iterator: checkpointable iterator or restartable iterable over input items to shuffle
+        block_size: size of the buffer in number of items used for shuffling
+        seed: random seed used for shuffling (or None)
+    """
+    # This is implemented as a pipeline:
+    #  - group N consecutive items together
+    #  - shuffle them
+    #  - flatten the result
+    blocks = FixedBatchIterator(source_iterator, batch_size=block_size)
+    def shuffle_block_fn(random: Random, block: List):
+        random.shuffle(block)
+        return block
+    shuffled_blocks = SamplingRandomMapIterator(blocks, transform=shuffle_block_fn, seed=seed)
+    samples = SelectManyIterator(shuffled_blocks, collection_selector=lambda shuffled_block: iter(shuffled_block))
+    return samples
+
+
+
+
+
+

Classes

+
+
+class CheckpointableIterator +
+
+

Abstract base class that defines the interface for checkpointing.

+

The interface (getstate, setstate) is inspired by Python's random package.

+
+ +Expand source code + +
class CheckpointableIterator(collections.abc.Iterator):
+    """
+    Abstract base class that defines the interface for checkpointing.
+    
+    The interface (getstate, setstate) is inspired by Python's random package.
+    """
+    def __iter__(self):
+        return self
+
+    @abstractmethod
+    def getstate(self) -> Dict:
+        """
+        Get checkpoint of current state of iterator
+        
+        In a pipeline of iterators, this function __recursively__ calls itself on the preceeding iterator
+        and includes the gathered information in the returned checkpoint.
+        Thereby, to obtain a checkpoint of the state of an entire pipeline of iterators
+        you only have to call this function on the __last__ iterator in the pipeline.
+        A checkpoint is represented as a `dict`,
+        but the caller should treat a checkpoint as an opaque object
+        and not make any assumptions about the existence or meaning of the `dict` entries.
+        """
+        pass
+
+    @abstractmethod
+    def setstate(self, checkpoint: Optional[Dict]):
+        """
+        Set state of iterator to given checkpoint
+
+        In a pipeline of iterators, this function __recursively__ calls itself on the preceeding iterator.
+        Thereby, to set the state of an entire pipeline of iterators to a given checkpoint
+        you only have to call this function on the __last__ iterator in the pipeline.
+
+        Args:
+            checkpoint: Checkpoint that should be used to reset the state of the iterator (or pipeline).
+                        If this is __None__, the state of the iterator (or pipeline) is reset to the initial
+                        state immediately after construction.
+        """
+        pass
+
+    def __getstate__(self) -> Dict:  # implementation of pickle Protocol
+        return self.getstate()
+
+    def __setstate__(self, checkpoint: Optional[Dict]):
+        self.setstate(checkpoint)
+
+    @abstractmethod
+    def __next__(self):
+        pass
+
+

Ancestors

+
    +
  • collections.abc.Iterator
  • +
  • collections.abc.Iterable
  • +
+

Subclasses

+ +

Methods

+
+
+def getstate(self) -> Dict +
+
+

Get checkpoin t of current state of iterator

+

In a pipeline of iterators, this function recursively calls itself on the preceeding iterator +and includes the gathered information in the returned checkpoint. +Thereby, to obtain a checkpoint of the state of an entire pipeline of iterators +you only have to call this function on the last iterator in the pipeline. +A checkpoint is represented as a dict, +but the caller should treat a checkpoint as an opaque object +and not make any assumptions about the existence or meaning of the dict entries.

+
+ +Expand source code + +
@abstractmethod
+def getstate(self) -> Dict:
+    """
+    Get checkpoint of current state of iterator
+    
+    In a pipeline of iterators, this function __recursively__ calls itself on the preceeding iterator
+    and includes the gathered information in the returned checkpoint.
+    Thereby, to obtain a checkpoint of the state of an entire pipeline of iterators
+    you only have to call this function on the __last__ iterator in the pipeline.
+    A checkpoint is represented as a `dict`,
+    but the caller should treat a checkpoint as an opaque object
+    and not make any assumptions about the existence or meaning of the `dict` entries.
+    """
+    pass
+
+
+
+def setstate(self, checkpoint: Union[Dict, NoneType]) +
+
+

Set state of iterator to given checkpoint

+

In a pipeline of iterators, this function recursively calls itself on the preceeding iterator. +Thereby, to set the state of an entire pipeline of iterators to a given checkpoint +you only have to call this function on the last iterator in the pipeline.

+

Args

+
+
checkpoint
+
Checkpoint that should be used to reset the state of the iterator (or pipeline). +If this is None, the state of the iterator (or pipeline) is reset to the initial +state immediately after construction.
+
+
+ +Expand source code + +
@abstractmethod
+def setstate(self, checkpoint: Optional[Dict]):
+    """
+    Set state of iterator to given checkpoint
+
+    In a pipeline of iterators, this function __recursively__ calls itself on the preceeding iterator.
+    Thereby, to set the state of an entire pipeline of iterators to a given checkpoint
+    you only have to call this function on the __last__ iterator in the pipeline.
+
+    Args:
+        checkpoint: Checkpoint that should be used to reset the state of the iterator (or pipeline).
+                    If this is __None__, the state of the iterator (or pipeline) is reset to the initial
+                    state immediately after construction.
+    """
+    pass
+
+
+
+
+
+class NativeCheckpointableIterator +(iterable: Iterable) +
+
+

Simple wrapper class that turns a Python Iterable into a CheckpointableIterator

+

When calling setstate on this class, it simply replays the iterator all the way to the checkpoint one element at a time, +which makes it generally inefficient.

+

Warning: This class cannot be used with Iterators (as opposed to Iterables), which have an __iter__ function that simply returns self, but does not reset.

+
+ +Expand source code + +
class NativeCheckpointableIterator(Checkpoi
ntableIterator):
+    """
+    Simple wrapper class that turns a Python Iterable into a CheckpointableIterator
+    
+    When calling setstate on this class, it simply replays the iterator all the way to the checkpoint one element at a time,
+    which makes it generally inefficient.
+
+    Warning: This class cannot be used with Iterators (as opposed to Iterables), which have an `__iter__` function that simply returns self, but does not reset.
+    """
+    def __init__(self, iterable: Iterable):
+        # check whether iterable is iterable or iterator:
+        # if the variable iterable contains an iterator, the function __iter__ returns self
+        # if the variable iterable is an actual iterator, it should not return self
+        if iter(iterable) is iterable:  
+            raise ValueError('It looks like you are passing an iterator instead of an iterable. This is not supported and can cause undefined behavior when used with checkpointing.')
+        self._input_iterable = iterable
+        self.setstate(None)
+
+    def getstate(self) -> Dict:
+        return {'num_items_yielded': self._num_items_yielded}
+
+    def setstate(self, checkpoint: Optional[Dict]):
+        self._iterator = iter(self._input_iterable)
+        self._num_items_yielded = _advance_iterator(self._iterator, checkpoint['num_items_yielded']) if checkpoint is not None else 0
+
+    def __next__(self):
+        item = next(self._iterator)  # call this before increasing _num_items_yielded to correctly handle the case when a StopIteration exception is thrown
+        self._num_items_yielded += 1
+        return item
+
+

Ancestors

+ +

Inherited members

+ +
+
+class InfinitePermutationSourceIterator +(source_items: List, seed: Union[int, NoneType] = None, shuffle: bool = True, num_instances: int = 1, instance_rank: int = 0) +
+
+

Infinitely generates permutations of the items in the given list.

+

This is a source iterator: +It is meant to be used at the beginning of a data loading pipeline. +As such, it takes a list as its source and not a CheckpointableIterator. +The given list is loaded completely into RAM.

+

For example, this is used for randomizing the pathnames of data blocks read by ChunkedReadlinesIterator.

+

Args

+
+
source_items
+
input list, must not be empty and must be small enough to fit into RAM entirely, ownership of the list and the data goes to the iterator, do not modify it!
+
seed
+
random seed used for shuffling (or None)
+
shuffle
+
set False to bypass the shuffling. Then this is just a checkpointed version of itertools.cycle(). (Default: True)
+
num_instances
+
number of instances of this iterator. Meant for use with multi-process data loading, e.g., in distributed training.
+
instance_rank
+
rank of this instance of the iterator. Meant for use with multi-process data loading, e.g., in distributed training.
+
+
+ +Expand source code + +
class InfinitePermutationSourceIterator(CheckpointableIterator):
+    """
+    Infinitely generates permutations of the items in the given list.
+
+    This is a source iterator:
+    It is meant to be used at the beginning of a data loading pipeline.
+    As such, it takes a list as its source and not a CheckpointableIterator.
+    The given list is loaded completely into RAM.
+
+    For example, this is used for randomizing the pathnames of data blocks read by ChunkedReadlinesIterator.
+    """
+    def __init__(self, source_items: List, seed: Optional[int]=None, shuffle: bool=True, num_instances: int=1, instance_rank: int=0):
+        """
+        Args:
+            source_items: input list, must not be empty and must be small enough to fit into RAM entirely, ownership of the list and the data goes to the iterator, do not modify it!
+            seed: random seed used for shuffling (or None)
+            shuffle: set False to bypass the shuffling. Then this is just a checkpointed version of itertools.cycle(). (Default: True)
+            num_instances: number of instances of this iterator. Meant for use with multi-process data loading, e.g., in distributed training.
+            instance_rank: rank of this instance of the iterator. Meant for use with multi-process data loading, e.g., in distributed training.
+        """
+        self._source_items = source_items
+        if not self._source_items:
+            raise ValueError("InfinitePermutationIterator: source must not be empty")
+        self._shuffle = shuffle
+        self._seed = seed
+        self._num_instances = num_instances
+        self._instance_rank = instance_rank
+        self.setstate(None)
+
+    def getstate(self) -> Dict:
+        return {'random_state':      self._random_state,  # state of random generator before generating the current shuffling of the sequence
+                'num_items_yielded': self._num_items_yielded}    # how many items have already been iterated over in the current shuffling
+
+    def setstate(self, checkpoint: Optional[Dict]):
+        # set iteration state. Do this outside the generator below in case getstate() is called before ever iterating
+        self._random_state      = checkpoint['random_state']      if checkpoint else None
+        self._num_items_yielded = checkpoint['num_items_yielded'] if checkpoint else 0
+        # We define the iteration itself as a generator for ease of implementation.
+        # We could as well just have used an explicit state machine represented by class members.
+        def _generate() -> Iterator:
+            # create and reset random generator
+            random = Random(self._seed)
+            if self._random_state is not None:  # restore the random generator's state
+                random.setstate(self._random_state)
+            skip_to_checkpoint = self._num_items_yielded  # items to skip in order to advance to checkpoint
+            # main outer loop for infinite passes over items (reshuffle before each pass)
+            while True:
+                # (re-)shuffle all items
+                self._random_state = random.getstate()  # remember random state before shuffling
+                self._num_items_yielded   = 0
+                shuffled_items = self._source_items[:]  # note: if underlying iterator is checkpointable, use setstate(checkpoint['nested_state']) on it
+                if self._shuffle:
+                    random.shuffle(shuffled_items)
+                shuffled_iterator = iter(shuffled_items)
+                # skip initial items when restarting from checkpoint
+                if skip_to_checkpoint:  # @TODO: find a way to abstract this more, so that we can plug it into the 'for' statement directly
+                    self._num_items_yielded += _advance_iterator(shuffled_iterator, skip_to_checkpoint)
+                    skip_to_checkpoint = 0  # done skipping
+                # main inner loop over items
+                for item in shuffled_iterator:
+                    self._num_items_yielded += 1  # record how many items we have iterated over in this pass over the items
+                    if (self._num_items_yielded-1) % self._num_instances == self._instance_rank:  # build-in islice facility
+                        yield item
+        self._iterator = _generate()
+
+    def __next__(self):
+        return next(self._iterator)
+
+

Ancestors

+ +

Inherited members

+ +
+
+class SelectManyIterator +(source_iterator: CheckpointableIterator, collection_selector: Union[Callable[[Any], Iterator], NoneType] = None) +
+
+

Projects each element of a source sequence to a sequence and flattens the resulting sequences into one sequence.

+

Args

+
+
source_iterator
+
iterator over the items to pass to collection_selector()
+
collection_selector
+
user callback that maps an item into an Iterable, whose items will be yielded. +The returned Iterator is used only once. Hence, it is also allowed to +return self-iterables, such as iterators and generator expressions. +If None is given, no callback is applied.
+
+
+ +Expand source code + +
class SelectManyIterator(CheckpointableIterator):
+    """
+    Projects each element of a source sequence to a sequence and flattens the resulting sequences into one sequence.
+    """
+    def __init__(self, source_iterator: CheckpointableIterator, collection_selector: Optional[Callable[[Any], Iterator]]=None):
+        """
+        Args:
+            source_iterator: iterator over the items to pass to collection_selector()
+            collection_selector: user callback that maps an item into an Iterable, whose items will be yielded.
+                                 The returned Iterator is used only once. Hence, it is also allowed to
+                                 return self-iterables, such as iterators and generator expressions.
+                                 If None is given, no callback is applied.
+        """
+        if not isinstance(source_iterator, CheckpointableIterator):
+            raise ValueError('source_iterator has to be a CheckpointableIterator')
+        self._source_iterator = source_iterator          # type: CheckpointableIterator
+        self._collection_selector = collection_selector  # type: Callable[[Any], Iterator]
+        self.setstate(None)
+
+    def getstate(self) -> Dict:
+        return {'source_state':            self._source_state,
+                &
#39;flattened_items_yielded': self._flattened_items_yielded}
+
+    def setstate(self, checkpoint: Optional[Dict]):
+        self._source_state            = checkpoint['source_state']            if checkpoint else None
+        self._flattened_items_yielded = checkpoint['flattened_items_yielded'] if checkpoint else 0
+        self._source_iterator.setstate(self._source_state)
+        def _generate():
+            skip_to_checkpoint = self._flattened_items_yielded
+            # main loop over source source_items
+            for source_item in self._source_iterator:
+                if self._collection_selector is not None:
+                    data = iter(self._collection_selector(source_item))
+                else:
+                    data = iter(source_item)
+                self._flattened_items_yielded = 0
+                if skip_to_checkpoint:
+                    #print("Skipping to index", skip_to_checkpoint, file=sys.stderr)
+                    self._flattened_items_yielded += _advance_iterator(data, skip_to_checkpoint)
+                    skip_to_checkpoint = 0
+                # main loop over lines
+                for item in data:
+                    self._flattened_items_yielded += 1
+                    yield item
+                self._source_state = self._source_iterator.getstate()
+        self._iterator = _generate()
+
+    def __next__(self):
+        return next(self._iterator)
+
+

Ancestors

+ +

Inherited members

+ +
+
+class BufferedShuffleIterator +(source_iterator: CheckpointableIterator, buffer_size: int, seed: int = 0) +
+
+

Shuffles given iterable using a limited buffer.

+

Args

+
+
source_iterator
+
checkpointable iterator or restartable iterable over input items to shuffle
+
buffer_size
+
size of the buffer in number of items used for shuffling
+
seed
+
random seed used for shuffling (or None)
+
+
+ +Expand source code + +
class BufferedShuffleIterator(CheckpointableIterator):
+    """
+    Shuffles given iterable using a limited buffer.
+    """
+    def __init__(self, source_iterator: CheckpointableIterator, buffer_size: int, seed: int = 0):
+        """
+        Args:
+            source_iterator: checkpointable iterator or restartable iterable over input items to shuffle
+            buffer_size: size of the buffer in number of items used for shuffling
+            seed: random seed used for shuffling (or None)
+        """
+        if not isinstance(source_iterator, CheckpointableIterator):
+            raise ValueError('source_iterator has to be a CheckpointableIterator')
+        self._source_iterator = source_iterator
+  
      self._buffer = [None for _ in range(buffer_size)]  # maybe do this lazily?   --Yes, since user may set state immediately, then this is not needed here
+        self._random = Random(seed)
+        self.setstate(None)
+
+    def getstate(self) -> Dict:
+        return {'source_state': self._source_iterator.getstate(),
+                'buffer':       copy.deepcopy(self._buffer),
+                'random_state': self._random.getstate()}
+
+    def setstate(self, checkpoint: Optional[Dict]):
+        if checkpoint:
+            self._source_iterator.setstate(checkpoint['source_state'])
+            self._buffer = checkpoint['buffer']
+            self._random.setstate(checkpoint['random_state'])
+            # @TODO: Can we add a comment how the flush part is handled?
+        else:
+            self._source_iterator.setstate(None)
+        self._iterator = self._generate()
+
+    def _generate(self) -> Iterator:
+        # shuffle data with a buffer:
+        # this is similar to what the Fisher-Yates shuffle does,
+        # but modified to run with a constant-size buffer
+        # see https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle
+        # this was inspired by an algorithm implemented in Kaldi
+        # see https://kaldi-asr.org/doc/nnet-shuffle-egs_8cc.html
+        for item in self._source_iterator:
+            index = self._random.randrange(0, len(self._buffer))
+            result = None
+            if self._buffer[index] is not None:
+                result = self._buffer[index]
+            self._buffer[index] = item
+            # only yield value once buffer is updated to allow for correct checkpointing!
+            if result is not None:
+                yield result
+
+        # flush buffer
+        while self._buffer:
+            item = self._buffer.pop()
+            if item is not None:
+                yield item
+
+    def __next__(self):
+        return next(self._iterator)
+
+

Ancestors

+ +

Inherited members

+ +
+
+class MapIterator +(source_iterator: CheckpointableIterator, transform: Callable[[str], Any]) +
+
+

Applies given tranform to each data item

+

Args

+
+
source_iterator
+
checkpointable iterator
+
transform
+
function to be applied to each data item
+
+
+ +Expand source code + +
class MapIterator(CheckpointableIterator):
+    """
+    Applies given tranform to each data item
+    """
+    def __init__(self, source_iterator: CheckpointableIterator, transform: Callable[[str],Any]):
+        """
+        Args:
+            source_iterator: checkpointable iterator
+            transform: function to be applied to each data item
+        """
+        if not 
isinstance(source_iterator, CheckpointableIterator):
+            raise ValueError('source_iterator has to be a CheckpointableIterator')
+        self._source_iterator = source_iterator
+        self._transform = transform
+
+    def getstate(self) -> Dict:
+        return self._source_iterator.getstate()
+
+    def setstate(self, checkpoint: Optional[Dict]):
+        self._source_iterator.setstate(checkpoint)
+
+    def __next__(self):
+        return self._transform(next(self._source_iterator))
+
+

Ancestors

+ +

Inherited members

+ +
+
+class ZipIterator +(*source_iterators: CheckpointableIterator) +
+
+

Zips items from all given iterators, like the Python standard function zip().

+

Like Python's build-in zip(), the iteration stops when the shortest input iterable is exhausted.

+

Args

+
+
source_iterators
+
list of iterators to zip, item by item
+
+
+ +Expand source code + +
class ZipIterator(CheckpointableIterator):
+    """
+    Zips items from all given iterators, like the Python standard function zip().
+
+    Like Python's build-in zip(), the iteration stops when the shortest input iterable is exhausted.
+    """
+    def __init__(self, *source_iterators: CheckpointableIterator):
+        """
+        Args:
+            source_iterators: list of iterators to zip, item by item
+        """
+        for source_iterator in source_iterators:
+            if not isinstance(source_iterator, CheckpointableIterator):
+                raise ValueError('all iterators in source_iterators have to be CheckpointableIterator')
+        self._source_iterators = source_iterators    # type: List[CheckpointableIterator]
+
+    def getstate(self) -> Dict:
+        return {'input_states': tuple(iterator.getstate() for iterator in self._source_iterators)}
+
+    def setstate(self, checkpoint: Optional[Dict]):
+        if checkpoint is None:
+            for iterator in self._source_iterators:
+                iterator.setstate(None)
+        else:
+            for iterator, state in zip(self._source_iterators, checkpoint['input_states']):
+                iterator.setstate(state)
+
+    def __next__(self):
+        res = []  # (note: can't use a generator expression, as it gets confused when a next() call raises StopIteration)
+        for iterator in self._source_iterators:
+            res.append(next(iterator))
+        return tuple(res)
+
+

Ancestors

+ +

Inherited members

+ +
+
+class WindowedIterator +(source_iterator: CheckpointableIterator, width: int) +
+
+

Yields 'width' consecutive items in a sliding window.

+

E.g. [1, 2, 3, 4, 5, 6] with width = 3 will yield +[[1, 2, 3], [2, 3, 4], [3, 4, 5], [4, 5, 6]]

+

Args

+
+
source_iterator
+
checkpointable input iterators
+
+
+ +Expand source code + +
class WindowedIterator(CheckpointableIterator):
+    """
+    Yields 'width' consecutive items in a sliding window.
+
+    E.g. [1, 2, 3, 4, 5, 6] with width = 3 will yield
+    [[1, 2, 3], [2, 3, 4], [3, 4, 5], [4, 5, 6]]
+    """
+    def __init__(self, source_iterator: CheckpointableIterator, width: int):
+        """
+        Args:
+            source_iterator: checkpointable input iterators
+        """
+        if not isinstance(source_iterator, CheckpointableIterator):
+            raise ValueError('source_iterator has to be a CheckpointableIterator')
+        self._source_iterator = source_iterator  # type: CheckpointableIterator
+        self._width = width                      # type: int
+        self.setstate(None)
+
+    def getstate(self) -> Dict:
+        return {'source_state': self._source_state,  # state for first item in FIFO
+                'item_index':  self._item_index}   # index of next item to serve
+
+    def setstate(self, checkpoint: Optional[Dict]):
+        self._source_state = checkpoint['source_state'] if checkpoint else None
+        self._item_index   = checkpoint['item_index']   if checkpoint else 0
+        self._source_iterator.setstate(self._source_state)
+        self._iterator = self._generate()
+
+    def _fifo_slice(self, i):  # returns a window into the FIFO beginning at i
+        # @TODO: for efficiency, make this a slice view
+        return tuple(self._fifo[i:i + self._width])
+
+    def _generate(self) -> Iterator:
+        self._source_state = self._source_iterator.getstate()
+        self._fifo = list(islice(self._source_iterator, self._width))
+        # we do this in overlapping blocks of length 2*width, for easier checkpointing and potential efficiency
+        while len(self._fifo) == self._width:
+            # we got 'width' items; append another 'width' (or less if at end)
+            next_input_state = self._source_iterator.getstate()
+            self._fifo.extend(islice(self._source_iterator, self._width))
+            # now serve all positions in first half (last = width - 1). If at end, then limit accordingly.
+            last = min(self._width - 1, len(self._fifo) - self._width)
+            while self._item_index <= last:
+                window = self._fifo_slice(self._item_index)
+                self._item_index += 1
+                yield window
+            # drop all we just served; if < width left, we have hit the end
+            self._fifo = self._fifo[last + 1:]    # Note: This must be a new list, since the old might still be in a slice view.
+            self._source_state = next_input_state  # this reflects n
ow the first element in the FIFO 
+            self._item_index = 0
+
+    def __next__(self):
+        return next(self._iterator)
+
+

Ancestors

+ +

Inherited members

+ +
+
+class FixedBatchIterator +(source_iterator: CheckpointableIterator, batch_size: int) +
+
+

Batches N consecutive items into a single item that is a list of these items.

+

E.g. [1, 2, 3 4, 5, 6, 7, 8] with batch_size = 3 will yield +[(1, 2, 3), (4, 5, 6), (7, 8)]

+

Args

+
+
source_iterator
+
checkpointable input iterators
+
batch_size
+
number of items per batch
+
+
+ +Expand source code + +
class FixedBatchIterator(CheckpointableIterator):
+    """
+    Batches N consecutive items into a single item that is a list of these items.
+
+    E.g. [1, 2, 3 4, 5, 6, 7, 8] with batch_size = 3 will yield
+    [(1, 2, 3), (4, 5, 6), (7, 8)]
+    """
+    def __init__(self, source_iterator: CheckpointableIterator, batch_size: int):
+        """
+        Args:
+            source_iterator: checkpointable input iterators
+            batch_size: number of items per batch
+        """
+        if not isinstance(source_iterator, CheckpointableIterator):
+            raise ValueError('source_iterator has to be a CheckpointableIterator')
+        self._source_iterator = source_iterator  # type: CheckpointableIterator
+        self._batch_size = batch_size            # type: int
+        self.setstate(None)
+
+    def getstate(self) -> Dict:
+        return {'source_state': self._source_iterator.getstate()}  # state for first item in next batch
+
+    def setstate(self, checkpoint: Optional[Dict]):
+        self._source_state = checkpoint['source_state'] if checkpoint else None
+        self._source_iterator.setstate(self._source_state)
+        self._iterator = self._generate()
+
+    def _generate(self) -> Iterator:
+        while True:
+            batch = list(islice(self._source_iterator, self._batch_size))
+            if not batch:
+                break
+            yield batch
+
+    def __next__(self):
+        return next(self._iterator)
+
+

Ancestors

+ +

Inherited members

+ +
+
+class RandomIterator +(seed: Union[int, NoneType] = None) +
+
+

Iterator to generate uniformly distributed random numbers in the interval [0,1). +Very similar to Random.random(), except that random numbers are +obtained via next().

+

Args

+
+
seed
+
Random seed.
+
+
+ +Expand source code + +
class RandomIterator(CheckpointableIterator):
+    """
+    Iterator to generate uniformly distributed random numbers in the interval [0,1).
+    Very similar to Random.random(), except that random numbers are
+    obtained via next().
+    """
+    def __init__(self, seed: Optional[int]=None):
+        """
+        Args:
+            seed: Random seed.
+        """
+        self._random = Random()  # type: Random
+        if seed is not None:
+            self._random.seed(seed)
+
+    def getstate(self) -> Dict:
+        return {'random_state': self._random.getstate()}
+
+    def setstate(self, checkpoint: Optional[Dict]):
+        self._random.setstate(checkpoint['random_state'] if checkpoint else None)
+
+    def __next__(self):
+        return self._random.random()
+
+

Ancestors

+ +

Inherited members

+ +
+
+class RecurrentIterator +(source_iterator: CheckpointableIterator, step_function: Callable[[Any, Any], Tuple[Any, Any]], initial_state: Any = None) +
+
+

Iterates statefully over a step function. The step function accepts a state and a new item, +and returns a new state and an output item, which is yielded.

+

Args

+
+
source_iterator
+
checkpointable iterator to recur over
+
step_function
+
user-supplied function with signature step_function(state, item) -> (new_state, output)
+
initial_state
+
initial state to be passed to the step_function upon first invocation
+
+
+ +Expand source code + +
class RecurrentIterator(CheckpointableIterator):
+    """
+    Iterates statefully over a step function. The step function accepts a state and a new item,
+    and returns a new state and an output item, which is yielded.
+    """
+    def __init__(self, source_iterator: Checkpointabl
eIterator, step_function: Callable[[Any,Any], Tuple[Any,Any]], initial_state: Any = None):
+        """
+        Args:
+            source_iterator: checkpointable iterator to recur over
+            step_function: user-supplied function with signature step_function(state, item) -> (new_state, output)
+            initial_state: initial state to be passed to the step_function upon first invocation
+        """
+        if not isinstance(source_iterator, CheckpointableIterator):
+            raise ValueError('source_iterator has to be a CheckpointableIterator')
+        self._source_iterator = source_iterator  # type: CheckpointableIterator
+        self._step_function = step_function      # type: Callable[[Any,Any], Tuple[Any,Any]]
+        self._initial_state = initial_state      # type: Any
+        self.setstate(None)
+    
+    def getstate(self):
+        return {'recurrent_state': self._recurrent_state,
+                'source_state':    self._source_iterator.getstate()}
+    
+    def setstate(self, checkpoint):
+        self._recurrent_state = checkpoint['recurrent_state'] if checkpoint else self._initial_state
+        self._source_iterator.setstate(checkpoint['source_state'] if checkpoint else None)
+        def _generate():
+            for item in self._source_iterator:
+                self._recurrent_state, output = self._step_function(self._recurrent_state, item)
+                yield output
+        self._iterator = _generate()
+
+    def __next__(self):
+        return next(self._iterator)
+
+

Ancestors

+ +

Inherited members

+ +
+
+class PrefetchIterator +(source_iterator: CheckpointableIterator, buffer_size: int = 1000) +
+
+

An iterator prefetching data into a buffer on a seperate thread to smooth out IO latency.

+

Args

+
+
source_iterator
+
checkpointable iterator to recur over
+
buffer_size
+
size of the queue between the threads
+
+
+ +Expand source code + +
class PrefetchIterator(CheckpointableIterator):
+    """
+    An iterator prefetching data into a buffer on a seperate thread to smooth out IO latency.
+
+    Args:
+        source_iterator: checkpointable iterator to recur over
+        buffer_size: size of the queue between the threads
+    """
+    def __init__(self, source_iterator: CheckpointableIterator, buffer_size: int=1000):
+        if not isinstance(source_iterator, CheckpointableIterator):
+            raise ValueError('source_iterator has to be a CheckpointableIterator')
+        self._source_iterator = source_iterator  # type:CheckpointableIterator
+        self._buffer_size = buffer_size          # type: int
+        self._queue = None                       # t
ype: Optional[ClosableQueue]
+        self._thread = None                      # type: Optional[Thread]
+        self.setstate(None)
+        
+    def getstate(self) -> Dict:
+        return {'source_state': self._source_state,
+                'item_offset' : self._item_offset  }
+
+    def setstate(self, checkpoint: Optional[Dict]):
+        if self._thread is not None:  # if there is a prefetching thread running, close the queue and wait for the thread to terminate
+            assert self._queue is not None
+            self._queue.close()
+            self._thread.join()
+        
+        self._source_state = checkpoint['source_state'] if checkpoint is not None else None
+        self._item_offset  = checkpoint['item_offset' ] if checkpoint is not None else 0
+
+        self._source_iterator.setstate(self._source_state)
+
+        self._queue = ClosableQueue(maxsize=self._buffer_size)  # clear queue
+        # make thread daemonic so it is killed when the main program terminates
+        self._thread = Thread(target=self._prefetch_thread_fn, args=(self._source_iterator, self._item_offset, self._buffer_size, self._queue), daemon=True)
+        self._thread.start()
+
+    @staticmethod
+    def _prefetch_thread_fn(source, item_offset, buffer_size, queue):  # behavior of the prefetching thread, only call from that thread!
+        _advance_iterator(source, item_offset)  # skip to checkpoint
+
+        while True:
+            try:
+                item = next(source)
+            except StopIteration:
+                queue.close()
+                return
+            
+            if item_offset == buffer_size - 1:  # send a new source state a the END of each window of length _buffer_size
+                source_state = source.getstate()  # this is the state for retrieving the NEXT element, i.e. the first element of the next buffer
+                item_offset = 0
+            else:
+                source_state = None
+                item_offset += 1
+            msg = (item, source_state)
+
+            try:
+                queue.put(msg)
+            except ClosedException:
+                return
+
+    def __next__(self):
+        try:
+            msg = self._queue.get()
+        except ClosedException:
+            raise StopIteration
+
+        item, prefetch_source_state = msg
+        if prefetch_source_state is not None:
+            assert self._item_offset == self._buffer_size - 1  # we expect a new source state at then END of each window of length _buffer_size
+            self._source_state = prefetch_source_state
+            self._item_offset = 0
+        else:
+            self._item_offset = self._item_offset + 1
+            assert self._item_offset < self._buffer_size
+        return item  # for debugging, its useful to return msg instead of item
+
+    def __del__(self):  # note: this is often not called. If you really need it, gc.collect() will do the trick.
+        if self._thread is not None:
+            assert self._queue is not None
+            self._queue.close()
+            try:
+                self._thread.join()
+            except:
+                pass
+
+

Ancestors

+ +

Inherited members

+ +
+
+class BucketedReadaheadBatchIterator +(source_iterator: CheckpointableIterator, read_ahead: int, key: Callable[[Any], Any], batch_size: Union[int, Callable[[Any], int]], shuffle: bool = True, seed: Union[int, NoneType] = None) +
+
+

Iterates over items from a checkpointable iterator and groups items of similar length into batches.

+

The algorithm reads a head a certain number of lines (e.g. 10 million), sorts them by +length, and them groups them into batches from start to end. The sort is stable, such +that prior randomization is not undone (except for the length grouping). The batch size +is dynamic, and determined by a user-provided callback.

+

This is based on Marian NMT's BatchGenerator.

+

Args

+
+
source_iterator
+
The data set that is read from. Typically this is an infinite source.
+
read_ahead
+
Number of items to fetch ahead for grouping purposes.
+
key
+
User-provided callback to define how data is sorted for purpose of batching.
+
batch_size
+
Batch size in number of items. Either an integer or a callback to determine batch size for a given first batch item.
+
shuffle
+
Pass False to not randomize the batches. (default: True)
+
seed
+
Random seed for batch shuffling.
+
+
+ +Expand source code + +
class BucketedReadaheadBatchIterator(CheckpointableIterator):
+    """
+    Iterates over items from a checkpointable iterator and groups items of similar length into batches.
+
+    The algorithm reads a head a certain number of lines (e.g. 10 million), sorts them by
+    length, and them groups them into batches from start to end. The sort is stable, such
+    that prior randomization is not undone (except for the length grouping). The batch size
+    is dynamic, and determined by a user-provided callback.
+
+    This is based on Marian NMT's BatchGenerator.
+    """
+
+    def __init__(self, source_iterator: CheckpointableIterator, read_ahead: int, key: Callable[[Any], Any], batch_size: Union[int,Callable[[Any], int]], shuffle: bool=True, seed: Optional[int]=None):
+        """
+        Args:
+            source_iterator: The data set that is read from. Typically this is an infinite source.
+            read_ahead: Number of items to fetch ahead for grouping purposes.
+            key: User-provided callback to define how data is sorted for purpose of batching.
+            batch_size: Batch size in number of items. Either an integer or a callback to determine batch size for a given first batch item.
+            shuffle: Pass False to not randomize the batches. (default: True)
+            seed: Random seed for batch shuffling.
+        """
+        if not isinstance(source_iterator, CheckpointableIterator):
+            raise ValueError('source_iterator has to be a CheckpointableIterator')
+        # keep arguments
+        self._key = key                # type: Callable[[Any], Any]
+        self._batch_size = batch_size  # type: Union[int,Callable[[Any], int]]
+        self._read_ahead = read_ahead  # type: int
+        # initialize state
+        self._random = None
+        if shuffle:
+            self._random = Random()                    # type: Random
+            if seed is not None:
+                self._random.seed(seed)
+        self._source_iterator = iter(source_iterator)  # type: CheckpointableIterator
+        self.setstate(None)
+
+    def getstate(self):
+        return {'source_state': self._source_state,
+                'random_state': self._random_state,
+                'num_served':   self._num_batches_yielded}
+
+    def setstate(self, checkpoint: Optional[Dict]):
+        self._source_state        = checkpoint['source_state'] if checkpoint else None  # type: Dict  -- state of input before reading the current set of batches
+        self._random_state        = checkpoint['random_state'] if checkpoint else None  # type: Any   -- state of random generator at _source_state
+        self._num_batches_yielded = checkpoint['num_served']   if checkpoint else 0     # type: int   -- number of batches served from the current set of batches
+        # checkpointing: restore to start of current set of batches
+        self._source_iterator.setstate(self._source_state)
+        if self._random_state:
+            self._random.setstate(self._random_state)
+        self._source_exhausted = False  # type: bool  -- set to True once we hit StopIteration on source
+        def _generate():
+            skip_to_checkpoint = self._num_batches_yielded
+            source_exhausted = False
+            while not source_exhausted:
+                # prefetch the readahead buffer
+                self._source_state = self._source_iterator.getstate()
+                self._random_state = self._random.getstate() if self._random else None
+                items = list(islice(self._source_iterator, self._read_ahead))
+                source_exhausted = (len(items) < self._read_ahead)
+                # create batches
+                batches = self._create_batches(items)
+                # shuffle the batches
+                if self._random:
+                    self._random.shuffle(batches)
+                # on first loop iteration, restore iterator inside batches from checkpoint
+                batches = iter(batches)
+                self._num_batches_yielded = _advance_iterator(batches, skip_to_checkpoint)
+                skip_to_checkpoint = 0
+                # main loop over batches in current read-ahead section
+                for batch in batches:
+                    self._num_batches_yielded += 1
+                    yield batch
+        self._iterator = _generate()  # type: Iterator  -- iterator into current set of batches
+
+    def _create_batches(self, items: List[Any]) -> List[List[Any]]:  # helper to form batches from a list of items
+            # sort by length, longest first
+            items.sort(key=self._key, reverse=True)  # note: sort() is stable, so we won't undo any randomization besides the bucketing
+            # group into batches
+            cur_batch = None
+            batches = []
+            for item in items:
+                if not cur_batch:
+                    batch_size = self._batch_size if isinstance(self._batch_size, int) else \
+                                 self._batch_size(item)
+                    cur_batch = []
+                cur_batch.append(item)
+                if len(cur_batch) >= batch_size:  # this batch is full
+                    batches.append(cur_batch)
+                    cur_batch = None
+            if cur_batch:
+                batches.append(cur_batch)
+            return batches
+
+    def __next__(self):
+        return next(self._iterator)
+
+

Ancestors

+ +

Inherited members

+ +
+
+
+
+ +
+ + + + + \ No newline at end of file diff --git a/model/third_party/HMNet/DataLoader/infinibatch/docs/infinibatch/torch/data.html b/model/third_party/HMNet/DataLoader/infinibatch/docs/infinibatch/torch/data.html new file mode 100644 index 0000000000000000000000000000000000000000..084b0ac93b596a09a349bb2aaa4509fd5b5563ad --- /dev/null +++ b/model/third_party/HMNet/DataLoader/infinibatch/docs/infinibatch/torch/data.html @@ -0,0 +1,268 @@ + + + + + + +infinibatch.torch.data API documentation + + + + + + + + + +
+
+
+

Module infinibatch.torch.data

+
+
+
+ +Expand source code + +
import torch
+from infinibatch.iterators import CheckpointableIterator
+from infinibatch.datasets  import chunked_dataset_iterator
+from typing import Union, Iterable, Any
+
+
+# @TODO: This has been tested once, but we have no regression test presently. I am worried tests will fail if Torch is not installed.
+class IterableCheckpointedDataset(torch.utils.data.IterableDataset):
+    """
+    Wraps a CheckpointableIterator into a PyTorch IterableDataset, which is recognized by its type by
+    PyTorch's DataLoader class.
+    """
+    def __init__(self, source: CheckpointableIterator):
+        super().__init__()
+        self._source = source
+
+    def __iter__(self):  # this is called in the forked clone
+        worker_info = torch.utils.data.get_worker_info()
+        assert worker_info is None or worker_info.num_workers == 1  # not supported since we can't get at the checkpoint for each worker
+        return iter(self._source)
+
+
+# @TODO: This is currently untested, and may not work presently.
+class IterableChunkedDataset(torch.utils.data.IterableDataset):
+    def __init__(self, paths: Union[str, Iterable[str]], shuffle: bool=True, buffer_size: int=2**20, transform=None, seed: int=None, world_size: int=1, rank: int=0, num_workers_per_rank: int=1):
+        super().__init__()
+        self.rank = rank
+        self.num_workers_per_rank = num_workers_per_rank
+        # instance_rank is set assuming that num_workers_per_rank = 1 and adapted dynamically in __iter__
+        self.dataset = chunked_dataset_iterator(paths, shuffle=shuffle, buffer_size=buffer_size, transform=transform, seed=seed, num_instances=world_size*num_workers_per_rank, instance_rank=rank)
+
+    def __iter__(self):
+        worker_info = torch.utils.data.get_worker_info()
+        if worker_info is None:  # single-process data loading
+            self.dataset._instance_rank = self.rank
+        else:
+            assert worker_
info.num_workers == self.num_workers_per_rank
+            self.dataset._instance_rank = self.rank * self.num_workers_per_rank + worker_info.id
+        return iter(self.dataset)
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class IterableCheckpointedDataset +(source: CheckpointableIterator) +
+
+

Wraps a CheckpointableIterator into a PyTorch IterableDataset, which is recognized by its type by +PyTorch's DataLoader class.

+
+ +Expand source code + +
class IterableCheckpointedDataset(torch.utils.data.IterableDataset):
+    """
+    Wraps a CheckpointableIterator into a PyTorch IterableDataset, which is recognized by its type by
+    PyTorch's DataLoader class.
+    """
+    def __init__(self, source: CheckpointableIterator):
+        super().__init__()
+        self._source = source
+
+    def __iter__(self):  # this is called in the forked clone
+        worker_info = torch.utils.data.get_worker_info()
+        assert worker_info is None or worker_info.num_workers == 1  # not supported since we can't get at the checkpoint for each worker
+        return iter(self._source)
+
+

Ancestors

+
    +
  • torch.utils.data.dataset.IterableDataset
  • +
  • torch.utils.data.dataset.Dataset
  • +
+
+
+class IterableChunkedDataset +(paths: Union[str, Iterable[str]], shuffle: bool = True, buffer_size: int = 1048576, transform=None, seed: int = None, world_size: int = 1, rank: int = 0, num_workers_per_rank: int = 1) +
+
+

An iterable Dataset.

+

All datasets that represent an iterable of data samples should subclass it. +Such form of datasets is particularly useful when data come from a stream.

+

All subclasses should overwrite :meth:__iter__, which would return an +iterator of samples in this dataset.

+

When a subclass is used with :class:~torch.utils.data.DataLoader, each +item in the dataset will be yielded from the :class:~torch.utils.data.DataLoader +iterator. When :attr:num_workers > 0, each worker process will have a +different copy of the dataset object, so it is often desired to configure +each copy independently to avoid having duplicate data returned from the +workers. :func:~torch.utils.data.get_worker_info, when called in a worker +process, returns information about the worker. It can be used in either the +dataset's :meth:__iter__ method or the :class:~torch.utils.data.DataLoader 's +:attr:worker_init_fn option to modify each copy's behavior.

+

Example 1: splitting workload across all workers in :meth:__iter__::

+
>>> class MyIterableDataset(torch.utils.data.IterableDataset):
+...     def __init__(self, start, end):
+...         super(MyIterableDataset).__init__()
+...         assert end > start, "this example code only works with end >= start"
+...         self.start = start
+...         self.end = end
+...
+...     def __iter__(self):
+...         worker_info = torch.utils.data.get_worker_info()
+...         if worker_info is None:  # single-process data loading, return the full iterator
+...             iter_start = self.start
+...             iter_end = self.end
+...         else:  # in a worker process
+...             # split workload
+...             per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
+...             worker_id = worker_info.id
+...             iter_start = self.start + worker_id * per_worker
+...             iter_end = min(iter_start + per_worker, self.end)
+...         return iter(range(iter_start, iter_end))
+...
+>>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
+>>> ds = MyIterableDataset(start=3, end=7)
+
+>>> # Single-process loading
+>>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
+[3, 4, 5, 6]
+
+>>> # Mult-process loading with two worker processes
+>>> # Worker 0 fetched [3, 4].  Worker 1 fetched [5, 6].
+>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
+[3, 5, 4, 6]
+
+>>> # With even more workers
+>>> print(list(torch.utils.data.DataLoader(ds, num_workers=20)))
+[3, 4, 5, 6]
+
+

Example 2: splitting workload across all workers using :attr:worker_init_fn::

+
>>> class MyIterableDataset(torch.utils.data.IterableDataset):
+...     def __init__(self, start, end):
+...         super(MyIterableDataset).__init__()
+...         assert end > start, "this example code only works with end >= start"
+...         self.start = start
+...         self.end = end
+...
+...     def __iter__(self):
+...         return iter(range(self.start, self.end))
+...
+>>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
+>>> ds = MyIterableDataset(start=3, end=7)
+
+>>> # Single-process loading
+>>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
+[3, 4, 5, 6]
+>>>
+>>> # Directly doing multi-process loading yields duplicate data
+>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
+[3, 3, 4, 4, 5, 5, 6, 6]
+
+>>> # Define a `worker_init_fn` that configures each dataset copy differently
+>>> def worker_init_fn(worker_id):
+...     worker_info = torch.utils.data.get_worker_info()
+...     dataset = worker_info.dataset  # the dataset copy in this worker process
+...     overall_start = dataset.start
+...     overall_end = dataset.end
+...     # configure the dataset to only process the split workload
+...     per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers)))
+...     worker_id = worker_info.id
+...     dataset.start = overall_start + worker_id * per_worker
+...     dataset.end = min(dataset.start + per_worker, overall_end)
+...
+
+>>> # Mult-process loading with the custom `worker_init_fn`
+>>> # Worker 0 fetched [3, 4].  Worker 1 fetched [5, 6].
+>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn)))
+[3, 5, 4, 6]
+
+>>> # With even more workers
+>>> print(list(torch.utils.data.DataLoader(ds, num_workers=20, worker_init_fn=worker_init_fn)))
+[3, 4, 5, 6]
+
+
+ +Expand source code + +
class IterableChunkedDataset(torch.utils.data.IterableDataset):
+    def __init__(self, paths: Union[str, Iterable[str]], shuffle: bool=True, buffer_size: int=2**20, transform=None, seed: int=None, world_size: int=1, rank: int=0, num_workers_per_rank: int=1):
+        super().__init__()
+        self.rank = rank
+        self.num_workers_per_rank = num_workers_per_rank
+        # instance_rank is set assuming that num_workers_per_rank = 1 and adapted dynamically in __iter__
+        self.dataset = chunked_dataset_iterator(paths, shuffle=shuffle, buffer_size=buffer_size, transform=transform, seed=seed, num_instances=world_size*num_workers_per_rank, instance_rank=rank)
+
+    def __iter__(self):
+        worker_info = torch.utils.data.get_worker_info()
+        if worker_info is None:  # single-process data loading
+            self.dataset._instance_rank = self.rank
+       
 else:
+            assert worker_info.num_workers == self.num_workers_per_rank
+            self.dataset._instance_rank = self.rank * self.num_workers_per_rank + worker_info.id
+        return iter(self.dataset)
+
+

Ancestors

+
    +
  • torch.utils.data.dataset.IterableDataset
  • +
  • torch.utils.data.dataset.Dataset
  • +
+
+
+
+
+ +
+ + + + + \ No newline at end of file diff --git a/model/third_party/HMNet/DataLoader/infinibatch/docs/infinibatch/torch/index.html b/model/third_party/HMNet/DataLoader/infinibatch/docs/infinibatch/torch/index.html new file mode 100644 index 0000000000000000000000000000000000000000..6468d9bc5da8da7fad63dee970ec8b1339134a10 --- /dev/null +++ b/model/third_party/HMNet/DataLoader/infinibatch/docs/infinibatch/torch/index.html @@ -0,0 +1,65 @@ + + + + + + +infinibatch.torch API documentation + + + + + + + + + +
+
+
+

Module infinibatch.torch

+
+
+
+
+

Sub-modules

+
+
infinibatch.torch.data
+
+
+
+
+
+
+
+
+
+
+
+
+ +
+ + + + + \ No newline at end of file diff --git a/model/third_party/HMNet/DataLoader/infinibatch/infinibatch/__init__.py b/model/third_party/HMNet/DataLoader/infinibatch/infinibatch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d0539435729f8df6f6e98a3cd86d66627971ae58 --- /dev/null +++ b/model/third_party/HMNet/DataLoader/infinibatch/infinibatch/__init__.py @@ -0,0 +1,293 @@ +""" +Infinibatch is a library of checkpointable iterators for randomized data loading of massive data sets in deep neural network training. + + +## Features + + * support for corpora much larger than fit into RAM + * hierarchical block+sentence-level randomization over the whole corpus, different randomization in each epoch + * only load the data that is needed + * very fast start-up time (does not need to read full corpus) + * only requires the most basic of data preparation (e.g. no indexing) + * for multi-GPU, only load what the respective GPU needs + * 100% accurate check-pointing, restore from checkpoint should not read all data up to the checkpoint + * support automatic bucketed batching with dynamic batch sizes + * pre-fetching thread + * composable, as to support for complex batching, e.g. negative samples from multiple documents + + +## Getting Started + +Infinibatch requires Python 3.5 and has no dependencies. +There is presently no pip package. +To install it, please copy this library into a subfolder in your project: +```bash +cd YOUR_PROJECT_FOLDER +git clone https://msasg.visualstudio.com/DefaultCollection/SDRG/_git/infinibatch +``` +or, better, as a submodule reference: +```bash +git submodule add https://msasg.visualstudio.com/DefaultCollection/SDRG/_git/infinibatch +``` +It is now located at `infinibatch/infinibatch`, e.g. the main import file is `infinibatch/infinibatch/__init__.py`. + +To import it, you need to add that folder to your `PYTHONPATH` variable externally, or to `sys.path` inside the code: +```python +import sys +sys.path.insert(0,'infinibatch') # note: relative paths are relative to your current dir, not to the python script +import infinibatch +``` + +## Tutorial + +This little tutorial walks you through the steps of preparing your data and consuming them from Python code as batches. + +### Infinibatch Basics: Iterators and Checkpointing + +Infinibatch provides [Python iterators](https://docs.python.org/3.5/glossary.html#term-iterator) +to read your data. +An iterator represents a stream of data that can be retrieved item by item, e.g. via a +`for` loop or repeatedly calling `next()` on it. + +Infinibatch is agnostic to the data type of the items, which is determined by a user-supplied file-read function. +In NLP applications, items would typically be tuples of text. In other applications, +they can be images or an audio file with a textual annotation. + +Infinibatch makes it easy to read your data in randomized order, and supports checkpointing, which allows you to restart training exactly where you left off. + +Randomization is done _on the fly_, which means that it is not necessary to read the entire data set into memory +to be shuffled. Infinibatch implements a hierarchical shuffling algorithm +that only holds a subset of the data in RAM at any point in time. + +Infinibatch iterators are _checkpointable_. +Checkpointing lets you retrieve the current position (the "checkpoint") in the data stream at any time, so that +later, you can "rewind" to that same position. +The sad reality is that long-running trainings occasionally crash. +To be able to continue a crashed training as if it had not crashed, +save your Infinibatch iterator's checkpoint to disk whenever you save an intermediate model during training. +To restart a crashed training, reset the iterator to the saved checkpoint. +The data reader will now yield the exact same data-item sequence it would have yielded without the crash. + +### Data Preparation + +Infinibatch has one requirement on your data organization: +To use your data with Infinibatch, it must be split into a large number of small chunks. +A chunk is the smallest unit of data that is loaded from disk into RAM. Infinibatch holds a random subset of chunks in memory +that it randomly draws samples from. + +Below we want to show how such a split can be created. An easy way to split your data into chunks is with the Linux `split` command. + +In this tutorial, our "corpus" consists of 6 lines of text, where each line is one data item. +To create that corpus, please run this command in a bash shell. It create s a 6-line text file named `corpus.txt`: +```bash +echo \\ +'Lorem ipsum dolor sit amet, +consectetur adipiscing elit, +sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. +Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. +Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. +The quick brown fox jumps over the lazy dog.' \\ +> corpus.txt +``` +Now let us split it into 3 chunks of 2 lines each. Each chunk is stored as a zipped text file. +We will create them inside a new subdirectory called `corpus_chunks`: +```bash +mkdir corpus_chunks +split --lines 2 --numeric-suffixes \\ + --filter 'gzip > corpus_chunks/$FILE.txt.gz' \\ + corpus.txt corpus. +``` +This will have created three files: `corpus_chunks/corpus.00.txt.gz`, `corpus_chunks/corpus.01.txt.gz`, and `corpus_chunks/corpus.02.txt.gz`. +To verify whether the data has been split as expected, you can use this command: +```bash +zcat corpus_chunks/corpus.*.txt.gz +``` + +Hint: For large corpora, we recommend replacing `gzip` by `pigz` (`apt-get install pigz`), which runs notably faster via multi-threading. + +### Reading Items in Random Order With Infinibatch + +We will first show the easiest way to read data with Infinibatch, using the helper function `chunked_dataset_iterator``()`. +This function will create an Infinibatch iterator that yields the content of your data in random order. +Please the following program: +```python +import sys, gzip, glob +sys.path.insert(0,'infinibatch') +from infinibatch import datasets as ds + +ds = ds.chunked_dataset_iterator( + chunk_refs = glob.glob('corpus_chunks/corpus.*.txt.gz'), + read_chunk_fn = lambda path: iter(gzip.decompress(open(path, "rb") \\ + .read()).decode(encoding='utf-8') \\ + .splitlines()), + buffer_size = 6, seed = 1) + +for i in range(10): + print(next(ds)) +``` +You should get output that contains the 6 example lines in randomized order: +```text +Lorem ipsum dolor sit amet, +consectetur adipiscing elit, +Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. +Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. +The quick brown fox jumps over the lazy dog. +sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. +consectetur adipiscing elit, +Lorem ipsum dolor sit amet, +The quick brown fox jumps over the lazy dog. +sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. +``` +Note: The `buffer_size` parameter determines how many sentences are read into memory at any given time, +to draw randomized items from. In real settings with corpora of hundreds of millions of text lines, +the `buffer_size` parameter should be set in the millions. +RAM usage and startup time will be proportional to the buffer size +(but much lower than having to load the entire corpus into RAM). + +### Reading Items of Different Lengths in Batches + +For deep learning, we want to group multiple items into batches. +For NLP tasks, items are often lines of text of varying length. +Infinibatch implements an algorithm that randomizes the input sequence and groups it into +batches of approximately the same length (aka _bucketing_). + +Infinibatch's `BucketedReadaheadBatchIterator` performs this task. +It implements an algorithm modeled after the [Marian toolkit](https://github.com/marian-nmt/marian) +that preloads a large number of randomized items (typically millions; in this example: 6), +sorts them and groups them into batches of similar length, and then yields +them, in turn, in randomized order. + +Here is an example. Note that the `BucketedReadaheadBatchIterator` accepts +the previous randomized sentence sequence iterator (`ds`) as the source of items to randomize over. +This is an example how one forms pipelines of iterators with Infinibatch +(a concept familiar from Python's o wn `itertools`). +Once an iterator is passed to another as its source, consider it owned by that other iterator, +it must no longer be accessed by the calling code. +```python +import sys, gzip, glob +sys.path.insert(0,'infinibatch') +from infinibatch import datasets as ds +from infinibatch import iterators as it + +ds = ds.chunked_dataset_iterator( + chunk_refs = glob.glob('corpus_chunks/corpus.*.txt.gz'), + read_chunk_fn = lambda path: iter(gzip.decompress(open(path, "rb") \\ + .read()).decode(encoding='utf-8') \\ + .splitlines()), + buffer_size = 6, seed = 1) + +bs = it.BucketedReadaheadBatchIterator( + source_iterator = ds, # note: this is the iterator from above + read_ahead = 6, + key = lambda line: len(line), + batch_size = 2, + seed = 1) + +for i in range(25): + print(next(bs)) +``` +This code should output something like this: +```python +['sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.', + 'The quick brown fox jumps over the lazy dog.'] +['consectetur adipiscing elit,', 'Lorem ipsum dolor sit amet,'] +['Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.', + 'Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.'] +``` +followed by different permutations of the same tuples. +As you can see, the sentences are in random order and grouped in batches of 2 of approximately the same length. +You may notice that there is no variation in how the items get grouped into batches--that +is an artifact of this example, and generally not the case in real use when the data size is much larger +than the batch size. + +In NLP, sentence length often varies considerably. As a result, using batches of a fixed number of lines, +as in the example above, will waste GPU RAM and cores. +This is because the number of lines is limited by the longest possible sequence; batches of shorter lines +would leave GPU cycles on the table. +Ideally, one would use batches that have as many lines as fit into GPU RAM, +given the number of tokens of the longest line in the batch. +To support variable batch sizes, Infinibatch allows to pass a function as the `batch_size` parameter. +That function will be given the longest item of a batch and should estimate how many items of at most this length can fit. + +In our example, we assume that batches can hold at most 150 tokens. +Please change the above code as follows: +```python + batch_size = lambda longest_line: 150 // len(longest_line), +``` +The output looks like this: +``` +['consectetur adipiscing elit,', 'Lorem ipsum dolor sit amet,'] +['Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.'] +['sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.', + 'The quick brown fox jumps over the lazy dog.'] +['Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.'] +``` +That shorter sentences got grouped, while longer did not because they would exceed the total of 150 characters. + +### Reading Batches Into Numpy Arrays + +Lastly, we will need to feed batches into our favorite deep-learning tool. +We will show how to convert the batches of text lines into padded `numpy` arrays. + +In a typical NLP application, text items would be tokenized, and then each token +would be represented by an index into a unit vocabulary. +For simplicity, in this example each character is its own token, +and each token's numeric unit index is just its ASCII code. +These sequences are then padded to equal length with -1, and converted into a `numpy` array. + +Please rerun the previous example, but first insert the following code before the final `for` loop. +This example uses an Infinibatch `MapIterator`, which applies a user-supplied function or +lambda to each item: +```python +import numpy as np +def collate(lines_batch): + # tokenize all lines in the batch and map to unit ids + ids_batch = [[ord(c) for c in line] for line in lines_batch] + # create a padded numpy array as wide as the longest line, + # where shorter sequences are padded with -1 + width = max(len(ids) for ids in ids_batch) + return np.array([ids + [-1] * (width-len(ids)) for ids in ids_batch]) + +bs = it.MapIterator( + source_iterator = bs, + transform = collate) +``` +This will output batches like this. Note that in batches with multiple sentences, +some entries are padded with `-1`. +```python +[[ 99 111 110 115 101 99 116 101 116 117 114 32 97 100 105 112 105 115 + 99 105 110 103 32 101 108 105 116 44] + [ 76 111 114 101 109 32 105 112 115 117 109 32 100 111 108 111 114 32 + 115 105 116 32 97 109 101 116 44 -1]] +[[ 85 116 32 101 110 105 109 32 97 100 32 109 105 110 105 109 32 118 + 101 110 105 97 109 44 32 113 117 105 115 32 110 111 115 116 114 117 + 100 32 101 120 101 114 99 105 116 97 116 105 111 110 32 117 108 108 + 97 109 99 111 32 108 97 98 111 114 105 115 32 110 105 115 105 32 + 117 116 32 97 108 105 113 117 105 112 32 101 120 32 101 97 32 99 + 111 109 109 111 100 111 32 99 111 110 115 101 113 117 97 116 46]] +[[115 101 100 32 100 111 32 101 105 117 115 109 111 100 32 116 101 109 + 112 111 114 32 105 110 99 105 100 105 100 117 110 116 32 117 116 32 + 108 97 98 111 114 101 32 101 116 32 100 111 108 111 114 101 32 109 + 97 103 110 97 32 97 108 105 113 117 97 46] + [ 84 104 101 32 113 117 105 99 107 32 98 114 111 119 110 32 102 111 + 120 32 106 117 109 112 115 32 111 118 101 114 32 116 104 101 32 108 + 97 122 121 32 100 111 103 46 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 + -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1]] +[[ 68 117 105 115 32 97 117 116 101 32 105 114 117 114 101 32 100 111 + 108 111 114 32 105 110 32 114 101 112 114 101 104 101 110 100 101 114 + 105 116 32 105 110 32 118 111 108 117 112 116 97 116 101 32 118 101 + 108 105 116 32 101 115 115 101 32 99 105 108 108 117 109 32 100 111 + 108 111 114 101 32 101 117 32 102 117 103 105 97 116 32 110 117 108 + 108 97 32 112 97 114 105 97 116 117 114 46]] +``` + +## Where To Go From Here + +The above tutorial showed you the use of the most common iterator type, as created by the +convenience function `chunked_dataset_iterator()`. + +Not all real-life scenarios are covered by this function. For example, multi-task learning +scenarios require more complex combinations of data. To create those, you will need +to compose the necessary data reader from the underlying building blocks. +This is described at the documentation of the module `iterators`. +""" diff --git a/model/third_party/HMNet/DataLoader/infinibatch/infinibatch/closablequeue.py b/model/third_party/HMNet/DataLoader/infinibatch/infinibatch/closablequeue.py new file mode 100644 index 0000000000000000000000000000000000000000..08a2a29690f9ebacae8576f78edd4a9132413ad1 --- /dev/null +++ b/model/third_party/HMNet/DataLoader/infinibatch/infinibatch/closablequeue.py @@ -0,0 +1,68 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from collections import deque +from threading import Condition, Lock, Thread + + +class ClosedException(Exception): + pass + + +class ClosableQueue: + """ + A thread-safe queue that can be closed + + As long as the the queue is not closed, it behaves just like a thread-safe queue with a capacity limit: + - put blocks until the item can be added + - get blocks until there is an item to be returned + + Once the queue is closed, no more items can be added but existing items can be removed: + - put always raises a ClosedException + - get returns an item if the queue is not empty and otherwise raises a ClosedException + """ + + def __init__(self, maxsize: int = 1000): + self._maxsize = maxsize + self._queue = deque() + self._mutex = Lock() + self._not_empty = Condition(self._mutex) + self._not_full = Condition(self._mutex) + self._closed = False + + def put(self, item): + with self._not_full: + if self._closed: + raise ClosedException( + "This queue has been closed, no more items can be added." + ) + while len(self._queue) >= self._maxsize: + self._not_full.wait() + if self._closed: + raise ClosedException( + "This queue has been closed, no more items can be added." + ) + self._queue.append(item) + self._not_empty.notify() + + def get(self): + with self._not_empty: + if self._closed and len(self._queue) == 0: + raise ClosedException( + "This queue has been closed and is empty, no more items can be retrieved." + ) + while len(self._queue) == 0: + self._not_empty.wait() + if self._closed and len(self._queue) == 0: + raise ClosedException( + "This queue has been closed and is empty, no more items can be retrieved." + ) + item = self._queue.popleft() + self._not_full.notify() + return item + + def close(self): + with self._mutex: + self._closed = True + self._not_empty.notify_all() + self._not_full.notify_all() diff --git a/model/third_party/HMNet/DataLoader/infinibatch/infinibatch/datasets.py b/model/third_party/HMNet/DataLoader/infinibatch/infinibatch/datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..cb4191703a51b56f9e6b512df78ab838015a8257 --- /dev/null +++ b/model/third_party/HMNet/DataLoader/infinibatch/infinibatch/datasets.py @@ -0,0 +1,92 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from .iterators import ( + create_source_iterator, + SelectManyIterator, + PrefetchIterator, + BufferedShuffleIterator, + BlockwiseShuffleIterator, + MapIterator, +) +from typing import List, Union, Iterable, Iterator, Callable, Any, Optional, Dict +import os, sys + +""" +This module contains common datasets, which are implemented as convenience functions that compose underlying Infinibatch iterators. +""" + + +def bump_seed(seed: Optional[int], step=1): + """ + Helper to bump a random seed if not None. + """ + return None if seed is None else seed + 1 + + +def chunked_dataset_iterator( + chunk_refs: List, + read_chunk_fn: Callable[[Any], Iterator], + buffer_size: int, + train: bool = True, + seed: Optional[int] = None, + shuffle: bool = True, + use_windowed: bool = False, + transform: Callable[[Any], Any] = None, + prefetch: bool = True, + num_instances: int = 1, + instance_rank: int = 0, +): + """ + Dataset reading data from gzipped chunks. + + If train=True, this chunks are strided assigned to instances in strides and the data is infinitely repeated in permutations. + Otherwise, the chunks are split among the instances in consecutive blocks and the data is not repeated. + This way, when using this dataset for inference on multiple GPUs, to order the outputs in a way that corresponds + to the original order of the data items in the dataset, one simply has to collect the lists of outputs from each GPU + and then concatenate these lists in order of increasing rank. + When using MPI, this can be achieved by a gather-operation to get a list of lists of outputs, one list per GPU, + followed by flattening the lists back into a single list. + + Args: + chunk_refs: references (such as path names) to chunk files + read_chunk_fn: function(chunk_ref) -> Iterator to read a chunk's content into an iterator over its items, e.g. read a file and split into text lines + train: see above + shuffle: if true, the data is shuffled. If train is False then shuffle must be False as well. + buffer_size: size of the buffer in number of samples / data items used for shuffling (defaul t: 2**20) + transform: transform to be applied to each data item (transform(Any) -> Any) + prefetch: if True, insert a prefetch iterator with buffer_size + seed: random seed (or None) + num_instances: number of instances of this dataset. Meant for use with multi-process data loading, e.g., in distributed training. + instance_rank: rank of this instance of the dataset. Meant for use with multi-process data loading, e.g., in distributed training. + use_windowed: temporary option to switch back to the WindowedShuffleIterator (default False). Will go away once shown that we don't need it anymore. + """ + if not train and shuffle: + raise ValueError("shuffling is not supported when train=False") + # set up the chunk reader + chunk_refs = create_source_iterator( + chunk_refs, + train=train, + seed=seed, + shuffle=shuffle, + num_instances=num_instances, + instance_rank=instance_rank, + ) + # set up the item reader + samples = SelectManyIterator( + source_iterator=chunk_refs, collection_selector=read_chunk_fn + ) + # wrap the I/O operation in a prefetch iterator + if prefetch: + samples = PrefetchIterator(samples, buffer_size) + # set up the item randomizer + if shuffle: + if use_windowed: + samples = BufferedShuffleIterator(samples, buffer_size, bump_seed(seed, 1)) + else: + samples = BlockwiseShuffleIterator(samples, buffer_size, bump_seed(seed, 1)) + # apply transform, if given + if transform is not None: + samples = MapIterator(samples, transform) + # this is what we are serving out + return samples diff --git a/model/third_party/HMNet/DataLoader/infinibatch/infinibatch/iterators.py b/model/third_party/HMNet/DataLoader/infinibatch/infinibatch/iterators.py new file mode 100644 index 0000000000000000000000000000000000000000..a3be2e238ef4d561a63005ea6b18fc83001fc214 --- /dev/null +++ b/model/third_party/HMNet/DataLoader/infinibatch/infinibatch/iterators.py @@ -0,0 +1,1217 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +## Overview + +This part of the documentation covers the __advanced usage__ of Infinibatch by assembling __custom data loading pipelines__. +Before you continue, please go through the tutorial on the top-level of the documentation of the `infinibatch` module. + +Two of the main features of Infinibatch are __lazy evaluation__ through the use of __iterators__ +and built-in support for __checkpointing__. +In this section, we give an introduction to these features and the basic usage of the Infinibatch iterator library. + + +### Iterators + +As a Python programmer, you are probably familiar with the concept of iterators. +According to the [Python documentation](https://docs.python.org/3.5/glossary.html#term-iterator), +an iterator is an object representing a stream of data, +and repeated calls to the iterator's `__next__()` method (or passing it to the built-in function `next()`) +return successive items in the stream. +It is important not to confuse an [iterator](https://docs.python.org/3.5/glossary.html#term-iterator) +with an [iterable](https://docs.python.org/3.5/glossary.html#term-iterable). +For more information on this subject, please follow the links above. + +The Python standard library contains a module of iterators called `itertools` +that bears some resembles to Infinibatch. +Infinibatch differs from `itertools` in two ways: + +1. Infinibatch provides iterators specifically for the purpose of creating __randomized batches of data for machine learning__. +2. All iterators in Infinibatch support __checkpointing__ (see the following section). + +Infinibatch iterators are not directly compatible with itertools due to the checkpointing requirement. + +Infinibatch enables you to build complex data loaders by combining iterators from this module into a pipeline. +To give you a high-level idea of how this is works, we provide a very simple example. +Note that this example is completely artificial and does not solve any useful task. +Its only purpose is to demonstrate the behavior of a pipeline of iterators. +We provide a more realistic example in a later section. + +First, we create a small test data set. +>>> dataset = list(range(6)) # 0, 1, 2, 3, 4, 5 + +We can turn this data set into an Infinibatch iterator by wrapping it in a `NativeCheckpointableIterator`. +>>> it = NativeCheckpointableIterator(dataset) # 0, 1, 2, 3, 4, 5 + +We can then transform the data items using a `MapIterator`, +which applies a given function to each individual data item. +For example, we can multiply each data item by 2. +>>> it = MapIterator(it, lambda n: 2 * n) # 0, 2, 4, 6, 8, 10 + +We can restructure the data set by batching together pairs of data items into lists using a `FixedBatchIterator`. +>>> it = FixedBatchIterator(it, batch_size=2) # [0, 2], [4, 6], [8, 10] + +Using another `MapIterator`, we can reduce each of these lists to its second element. +>>> it = MapIterator(it, lambda l: l[1]) # 2, 6, 10 + +Finally, we can use the resulting iterator `it` just like any standard Python iterator. +```py +>>> for item in it: +... print(item) +2 +6 +10 + +``` + +By using iterators, Infinibatch operates in a __lazy__ fashion: +It generally doesn't apply operations to an entire data set at once, +but rather operates on individual data items on-the-fly as they are consumed. +When used correctly, this allows Infinibatch to have a low start-up time and low memory overhead. +For more detail on this, please consult the section on performance considerations below. + + +### Checkpointing + +The main features that sets Infinibatch iterators apart from standard Python iterators is that they support __checkpointing__. +A checkpoint encapsulates the internal state of an entire pipeline of iterators at a specific point while iterating through a data set. +Once you retrieve a checkpoint, you can later use it to reset the pipeline of iterators to the exact state it was in +when the checkpoint was created. +Checkpoints can easily be serialized and stored to disk using [Pythons `pickle` module](https://docs.python.org/3.5/library/pickle.html). +Infinibatch's checkpointing feature is particularly useful when you're training large deep neural network models over days or weeks, +and you want to make sure that, in case your training is interrupted for any reason, __you can pick up your training exactly where you left off__. + +The checkpointing interface consists of two functions `getstate` and `setstate` that are defined in `CheckpointableIterator`, +the common base class of all iterators in this module. +As the names suggest `getstate` returns a checkpoint object that represents the state of a pipeline at the time the function is called, +and 'setstate' receives a checkpoint object to reset the state of a pipeline. +`setstate` also accepts `None`, which resets a pipeline to the __beginning__ of the iteration, +i.e. the state of the pipeline immediately after its construction. + +It is important to realize that __a checkpoint represents the state of a complete pipeline of iterators__. +If you have a pipeline consisting of a sequence of iterators, you only have to call `getstate` on the __last__ iterator in the sequence +to capture the state of the entire pipeline. +Internally, this is achieved by recursive calls that traverse the entire data loading pipeline to collect the state of every iterator in it. +Similarly, when you want to reset a pipeline to a previous state, you only have to call `setstate` on the __last__ iterator in the pipeline. + + +To demonstrate this, we recreate the pipeline from the previous section. +>>> dataset = list(range(6)) # 0, 1, 2, 3, 4, 5 +>>> it = NativeCheckpointableIterator(dataset) # 0, 1, 2, 3, 4, 5 +>>> it = MapIterator(it, lambda n: 2 * n) # 0, 2, 4, 6, 8, 10 +>>> it = FixedBatchIterator(it, batch_size=2) # [0, 2], [4, 6], [8, 10] +>>> it = MapIterator(it, lambda l: l[1]) # 2, 6, 10 + +Since `it` behaves just like a standard Python iterator, we can call `next` to retrieve its first element. +> >> next(it) +2 + +We can now call `getstate` on `it` (which is the last `MapIterator` in the pipeline) +to get a checkpoint of the internal state of the entire data loading pipeline. +>>> checkpoint = it.getstate() + +Note that the checkpoint represents the internal state of the pipeline after the data item `2` has been retrieved. +Using the checkpoint, we can always return to this __exact__ point in the data set. +To show this, let's exhaust the iterator by casting it to a list. +>>> list(it) +[6, 10] + +Since the iterator is now exhausted, calling `next` raises a `StopIteration` exception. +``` +>>> next(it) +Traceback (most recent call last): + ... +StopIteration + +``` + +We can now reset the pipeline to the checkpoint using `setstate`. +>>> it.setstate(checkpoint) + +This recovers the state of the pipeline after the data item `2` has been retrieved. +Thereby, we expect the next element to be `6`. +>>> next(it) +6 + + +## Types of Iterators + +This section provides a brief overview of the different types of iterators in Infinibatch. + + +### Classes and Factory Functions + +Most iterators in this module are implemented as classes that inherit from the abstract base class `CheckpointableIterator`. +However, some iterators (such as the `BlockwiseShuffleIterator`) are simple combinations of other iterators. +These iterators are implemented as __factory functions__ that construct a pipeline of iterators +and return the last iterator in the pipeline. +For consistency with class-based iterators, +we name these factory function using CamelCase instead of the more pythonic use_of_underscores. + +.. todo:: + We currently also have one factory function that actually looks like one: `create_source_iterator`. + Provide a comment on this describing why that is. + + +### Source Iterators + +There are three iterators that are intended to go at the __beginning__ of a data loading pipeline: + +- `InfinitePermutationSourceIterator`: +This iterator accepts a list, shuffles it, and yields its elements. +It repeats this infinitely, shuffling the list after each pass. +Thereby, __this iterator is infinte and cannot be exhausted__. +This iterator is meant to be used as the first iterator in a training scenario +and supports splitting the data for multi-GPU training. +- `ChunkedSourceIterator`: +This iterator accepts a list and yields its elements. +It is meant to be used as the first iterator in an inference or validation scenario +and supports splitting the data for mult-GPU inference. +- `NativeCheckpointableIterator`: +This iterator wraps a Python iterable and makes it checkpointable. +It is mainly intended for demonstration and debugging purposes. + + +### Shuffling + +.. todo:: Describe `BufferedShuffleIterator` and `BlockwiseShuffleIterator`. + + +### Batching, SelectMany, and Windowing + +.. todo:: Describe `FixedBatchIterator`, `SelectManyIterator`, and `WindowedIterator`. + + +### Mapping + +.. todo:: Describe `MapIterator`, `ParallelMapIterator`, `RecurrentIterator`, and `SamplingRandomMapIterator`. + + +### Other Iterators + +.. todo:: Describe `ZipIterator`, `PrefetchIterator`, and `BucketedReadaheadBatchIterator`. + + +## Complete Example + +.. todo:: + Give a more realistic example following, in broad strokes, the ChunkedDataset including: + + - use gzip chunks + - training pipeline example + - inference pipeline example + - pipeline that can do both + - etc. + +## Performance Considerations + +.. todo:: + Describe what parameters influence performance measures such as memory usage and start-up time. +""" + +from abc import abstractmethod +import collections +import copy +import gzip +from itertools import cycle, islice +import math +from multiprocessing import Pool +import os +from queue import Full, Queue +from random import Random +from threading import Thread +from typing import ( + Any, + Callable, + Dict, + Generator, + Iterable, + Iterator, + List, + Optional, + Tuple, + Union, +) + + +from .closablequeue import ClosableQueue, ClosedException + + +# TODO for n ext release: +# - benchmark the accuracy when using BlockwiseShuffleIterator vs. the BufferedShuffleIterator +# - change all convenience functions back to true classes, using a wrapper class + +# TODO later: +# - make iterator pipeline work for streaming data + + +def _advance_iterator(iterator: Iterator, n: int): + """Little helper to advance an iterator by n items""" + for _ in range(n): + next(iterator) + return n + + +class CheckpointableIterator(collections.abc.Iterator): + """ + Abstract base class that defines the interface for checkpointing. + + The interface (getstate, setstate) is inspired by Python's random package. + """ + + def __iter__(self): + return self + + @abstractmethod + def getstate(self) -> Dict: + """ + Get checkpoint of current state of iterator + + In a pipeline of iterators, this function __recursively__ calls itself on the preceeding iterator + and includes the gathered information in the returned checkpoint. + Thereby, to obtain a checkpoint of the state of an entire pipeline of iterators + you only have to call this function on the __last__ iterator in the pipeline. + A checkpoint is represented as a `dict`, + but the caller should treat a checkpoint as an opaque object + and not make any assumptions about the existence or meaning of the `dict` entries. + """ + pass + + @abstractmethod + def setstate(self, checkpoint: Optional[Dict]): + """ + Set state of iterator to given checkpoint + + In a pipeline of iterators, this function __recursively__ calls itself on the preceeding iterator. + Thereby, to set the state of an entire pipeline of iterators to a given checkpoint + you only have to call this function on the __last__ iterator in the pipeline. + + Args: + checkpoint: Checkpoint that should be used to reset the state of the iterator (or pipeline). + If this is __None__, the state of the iterator (or pipeline) is reset to the initial + state immediately after construction. + """ + pass + + def __getstate__(self) -> Dict: # implementation of pickle Protocol + return self.getstate() + + def __setstate__(self, checkpoint: Optional[Dict]): + self.setstate(checkpoint) + + @abstractmethod + def __next__(self): + pass + + +class NativeCheckpointableIterator(CheckpointableIterator): + """ + Simple wrapper class that turns a Python Iterable into a CheckpointableIterator + + When calling setstate on this class, it simply replays the iterator all the way to the checkpoint one element at a time, + which makes it generally inefficient. + + Warning: This class cannot be used with Iterators (as opposed to Iterables), which have an `__iter__` function that simply returns self, but does not reset. + """ + + def __init__(self, iterable: Iterable): + # check whether iterable is iterable or iterator: + # if the variable iterable contains an iterator, the function __iter__ returns self + # if the variable iterable is an actual iterator, it should not return self + if iter(iterable) is iterable: + raise ValueError( + "It looks like you are passing an iterator instead of an iterable. This is not supported and can cause undefined behavior when used with checkpointing." + ) + self._input_iterable = iterable + self.setstate(None) + + def getstate(self) -> Dict: + return {"num_items_yielded": self._num_items_yielded} + + def setstate(self, checkpoint: Optional[Dict]): + self._iterator = iter(self._input_iterable) + self._num_items_yielded = ( + _advance_iterator(self._iterator, checkpoint["num_items_yielded"]) + if checkpoint is not None + else 0 + ) + + def __next__(self): + item = next( + self._iterator + ) # call this before increasing _num_items_yiel ded to correctly handle the case when a StopIteration exception is thrown + self._num_items_yielded += 1 + return item + + +def create_source_iterator( + source_items: List, + train: bool = True, + seed: Optional[int] = None, + shuffle: bool = True, + num_instances: int = 1, + instance_rank: int = 0, +): + if not train and shuffle: + raise ValueError("shuffling is not supported when train=False") + if train: + return InfinitePermutationSourceIterator( + source_items, + seed=seed, + shuffle=shuffle, + num_instances=num_instances, + instance_rank=instance_rank, + ) + else: + return ChunkedSourceIterator( + source_items, num_instances=num_instances, instance_rank=instance_rank + ) + + +def ChunkedSourceIterator( + source_items: List, num_instances: int = 1, instance_rank: int = 0 +): + """ + Cuts source list into chunks, one per instance, and serves out items in chunk corresponding to instance_rank + + This is a source iterator: + It is meant to be used at the beginning of a data loading pipeline. + As such, it takes a list as its source and not a CheckpointableIterator. + + Args: + source_items: input list, must not be empty and must be small enough to fit into RAM entirely, ownership of the list and the data goes to the iterator, do not modify it! + num_instances: number of instances of this iterator. Meant for use with multi-process data loading, e.g., in distributed training. + instance_rank: rank of this instance of the iterator. Meant for use with multi-process data loading, e.g., in distributed training. + """ + # heuristic: assuming blocks are all of the same size, math.ceil should give us the shortest makespan + chunk_size = math.ceil(len(source_items) / num_instances) + # this does not cause any out-of-bounds issues: + # a slice with a start-index beyong the end of the list is empty, + # and an end-index of a slice is capped at the end of the list + chunk = source_items[instance_rank * chunk_size : (instance_rank + 1) * chunk_size] + return NativeCheckpointableIterator(chunk) + + +class InfinitePermutationSourceIterator(CheckpointableIterator): + """ + Infinitely generates permutations of the items in the given list. + + This is a source iterator: + It is meant to be used at the beginning of a data loading pipeline. + As such, it takes a list as its source and not a CheckpointableIterator. + The given list is loaded completely into RAM. + + For example, this is used for randomizing the pathnames of data blocks read by ChunkedReadlinesIterator. + """ + + def __init__( + self, + source_items: List, + seed: Optional[int] = None, + shuffle: bool = True, + num_instances: int = 1, + instance_rank: int = 0, + ): + """ + Args: + source_items: input list, must not be empty and must be small enough to fit into RAM entirely, ownership of the list and the data goes to the iterator, do not modify it! + seed: random seed used for shuffling (or None) + shuffle: set False to bypass the shuffling. Then this is just a checkpointed version of itertools.cycle(). (Default: True) + num_instances: number of instances of this iterator. Meant for use with multi-process data loading, e.g., in distributed training. + instance_rank: rank of this instance of the iterator. Meant for use with multi-process data loading, e.g., in distributed training. + """ + self._source_items = source_items + if not self._source_items: + raise ValueError("InfinitePermutationIterator: source must not be empty") + self._shuffle = shuffle + self._seed = seed + self._num_instances = num_instances + self._instance_rank = instance_rank + self.setstate(None) + + def getstate(self) -> Dict: + return { + "random_state": self._random_state, # state of random generator before generating the current shuffling of the sequence + "num_items_yielded": self._num_items_yielded, + } # how many items have already been iterated over in the current shuffling + + def setstate(self, checkpoint: Optional[Dict]): + # set iteration state. Do this outside the generator below in case getstate() is called before ever iterating + self._random_state = checkpoint["random_state"] if checkpoint else None + self._num_items_yielded = checkpoint["num_items_yielded"] if checkpoint else 0 + # We define the iteration itself as a generator for ease of implementation. + # We could as well just have used an explicit state machine represented by class members. + def _generate() -> Iterator: + # create and reset random generator + random = Random(self._seed) + if self._random_state is not None: # restore the random generator's state + random.setstate(self._random_state) + skip_to_checkpoint = ( + self._num_items_yielded + ) # items to skip in order to advance to checkpoint + # main outer loop for infinite passes over items (reshuffle before each pass) + while True: + # (re-)shuffle all items + self._random_state = ( + random.getstate() + ) # remember random state before shuffling + self._num_items_yielded = 0 + shuffled_items = self._source_items[ + : + ] # note: if underlying iterator is checkpointable, use setstate(checkpoint['nested_state']) on it + if self._shuffle: + random.shuffle(shuffled_items) + shuffled_iterator = iter(shuffled_items) + # skip initial items when restarting from checkpoint + if ( + skip_to_checkpoint + ): # @TODO: find a way to abstract this more, so that we can plug it into the 'for' statement directly + self._num_items_yielded += _advance_iterator( + shuffled_iterator, skip_to_checkpoint + ) + skip_to_checkpoint = 0 # done skipping + # main inner loop over items + for item in shuffled_iterator: + self._num_items_yielded += 1 # record how many items we have iterated over in this pass over the items + if ( + self._num_items_yielded - 1 + ) % self._num_instances == self._instance_rank: # build-in islice facility + yield item + + self._iterator = _generate() + + def __next__(self): + return next(self._iterator) + + +class SelectManyIterator(CheckpointableIterator): + """ + Projects each element of a source sequence to a sequence and flattens the resulting sequences into one sequence. + """ + + def __init__( + self, + source_iterator: CheckpointableIterator, + collection_selector: Optional[Callable[[Any], Iterator]] = None, + ): + """ + Args: + source_iterator: iterator over the items to pass to collection_selector() + collection_selector: user callback that maps an item into an Iterable, whose items will be yielded. + The returned Iterator is used only once. Hence, it is also allowed to + return self-iterables, such as iterators and generator expressions. + If None is given, no callback is applied. + """ + if not isinstance(source_iterator, CheckpointableIterator): + raise ValueError("source_iterator has to be a CheckpointableIterator") + self._source_iterator = source_iterator # type: CheckpointableIterator + self._collection_selector = ( + collection_selector + ) # type: Callable[[Any], Iterator] + self.setstate(None) + + def getstate(self) -> Dict: + return { + "source_state": self._source_state, + "flattened_items_yielded": self._flattened_items_yielded, + } + + def setstate(self, checkpoint: Optional[Dict]): + self._source_state = checkpoint["source_state"] if checkpoint else None + self._flattened_items_yielded = ( + checkpoint["flattened_items_yielded"] if checkpoint else 0 + ) + self._source_iterator.setstate(self._source_state) + + def _generate(): + skip_to_checkpoint = self._flattened_items_yielded + # main loop over source source_items + for source_item in self._source_iterator: + if self._collection_selector is not None: + data = iter(self._collection_selector(source_item)) + else: + data = iter(source_item) + self._flattened_items_yielded = 0 + if skip_to_checkpoint: + # print("Skipping to index", skip_to_checkpoint, file=sys.stderr) + self._flattened_items_yielded += _advance_iterator( + data, skip_to_checkpoint + ) + skip_to_checkpoint = 0 + # main loop over lines + for item in data: + self._flattened_items_yielded += 1 + yield item + self._source_state = self._source_iterator.getstate() + + self._iterator = _generate() + + def __next__(self): + return next(self._iterator) + + +class BufferedShuffleIterator(CheckpointableIterator): + """ + Shuffles given iterable using a limited buffer. + """ + + def __init__( + self, source_iterator: CheckpointableIterator, buffer_size: int, seed: int = 0 + ): + """ + Args: + source_iterator: checkpointable iterator or restartable iterable over input items to shuffle + buffer_size: size of the buffer in number of items used for shuffling + seed: random seed used for shuffling (or None) + """ + if not isinstance(source_iterator, CheckpointableIterator): + raise ValueError("source_iterator has to be a CheckpointableIterator") + self._source_iterator = source_iterator + self._buffer = [ + None for _ in range(buffer_size) + ] # maybe do this lazily? --Yes, since user may set state immediately, then this is not needed here + self._random = Random(seed) + self.setstate(None) + + def getstate(self) -> Dict: + return { + "source_state": self._source_iterator.getstate(), + "buffer": copy.deepcopy(self._buffer), + "random_state": self._random.getstate(), + } + + def setstate(self, checkpoint: Optional[Dict]): + if checkpoint: + self._source_iterator.setstate(checkpoint["source_state"]) + self._buffer = checkpoint["buffer"] + self._random.setstate(checkpoint["random_state"]) + # @TODO: Can we add a comment how the flush part is handled? + else: + self._source_iterator.setstate(None) + self._iterator = self._generate() + + def _generate(self) -> Iterator: + # shuffle data with a buffer: + # this is similar to what the Fisher-Yates shuffle does, + # but modified to run with a constant-size buffer + # see https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle + # this was inspired by an algorithm implemented in Kaldi + # see https://kaldi-asr.org/doc/nnet-shuffle-egs_8cc.html + for item in self._source_iterator: + index = self._random.randrange(0, len(self._buffer)) + result = None + if self._buffer[index] is not None: + result = self._buffer[index] + self._buffer[index] = item + # only yield value once buffer is updated to allow for correct checkpointing! + if result i s not None: + yield result + + # flush buffer + while self._buffer: + item = self._buffer.pop() + if item is not None: + yield item + + def __next__(self): + return next(self._iterator) + + +class MapIterator(CheckpointableIterator): + """ + Applies given tranform to each data item + """ + + def __init__( + self, source_iterator: CheckpointableIterator, transform: Callable[[str], Any] + ): + """ + Args: + source_iterator: checkpointable iterator + transform: function to be applied to each data item + """ + if not isinstance(source_iterator, CheckpointableIterator): + raise ValueError("source_iterator has to be a CheckpointableIterator") + self._source_iterator = source_iterator + self._transform = transform + + def getstate(self) -> Dict: + return self._source_iterator.getstate() + + def setstate(self, checkpoint: Optional[Dict]): + self._source_iterator.setstate(checkpoint) + + def __next__(self): + return self._transform(next(self._source_iterator)) + + +def ParallelMapIterator( + source_iterator: CheckpointableIterator, + transform: Callable[[str], Any], + num_processes: int, + num_items_per_process: int, +): + """ + Applies given transform to each data item + + Behaves the same as MapIterator, but applies transform in parallel using multiple processes in a parallel map operation. + + Warning: + The transform function has to be pickleable because it is sent across process boundaries. + To achieve this, transform should be a top-level function. + + Args: + source_iterator: checkpointable iterator + transform: function to be applied to each data item, has to be pickleable, see above + num_processes: number of processes to use for parallel map + num_items_per_process: number of data items each process operates on + """ + # divide stream of data items into batches + batched_samples = FixedBatchIterator( + source_iterator, num_processes * num_items_per_process + ) + # create process pool and capture it in closure that performs parallel map + p = Pool(num_processes) + + def parallel_map_transform(buffer): + return p.map(transform, buffer) + + # apply transform in parallel to data items in a batch + batched_transformed_samples = MapIterator(batched_samples, parallel_map_transform) + # unpack batches to go back to stream of (now transformed) data items + transformed_samples = SelectManyIterator(batched_transformed_samples) + return transformed_samples + + +class ZipIterator(CheckpointableIterator): + """ + Zips items from all given iterators, like the Python standard function zip(). + + Like Python's build-in zip(), the iteration stops when the shortest input iterable is exhausted. + """ + + def __init__(self, *source_iterators: CheckpointableIterator): + """ + Args: + source_iterators: list of iterators to zip, item by item + """ + for source_iterator in source_iterators: + if not isinstance(source_iterator, CheckpointableIterator): + raise ValueError( + "all iterators in source_iterators have to be CheckpointableIterator" + ) + self._source_iterators = source_iterators # type: List[CheckpointableIterator] + + def getstate(self) -> Dict: + return { + "input_states": tuple( + iterator.getstate() for iterator in self._source_iterators + ) + } + + def setstate(self, checkpoint: Optional[Dict]): + if checkpoint is None: + for iterator in self._source_iterators: + iterator.setstate(None) + else: + for iterator, state in zip( + self._source_iterators, checkpoint["input_states"] + ): + iterator.setstate(state) + + def __next__(self): + res = ( + [] + ) # (note: can't use a generator expression, as it gets confused when a next() call raises StopIteration) + for iterator in self._source_iterators: + res.append(next(iterator)) + return tuple(res) + + +# @TODO: The yield makes a (shallow) copy of the window, which has complexity O(width * length). In some cases, +# we don't actually need to consume all items in the window. Hence, to make this faster, we should use +# double-buffering and return a slice view (which we'd have to write). +class WindowedIterator(CheckpointableIterator): + """ + Yields 'width' consecutive items in a sliding window. + + E.g. [1, 2, 3, 4, 5, 6] with width = 3 will yield + [[1, 2, 3], [2, 3, 4], [3, 4, 5], [4, 5, 6]] + """ + + def __init__(self, source_iterator: CheckpointableIterator, width: int): + """ + Args: + source_iterator: checkpointable input iterators + """ + if not isinstance(source_iterator, CheckpointableIterator): + raise ValueError("source_iterator has to be a CheckpointableIterator") + self._source_iterator = source_iterator # type: CheckpointableIterator + self._width = width # type: int + self.setstate(None) + + def getstate(self) -> Dict: + return { + "source_state": self._source_state, # state for first item in FIFO + "item_index": self._item_index, + } # index of next item to serve + + def setstate(self, checkpoint: Optional[Dict]): + self._source_state = checkpoint["source_state"] if checkpoint else None + self._item_index = checkpoint["item_index"] if checkpoint else 0 + self._source_iterator.setstate(self._source_state) + self._iterator = self._generate() + + def _fifo_slice(self, i): # returns a window into the FIFO beginning at i + # @TODO: for efficiency, make this a slice view + return tuple(self._fifo[i : i + self._width]) + + def _generate(self) -> Iterator: + self._source_state = self._source_iterator.getstate() + self._fifo = list(islice(self._source_iterator, self._width)) + # we do this in overlapping blocks of length 2*width, for easier checkpointing and potential efficiency + while len(self._fifo) == self._width: + # we got 'width' items; append another 'width' (or less if at end) + next_input_state = self._source_iterator.getstate() + self._fifo.extend(islice(self._source_iterator, self._width)) + # now serve all positions in first half (last = width - 1). If at end, then limit accordingly. + last = min(self._width - 1, len(self._fifo) - self._width) + while self._item_index <= last: + window = self._fifo_slice(self._item_index) + self._item_index += 1 + yield window + # drop all we just served; if < width left, we have hit the end + self._fifo = self._fifo[ + last + 1 : + ] # Note: This must be a new list, since the old might still be in a slice view. + self._source_state = ( + next_input_state # this reflects now the first element in the FIFO + ) + self._item_index = 0 + + def __next__(self): + return next(self._iterator) + + +# @TODO: research on whether this operation has a well-known name +class FixedBatchIterator(CheckpointableIterator): + """ + Batches N consecutive items into a single item that is a list of these items. + + E.g. [1, 2, 3 4, 5, 6, 7, 8] with batch_size = 3 will yield + [(1, 2, 3), (4, 5, 6), (7, 8)] + """ + + def __init__(self, source_iterator: CheckpointableIterator, batch_size: int): + """ + Args: + source_iterator: checkpointable input iterators + batch_size: number of items per batch + """ + if not isinstance(source_iterator, CheckpointableIterator): + raise ValueError("source_iterator has to be a CheckpointableIterator") + self._source_iterator = source_iterator # type: CheckpointableIterator + self._batch_size = batch_size # type: int + self.setstate(None) + + def getstate(self) -> Dict: + return { + "source_state": self._source_iterator.getstate() + } # state for first item in next batch + + def setstate(self, checkpoint: Optional[Dict]): + self._source_state = checkpoint["source_state"] if checkpoint else None + self._source_iterator.setstate(self._source_state) + self._iterator = self._generate() + + def _generate(self) -> Iterator: + while True: + batch = list(islice(self._source_iterator, self._batch_size)) + if not batch: + break + yield batch + + def __next__(self): + return next(self._iterator) + + +class RandomIterator(CheckpointableIterator): + """ + Iterator to generate uniformly distributed random numbers in the interval [0,1). + Very similar to Random.random(), except that random numbers are + obtained via next(). + """ + + def __init__(self, seed: Optional[int] = None): + """ + Args: + seed: Random seed. + """ + self._random = Random() # type: Random + if seed is not None: + self._random.seed(seed) + + def getstate(self) -> Dict: + return {"random_state": self._random.getstate()} + + def setstate(self, checkpoint: Optional[Dict]): + self._random.setstate(checkpoint["random_state"] if checkpoint else None) + + def __next__(self): + return self._random.random() + + +class RecurrentIterator(CheckpointableIterator): + """ + Iterates statefully over a step function. The step function accepts a state and a new item, + and returns a new state and an output item, which is yielded. + """ + + def __init__( + self, + source_iterator: CheckpointableIterator, + step_function: Callable[[Any, Any], Tuple[Any, Any]], + initial_state: Any = None, + ): + """ + Args: + source_iterator: checkpointable iterator to recur over + step_function: user-supplied function with signature step_function(state, item) -> (new_state, output) + initial_state: initial state to be passed to the step_function upon first invocation + """ + if not isinstance(source_iterator, CheckpointableIterator): + raise ValueError("source_iterator has to be a CheckpointableIterator") + self._source_iterator = source_iterator # type: CheckpointableIterator + self._step_function = step_function # type: Callable[[Any,Any], Tuple[Any,Any]] + self._initial_state = initial_state # type: Any + self.setstate(None) + + def getstate(self): + return { + "recurrent_state": self._recurrent_state, + "source_state": self._source_iterator.getstate(), + } + + def setstate(self, checkpoint): + self._recurrent_state = ( + checkpoint["recurrent_state"] if checkpoint else self._initial_state + ) + self._source_iterator.setstate( + checkpoint["source_state"] if checkpoint else None + ) + + def _generate(): + for item in self._source_iterator: + self._recurrent_state, output = self._step_function( + self._recurrent_state, item + ) + yield output + + self._iterator = _generate() + + def __next__(self): + return next(self._iterator) + + +def SamplingRandomMapIterator( + source_iterator: CheckpointableIterator, + transform: Callable[[Random, Any], Any], + seed: Optional[int] = None, +): + """ + An iterator that calls a transform function on each item, while also passing a checkpointed + random generator. + + Args: + source_iterator: checkpointable iterator to recur over + step_function: user-supplied function with signature step_function(random, item) -> result_item + seed: random seed + """ + _random = Random() + if seed is not None: + _random.seed(seed) + + def _step_function(state, item): + _random.setstate(state) + output = transform(_random, item) + return _random.getstate(), output + + return RecurrentIterator( + source_iterator, _step_function, initial_state=_random.getstate() + ) + + +def BlockwiseShuffleIterator( + source_iterator: CheckpointableIterator, block_size: int, seed: int = 0 +): + """ + Shuffles a sequence of items by grouping consecutive items in blocks of fixed size, shuffling + each block, and yielding the shuffled items of all blocks as a flat sequence. + + E.g. [1, 2, 3, 4, 5, 6, 7, 8] with block_size = 3 may yield [3, 1, 2, 4, 6, 5, 8, 7]. + + Args: + source_iterator: checkpointable iterator or restartable iterable over input items to shuffle + block_size: size of the buffer in number of items used for shuffling + seed: random seed used for shuffling (or None) + """ + # This is implemented as a pipeline: + # - group N consecutive items together + # - shuffle them + # - flatten the result + blocks = FixedBatchIterator(source_iterator, batch_size=block_size) + + def shuffle_block_fn(random: Random, block: List): + random.shuffle(block) + return block + + shuffled_blocks = SamplingRandomMapIterator( + blocks, transform=shuffle_block_fn, seed=seed + ) + samples = SelectManyIterator( + shuffled_blocks, collection_selector=lambda shuffled_block: iter(shuffled_block) + ) + return samples + + +class PrefetchIterator(CheckpointableIterator): + """ + An iterator prefetching data into a buffer on a seperate thread to smooth out IO latency. + + Args: + source_iterator: checkpointable iterator to recur over + buffer_size: size of the queue between the threads + """ + + def __init__( + self, source_iterator: CheckpointableIterator, buffer_size: int = 1000 + ): + if not isinstance(source_iterator, CheckpointableIterator): + raise ValueError("source_iterator has to be a CheckpointableIterator") + self._source_iterator = source_iterator # type:CheckpointableIterator + self._buffer_size = buffer_size # type: int + self._queue = None # type: Optional[ClosableQueue] + self._thread = None # type: Optional[Thread] + self.setstate(None) + + def getstate(self) -> Dict: + return {"source_state": self._source_state, "item_offset": self._item_offset} + + def setstate(self, checkpoint: Optional[Dict]): + if ( + self._thread is not None + ): # if there is a prefetching thread running, close the queue and wait for the thread to terminate + assert self._queue is not None + self._queue.close() + self._thread.join() + + self._source_state = ( + checkpoint["source_state"] if checkpoint is not None else None + ) + self._item_offset = checkpoint["item_offset"] if checkpoint is not None else 0 + + self._source_iterator.setstate(self._source_state) + + self._queue = ClosableQueue(maxsize=self._buffer_size) # clear queue + # make thread daemonic so it is killed when the main program terminates + self._thread = Thread( + target=self._prefetch_thread_fn, + args=( + self._source_iterator, + self._item_offset, + self._buffer_size, + self._queue, + ), + daemon=True, + ) + self._thread.start() + + @staticmethod + def _prefetch_thread_fn( + source, item_offset, buffer_size, queue + ): # behavior of the prefetching thread, only call from that thread! + _advance_iterator(source, item_offset) # skip to checkpoint + + while True: + try: + item = next(source) + except StopIteration: + queue.close() + return + + if ( + item_offset == buffer_size - 1 + ): # send a new source state a the END of each window of length _buffer_size + source_state = ( + source.getstate() + ) # this is the state for retrieving the NEXT element, i.e. the first element of the next buffer + item_offset = 0 + else: + source_state = None + item_offset += 1 + msg = (item, source_state) + + try: + queue.put(msg) + except ClosedException: + return + + def __next__(self): + try: + msg = self._queue.get() + except ClosedException: + raise StopIteration + + item, prefetch_source_state = msg + if prefetch_source_state is not None: + assert ( + self._item_offset == self._buffer_size - 1 + ) # we expect a new source state at then END of each window of length _buffer_size + self._source_state = prefetch_source_state + self._item_offset = 0 + else: + self._item_offset = self._item_offset + 1 + assert self._item_offset < self._buffer_size + return item # for debugging, its useful to return msg instead of item + + def __del__( + self, + ): # note: this is often not called. If you really need it, gc.collect() will do the trick. + if self._thread is not None: + assert self._queue is not None + self._queue.close() + try: + self._thread.join() + except: + pass + + +class BucketedReadaheadBatchIterator(CheckpointableIterator): + """ + Iterates over items from a checkpointable iterator and groups items of similar length into batches. + + The algorithm reads a head a certain number of lines (e.g. 10 million), sorts them by + length, and them groups them into batches from start to end. The sort is stable, such + that prior randomization is not undone (except for the length grouping). The batch size + is dynamic, and determined by a user-provided callback. + + This is based on Marian NMT's BatchGenerator. + """ + + def __init__( + self, + source_iterator: CheckpointableIterator, + read_ahead: int, + key: Callable[[Any], Any], + batch_size: Union[int, Callable[[Any], int]], + shuffle: bool = True, + seed: Optional[int] = None, + ): + """ + Args: + source_iterator: The data set that is read from. Typically this is an infinite source. + read_ahead: Number of items to fetch ahead for grouping purposes. + key: User-provided callback to define how data is sorted for purpose of batching. + batch_size: Batch size in number of items. Either an integer or a callback to determine batch size for a given first batch item. + shuffle: Pass False to not randomize the batches. (default: True) + seed: Random seed for batch shuffling. + """ + if not isinstance(source_iterator, CheckpointableIterator): + raise ValueError("source_iterator has to be a CheckpointableIterator") + # keep arguments + self._key = key # type: Callable[[Any], Any] + self._batch_size = batch_size # type: Union[int,Callable[[Any], int]] + self._read_ahead = read_ahead # type: int + # initialize state + self._random = None + if shuffle: + self._random = Random() # type: Random + if seed is not None: + self._random.seed(seed) + self._source_iterator = iter(source_iterator) # type: CheckpointableIterator + self.setstate(None) + + def getstate(self): + return { + "source_state": self._source_state, + "random_state": self._random_state, + "num_served": self._num_batches_yielded, + } + + def setstate(self, checkpoint: Optional[Dict]): + self._source_state = ( + checkpoint["source_state"] if checkpoint else None + ) # type: Dict -- state of input before reading the current set of batches + self._random_state = ( + checkpoint["random_state"] if checkpoint else None + ) # type: Any -- state of random generator at _source_state + self._num_batches_yielded = ( + checkpoint["num_served"] if checkpoint else 0 + ) # type: int -- number of batches served from the current set of batches + # checkpointing: restore to start of current set of batches + self._source_iterator.setstate(self._source_state) + if self._random_state: + self._random.setstate(self._random_state) + self._source_exhausted = ( + False + ) # type: bool -- set to True once we hit StopIteration on source + + def _generate(): + skip_to_checkpoint = self._num_batches_yielded + source_exhausted = False + while not source_exhausted: + # prefetch the readahead buffer + self._source_state = self._source_iterator.getstate() + self._random_state = self._random.getstate() if self._random else None + items = list(islice(self._source_iterator, self._read_ahead)) + source_exhausted = len(items) < self._read_ahead + # create batches + batches = self._create_batches(items) + # shuffle the batches + if self._random: + self._random.shuffle(batches) + # on first loop iteration, restore iterator inside batches from checkpoint + batches = iter(batches) + self._num_batches_yielded = _advance_iterator( + batches, skip_to_checkpoint + ) + skip_to_checkpoint = 0 + # main loop over batches in current read-ahead section + for batch in batches: + self._num_batches_yielded += 1 + yield batch + + self._iterator = ( + _generate() + ) # type: Iterator -- iterator into current set of batches + + def _create_batches( + self, items: List[Any] + ) -> List[List[Any]]: # helper to form batches from a list of items + # sort by length, longest first + if self._key: + items.sort( + key=self._key, reverse=True + ) # note: sort() is stable, so we won't undo any randomization besides the bucketing + # group into batches + cur_batch = None + batches = [] + for item in items: + if not cur_batch: + batch_size = ( + self._batch_size + if isinstance(self._batch_size, int) + else self._batch_size(item) + ) + cur_batch = [] + cur_batch.append(item) + if len(cur_batch) >= batch_size: # this batch is full + batches.append(cur_batch) + cur_batch = None + if cur_batch: + batches.append(cur_batch) + return batches + + def __next__(self): + return next(self._iterator) diff --git a/model/third_party/HMNet/DataLoader/infinibatch/infinibatch/torch/__init__.py b/model/third_party/HMNet/DataLoader/infinibatch/infinibatch/torch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/model/third_party/HMNet/DataLoader/infinibatch/infinibatch/torch/data.py b/model/third_party/HMNet/DataLoader/infinibatch/infinibatch/torch/data.py new file mode 100644 index 0000000000000000000000000000000000000000..2b2d91a4b64d2d8f484d0adefd514889748218b9 --- /dev/null +++ b/model/third_party/HMNet/DataLoader/infinibatch/infinibatch/torch/data.py @@ -0,0 +1,65 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import torch +from infinibatch.iterators import CheckpointableIterator +from infinibatch .datasets import chunked_dataset_iterator +from typing import Union, Iterable, Any + + +# @TODO: This has been tested once, but we have no regression test presently. I am worried tests will fail if Torch is not installed. +class IterableCheckpointedDataset(torch.utils.data.IterableDataset): + """ + Wraps a CheckpointableIterator into a PyTorch IterableDataset, which is recognized by its type by + PyTorch's DataLoader class. + """ + + def __init__(self, source: CheckpointableIterator): + super().__init__() + self._source = source + + def __iter__(self): # this is called in the forked clone + worker_info = torch.utils.data.get_worker_info() + assert ( + worker_info is None or worker_info.num_workers == 1 + ) # not supported since we can't get at the checkpoint for each worker + return iter(self._source) + + +# @TODO: This is currently untested, and may not work presently. +class IterableChunkedDataset(torch.utils.data.IterableDataset): + def __init__( + self, + paths: Union[str, Iterable[str]], + shuffle: bool = True, + buffer_size: int = 2 ** 20, + transform=None, + seed: int = None, + world_size: int = 1, + rank: int = 0, + num_workers_per_rank: int = 1, + ): + super().__init__() + self.rank = rank + self.num_workers_per_rank = num_workers_per_rank + # instance_rank is set assuming that num_workers_per_rank = 1 and adapted dynamically in __iter__ + self.dataset = chunked_dataset_iterator( + paths, + shuffle=shuffle, + buffer_size=buffer_size, + transform=transform, + seed=seed, + num_instances=world_size * num_workers_per_rank, + instance_rank=rank, + ) + + def __iter__(self): + worker_info = torch.utils.data.get_worker_info() + if worker_info is None: # single-process data loading + self.dataset._instance_rank = self.rank + else: + assert worker_info.num_workers == self.num_workers_per_rank + self.dataset._instance_rank = ( + self.rank * self.num_workers_per_rank + worker_info.id + ) + return iter(self.dataset) diff --git a/model/third_party/HMNet/DataLoader/infinibatch/requirements.txt b/model/third_party/HMNet/DataLoader/infinibatch/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/model/third_party/HMNet/DataLoader/infinibatch/setup.py b/model/third_party/HMNet/DataLoader/infinibatch/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/model/third_party/HMNet/DataLoader/infinibatch/test/test_closablequeue.py b/model/third_party/HMNet/DataLoader/infinibatch/test/test_closablequeue.py new file mode 100644 index 0000000000000000000000000000000000000000..440db98370df2f09e80dcd29574cb3165f57107c --- /dev/null +++ b/model/third_party/HMNet/DataLoader/infinibatch/test/test_closablequeue.py @@ -0,0 +1,42 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from threading import Thread +import unittest + +from infinibatch.closablequeue import ClosableQueue, ClosedException + + +class TestClosableQueue(unittest.TestCase): + def setUp(self): + self.queue = ClosableQueue(maxsize=10) + + def put_items(self, items, close=False): + for item in items: + self.queue.put(item) + if close: + self.queue.close() + + def get_items(self, num_items): + return [self.queue.get() for _ in range(num_items)] + + def test_basic(self): + self.put_items(range(10)) + self.assertListEqual(self.get_items(10), list(range(10))) + + def test_closed_put(self): + self.queue.close() + self.assertRaises(ClosedException, self.queue.put, 42) + + def test_closed_get(self): + self.put_items(range(10)) + self.queue.cl ose() + self.assertListEqual(self.get_items(10), list(range(10))) + self.assertRaises(ClosedException, self.queue.get) + + def test_basic_two_threads(self): + thread = Thread(target=self.put_items, args=(range(20),)) + thread.start() + result = self.get_items(20) + thread.join() + self.assertListEqual(result, list(range(20))) diff --git a/model/third_party/HMNet/DataLoader/infinibatch/test/test_doctests.py b/model/third_party/HMNet/DataLoader/infinibatch/test/test_doctests.py new file mode 100644 index 0000000000000000000000000000000000000000..49d2bfa6d32663cbe4223bf94346aadce247c6ea --- /dev/null +++ b/model/third_party/HMNet/DataLoader/infinibatch/test/test_doctests.py @@ -0,0 +1,17 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +This file causes the doctests to be included as part of unit tests. + +To make sure the doctests of a specific module are included, +please replicate the `addTests` call for the iterators module below. +""" + +import doctest +import infinibatch.iterators + + +def load_tests(loader, tests, ignore): + tests.addTests(doctest.DocTestSuite(infinibatch.iterators)) + return tests diff --git a/model/third_party/HMNet/DataLoader/infinibatch/test/test_iterators.py b/model/third_party/HMNet/DataLoader/infinibatch/test/test_iterators.py new file mode 100644 index 0000000000000000000000000000000000000000..08d5e2465dec4f684435fb1663bd9566a8cfc27b --- /dev/null +++ b/model/third_party/HMNet/DataLoader/infinibatch/test/test_iterators.py @@ -0,0 +1,601 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import gzip +import itertools +from random import Random +import os +import shutil +import tempfile +from typing import Iterable, Iterator, Any, Union +import unittest +import pickle +import gc + +from infinibatch.iterators import ( + create_source_iterator, + ChunkedSourceIterator, + InfinitePermutationSourceIterator, + BufferedShuffleIterator, + BlockwiseShuffleIterator, + NativeCheckpointableIterator, + BucketedReadaheadBatchIterator, + MapIterator, + ParallelMapIterator, + ZipIterator, + FixedBatchIterator, + WindowedIterator, + SelectManyIterator, + RandomIterator, + RecurrentIterator, + SamplingRandomMapIterator, + PrefetchIterator, +) +from infinibatch.datasets import chunked_dataset_iterator + + +# TODO: +# - make sure that all iterators can be reset to a checkpoint even after they were exhausted +# - make sure that all iterators can be reset to a checkpoint that was taken after the iterator was exhausted +# - make sure that all iterators can be reset to a checkpoint at the beginning of the iteration +# - refactor test cases that do not rely on TestCheckpointableIterator +# - make sure every iterator is tested for correct checkpointing at the end of the iterator + + +class TestCheckpointableIterator: + """ + These are common test cases for CheckointableIterators + + Inherit from this class and set self.iterator and self.expected_result in the setUp function to use. + """ + + def test_basic(self): + self.assertListEqual(list(self.iterator), self.expected_result) + + def test_checkpointing_from_start(self): + for _ in range(len(self.expected_result)): + next(self.iterator) + self.iterator.setstate(None) + self.assertListEqual(list(self.iterator), self.expected_result) + + def test_checkpointing_in_middle(self): + result = [next(self.iterator) for _ in range(len(self.expected_result) // 3)] + self.iterator.setstate(self.iterator.getstate()) + result += [item for item in self.iterator] + self.assertListEqual(result, self.expected_result) + + def test_checkpointing_at_end(self): + for _ in range(len(self.expected_result)): + next(self.iterator) + self.iterator.setstate(self.iterator.getstate()) + self.assertRaises(StopIteration, self.iterator.__next__) + + +class TestBase(unittest.TestCase): + def setUp(self): + self.test_data = [ + [ + "item number one", + "item number two", + "item number three", + "item number four", + ], + ["item number five"], + [ + "item number six", + "item number seven", + "item number eight", + "item number nine", + "item number ten", + "item number eleven", + ], + [ + "item number twelve", + "item number thirteen", + "item number fourteen", + ], + ] + + self.flattened_test_data = [] + for chunk in self.test_data: + for item in chunk: + self.flattened_test_data.append(item) + + self.data_dir = tempfile.mkdtemp() + self.chunk_file_paths = [] + for chunk_id, chunk in enumerate(self.test_data): + file_name = os.path.join( + self.data_dir, "chunk_" + str(chunk_id).zfill(10) + ".gz" + ) + self.chunk_file_paths.append(file_name) + file_content = "\n".join(chunk) + with gzip.open(file_name, "wt", encoding="utf-8") as f: + f.write(file_content) + + @staticmethod + def read_chunk( + textfile_path: str, + ) -> Iterator[str]: # read_chunk_fn for chunked_dataset_iterator + with gzip.open(textfile_path, "rt", encoding="utf-8") as f: + return iter(f.read().splitlines()) + + def tearDown(self): + gc.collect() # this will get the pre-fetch terminated in some tests, which otherwise may still want to read these files + shutil.rmtree(self.data_dir) + + def assertMultisetEqual(self, a, b): + self.assertEqual(len(a), len(b)) + self.assertSetEqual(set(a), set(b)) + + +class TestSourceIterator(unittest.TestCase): + def test_exception(self): + self.assertRaises( + ValueError, create_source_iterator, [1], train=False, shuffle=True + ) + + +class TestChunkedSourceIterator(unittest.TestCase, TestCheckpointableIterator): + def setUp(self): + self.expected_result = list(range(53)) + self.iterator = ChunkedSourceIterator(self.expected_result) + + def test_multiple_instance(self): + for num_instances in range(2, 17): + items = [] + for rank in range(num_instances): + iterator = ChunkedSourceIterator( + self.expected_result, + num_instances=num_instances, + instance_rank=rank, + ) + items.extend(list(iterator)) + self.assertListEqual(items, self.expected_result) + + +class TestInfinitePermutationSourceIterator(TestBase): + def test_repeat_once(self): + # This tests that two consecutive iterations through the test data yields differently ordered sequences. + reader = iter(InfinitePermutationSourceIterator(self.flattened_test_data, 42)) + items0 = list(itertools.islice(reader, len(self.flattened_test_data))) + items1 = list(itertools.islice(reader, len(self.flattened_test_data))) + self.assertMultisetEqual(items0 + items1, self.flattened_test_data * 2) + self.assertTrue(any(item0 != item1 for item0, item1 in zip(items0, items1))) + + def test_reiter_once(self): + # This differs from test_repeat_once in that we use checkpoints. + reader = InfinitePermutationSourceIterator(self.flattened_test_data, 42) + checkpoint = reader.getstate() + items0 = list(itertools.islice(reader, len(self.flattened_test_data))) + reader.setstate(checkpoint) + items1 = list(itertools.islice(reader, len(self.flattened_test_data))) + self.assertMultisetEqual(items0 + items1, self.flattened_test_data * 2) + self.assertSequenceEqual(items0, items1) + + def test_checkpointing(self): + random = Random() + for i in range(5): + # random sequence lengths t o for testing different configurations + test_source_length = random.randrange(5, 25) + test_first_output_length = random.randrange(5, 25) + test_second_output_length = random.randrange(5, 25) + # source + test_source = list(range(test_source_length)) + reader = InfinitePermutationSourceIterator(test_source, seed=i) + # fetch a first sequence + _ = list(itertools.islice(reader, test_first_output_length)) + # fetch a second sequence + checkpoint = reader.getstate() + items1a = list(itertools.islice(reader, test_second_output_length)) + # fetch that second sequence again via checkpointing + reader.setstate(checkpoint) + items1b = list(itertools.islice(reader, test_second_output_length)) + # and again with serialized checkpoint + as_json = pickle.dumps(checkpoint) + checkpoint2 = pickle.loads(as_json) + reader.setstate(checkpoint2) + items1c = list(itertools.islice(reader, test_second_output_length)) + # must be the same + self.assertTrue(items1a == items1b) + self.assertTrue(items1a == items1c) + + +class TestNativeCheckpointableIterator(unittest.TestCase, TestCheckpointableIterator): + def setUp(self): + self.expected_result = list(range(53)) + self.iterator = NativeCheckpointableIterator(self.expected_result) + + def test_iterator_exception(self): + self.assertRaises(ValueError, NativeCheckpointableIterator, iter(range(10))) + + +class TestRecurrentIterator(unittest.TestCase, TestCheckpointableIterator): + def setUp(self): + data = list(range(53)) + + self.expected_result = [0] + for i in data[1:]: + self.expected_result.append(self.expected_result[-1] + i) + + def step_function(prev_state, item): + output = item + prev_state + new_state = output + return new_state, output + + self.iterator = RecurrentIterator( + NativeCheckpointableIterator(data), step_function, initial_state=0 + ) + + +class TestSamplingRandomMapIterator(unittest.TestCase, TestCheckpointableIterator): + def setUp(self): + data = list(range(53)) + + def transform(random: Random, item: int): + return item + random.random() + + seed = 1 + random = Random() + random.seed(seed) + self.expected_result = [n + random.random() for n in data] + + self.iterator = SamplingRandomMapIterator( + NativeCheckpointableIterator(data), transform=transform, seed=seed + ) + + +class TestFixedBatchIterator(unittest.TestCase, TestCheckpointableIterator): + def setUp(self): + data = list(range(5)) + + batch_size = 3 + self.expected_result = [data[0:3], data[3:]] + + self.iterator = FixedBatchIterator( + NativeCheckpointableIterator(data), batch_size=batch_size + ) + + +class TestSelectManyIterator(TestBase): + # in this test, SelectManyIterator is used to read chunk files + @staticmethod + def _select_many_from_chunks(chunk_file_paths): + return SelectManyIterator( + source_iterator=chunk_file_paths, collection_selector=TestBase.read_chunk + ) + + def test(self): + items = list( + self._select_many_from_chunks( + NativeCheckpointableIterator(self.chunk_file_paths) + ) + ) + self.assertListEqual(items, self.flattened_test_data) + + def test_no_selector(self): + data = list(range(100)) + sublists = [data[:10], data[10:42], data[42:87], data[87:]] + result = list(SelectManyIterator(NativeCheckpointableIterator(sublists))) + self.assertListEqual(result, data) + + def test_different_line_endings(self): + # write data in binary mode with LF line endings + lf_dir = tempfile.mkdtemp() + lf_file = os.path.join(lf_dir, "test.gz") + with gzip.o pen(lf_file, "w") as f: + f.write("\n".join(self.flattened_test_data).encode("utf-8")) + + # write data in binary mode with CRLF line endings + crlf_dir = tempfile.mkdtemp() + crlf_file = os.path.join(crlf_dir, "test.gz") + with gzip.open(crlf_file, "w") as f: + f.write("\r\n".join(self.flattened_test_data).encode("utf-8")) + + lf_data = list( + self._select_many_from_chunks(NativeCheckpointableIterator([lf_file])) + ) + crlf_dat = list( + self._select_many_from_chunks(NativeCheckpointableIterator([crlf_file])) + ) + self.assertListEqual(lf_data, crlf_dat) + + shutil.rmtree(lf_dir) + shutil.rmtree(crlf_dir) + + def test_checkpointing(self): + chunk_file_paths = [ + os.path.join(self.data_dir, subpath.name) + for subpath in os.scandir(self.data_dir) + if subpath.is_file() and subpath.name.endswith(".gz") + ] + chunk_file_paths = InfinitePermutationSourceIterator( + chunk_file_paths, shuffle=False + ) # using this as checkpointed cycle() + random = Random(1) + for _ in range(5): + first_length = random.randrange(11, 31) + extra_length = random.randrange(11, 33) + dataset = self._select_many_from_chunks(chunk_file_paths) + for _ in range(first_length): + next(dataset) + checkpoint = dataset.getstate() + items0 = list(itertools.islice(dataset, extra_length)) + # print(len(items0)) + dataset.setstate(checkpoint) + items1 = list(itertools.islice(dataset, extra_length)) + # print(len(items1)) + self.assertListEqual(items0, items1) + + +class TestBufferedShuffleIterator(TestBase): + def test_shuffle(self): + # work on copy of data in case data is modified by class + items = list( + BufferedShuffleIterator( + NativeCheckpointableIterator(self.flattened_test_data.copy()), 971, 42 + ) + ) + self.assertMultisetEqual(items, self.flattened_test_data) + + def test_shuffle_buffer_size_one(self): + # work on copy of data in case data is modified by class + items = list( + BufferedShuffleIterator( + NativeCheckpointableIterator(self.flattened_test_data.copy()), 1, 42 + ) + ) + self.assertListEqual(items, self.flattened_test_data) + + +# note: this is also tested in more depth in Test_chunked_dataset_iterator() +class TestBlockwiseShuffleIterator(TestBase): + def test_shuffle(self): + # work on copy of data in case data is modified by class + items = list( + BlockwiseShuffleIterator( + NativeCheckpointableIterator(self.flattened_test_data.copy()), 971, 42 + ) + ) + self.assertMultisetEqual(items, self.flattened_test_data) + + def test_shuffle_buffer_size_one(self): + # work on copy of data in case data is modified by class + items = list( + BlockwiseShuffleIterator( + NativeCheckpointableIterator(self.flattened_test_data.copy()), 1, 42 + ) + ) + self.assertListEqual(items, self.flattened_test_data) + + +def map_fun(n): + return n + 1 + + +class TestMapIterator(unittest.TestCase, TestCheckpointableIterator): + def setUp(self): + data = list(range(53)) + self.expected_result = [map_fun(n) for n in data] + self.iterator = MapIterator(NativeCheckpointableIterator(data), map_fun) + + +class TestParallelMapIterator(unittest.TestCase, TestCheckpointableIterator): + def setUp(self): + data = list(range(53)) + self.expected_result = [map_fun(n) for n in data] + self.iterator = ParallelMapIterator( + NativeCheckpointableIterator(data), map_fun, 5, 7 + ) + + +class TestZipIterator(unittest.TestCase, TestCheckpointableIterator): + def setUp(self): + data1 = list(range(53)) + data2 = [n * n for n in data1] + self.expected_result = list(zip(data1, data2)) + self.iterator = ZipIterator( + NativeCheckpointableIterator(data1), NativeCheckpointableIterator(data2) + ) + + +class TestWindowedIterator(TestBase): + def test(self): + for n in [0, 2, 3, 8, 9, 10, 11, 12]: # cover various boundary conditions + seq = list(range(n)) + it = WindowedIterator(NativeCheckpointableIterator(seq), 3) + actual0 = list(itertools.islice(it, n * 3 // 10)) + checkpoint = it.getstate() + actual1a = list(it) + it.setstate(checkpoint) + actual1b = list(it) + actual = actual0 + actual1a + expected = list( + zip(seq, itertools.islice(seq, 1, None), itertools.islice(seq, 2, None)) + ) + self.assertListEqual(actual, expected) # basic operation + self.assertListEqual(actual1a, actual1b) # checkpointing + + +class TestRandomIterator(TestBase): + def test(self): + n = 100 + it = RandomIterator(seed=1) + _ = list(itertools.islice(it, n * 3 // 10)) + checkpoint = it.getstate() + items1a = list(itertools.islice(it, n * 7 // 10)) + it.setstate(checkpoint) + items1b = list(itertools.islice(it, n * 7 // 10)) + self.assertListEqual(items1a, items1b) + + +class TestPrefetchIterator(unittest.TestCase, TestCheckpointableIterator): + def setUp(self): + self.expected_result = list(range(53)) + source_iterator = NativeCheckpointableIterator(self.expected_result) + self.iterator = PrefetchIterator(source_iterator, buffer_size=13) + + +class Test_chunked_dataset_iterator(TestBase): + def test_no_shuffle(self): + items = list( + itertools.islice( + chunked_dataset_iterator( + self.chunk_file_paths, + self.read_chunk, + shuffle=False, + buffer_size=1000, + ), + len(self.flattened_test_data), + ) + ) + self.assertListEqual(items, self.flattened_test_data) + + def test_other_files_present(self): + with open(os.path.join(self.data_dir, "i_do_not_belong_here.txt"), "w") as f: + f.write("really ...") + items = list( + itertools.islice( + chunked_dataset_iterator( + self.chunk_file_paths, + self.read_chunk, + shuffle=False, + buffer_size=1000, + ), + len(self.flattened_test_data), + ) + ) + self.assertListEqual(items, self.flattened_test_data) + + def test_transform(self): + transform = lambda s: s + "!" + modified_test_data = [transform(s) for s in self.flattened_test_data] + items = list( + itertools.islice( + chunked_dataset_iterator( + self.chunk_file_paths, + self.read_chunk, + shuffle=False, + buffer_size=1000, + transform=transform, + ), + len(self.flattened_test_data), + ) + ) + self.assertListEqual(items, modified_test_data) + + def test_two_instances(self): + dataset0 = chunked_dataset_iterator( + self.chunk_file_paths, + self.read_chunk, + shuffle=False, + buffer_size=1000, + num_instances=2, + instance_rank=0, + ) + dataset1 = chunked_dataset_iterator( + self.chunk_file_paths, + self.read_chunk, + shuffle=False, + buffer_size=1000, + num_instances=2, + instance_rank=1, + ) + items0 = list( + itertools.islice(dataset0, len(self.test_data[0]) + len(self.test_data[2])) + ) + items1 = list( + itertools.islice(data set1, len(self.test_data[1]) + len(self.test_data[3])) + ) + self.assertMultisetEqual(set(items0 + items1), self.flattened_test_data) + + def test_checkpointing(self): + random = Random(1) + for use_windowed in (True, False): + for i in range(2): + first_length = random.randrange(11, 21) + extra_length = random.randrange(11, 21) + dataset = chunked_dataset_iterator( + self.chunk_file_paths, + self.read_chunk, + shuffle=(i % 2 == 0), + buffer_size=1000, + seed=i, + num_instances=2, + instance_rank=0, + use_windowed=use_windowed, + ) + for _ in range(first_length): + next(dataset) + checkpoint = dataset.getstate() + items1 = list(itertools.islice(dataset, extra_length)) + dataset.setstate(checkpoint) + items2 = list(itertools.islice(dataset, extra_length)) + self.assertListEqual(items1, items2) + + +class TestBucketedReadaheadBatchIterator(TestBase): + def txest_basic_functionality(self): + num_batches = 13 + batch_labels = ( + 75 # note: these settings imply a few iterations through the chunks + ) + # basic operation, should not crash + bg = BucketedReadaheadBatchIterator( + chunked_dataset_iterator( + self.chunk_file_paths, + self.read_chunk, + shuffle=True, + buffer_size=1000, + seed=1, + ), + read_ahead=100, + seed=1, + key=lambda line: len(line), + batch_size=lambda line: batch_labels // (1 + len(line)), + ) + batches1 = list(itertools.islice(bg, num_batches)) + # verify determinism + bg = BucketedReadaheadBatchIterator( + chunked_dataset_iterator( + self.chunk_file_paths, + self.read_chunk, + shuffle=True, + buffer_size=1000, + seed=1, + ), + read_ahead=100, + seed=1, + key=lambda line: len(line), + batch_size=lambda line: batch_labels // (1 + len(line)), + ) + batches2 = list(itertools.islice(bg, num_batches)) + print([(len(batch[0]), len(batch)) for batch in batches1]) + self.assertListEqual(batches1, batches2) + + def test_checkpointing(self): + first_batches = 12 + extra_batches = 7 + batch_labels = 123 + bg = BucketedReadaheadBatchIterator( + chunked_dataset_iterator( + self.chunk_file_paths, + self.read_chunk, + shuffle=True, + buffer_size=1000, + seed=1, + ), + read_ahead=100, + seed=1, + key=lambda line: len(line), + batch_size=lambda line: batch_labels // (1 + len(line)), + ) + _ = list(itertools.islice(bg, first_batches)) + checkpoint = bg.getstate() + batches1 = list(itertools.islice(bg, extra_batches)) + bg.setstate(checkpoint) + batches2 = list(itertools.islice(bg, extra_batches)) + self.assertListEqual(batches1, batches2) + + +if __name__ == "__main__": + unittest.main() diff --git a/model/third_party/HMNet/DataLoader/infinibatch/unit-test-pipeline.yml b/model/third_party/HMNet/DataLoader/infinibatch/unit-test-pipeline.yml new file mode 100644 index 0000000000000000000000000000000000000000..533ce122afb42afcc514594c772ec164ddbce242 --- /dev/null +++ b/model/third_party/HMNet/DataLoader/infinibatch/unit-test-pipeline.yml @@ -0,0 +1,65 @@ +# Python package +# Create and test a Python package on multiple Python versions. +# Add steps that analyze code, save the dist with the build record, publish to a PyPI-compatible index, and more: +# https://docs.m icrosoft.com/azure/devops/pipelines/languages/python + +trigger: +- master +- dev/* + +jobs: + - job: Linux + pool: + vmImage: 'ubuntu-latest' + strategy: + matrix: + Python35: + python.version: '3.5' + Python36: + python.version: '3.6' + Python37: + python.version: '3.7' + + steps: + - task: UsePythonVersion@0 + inputs: + versionSpec: '$(python.version)' + displayName: 'Use Python $(python.version)' + + - script: | + python -m pip install --upgrade pip + pip install -r requirements.txt + displayName: 'Install dependencies' + + - script: | + pip install unittest + python -m unittest discover -s ./test + displayName: 'unittest' + + - job: Windows + pool: + vmImage: 'windows-latest' + strategy: + matrix: + Python35: + python.version: '3.5' + Python36: + python.version: '3.6' + Python37: + python.version: '3.7' + + steps: + - task: UsePythonVersion@0 + inputs: + versionSpec: '$(python.version)' + displayName: 'Use Python $(python.version)' + + - script: | + python -m pip install --upgrade pip + pip install -r requirements.txt + displayName: 'Install dependencies' + + - script: | + pip install unittest + python -m unittest discover -s ./test + displayName: 'unittest' diff --git a/model/third_party/HMNet/Evaluation/OldROUGEEval.py b/model/third_party/HMNet/Evaluation/OldROUGEEval.py new file mode 100644 index 0000000000000000000000000000000000000000..7ac8ddf6877e4d00b748f9bcd2c70ae3fbf21618 --- /dev/null +++ b/model/third_party/HMNet/Evaluation/OldROUGEEval.py @@ -0,0 +1,432 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""ROUGe metric implementation. + +This is a modified and slightly extended verison of +https://github.com/miso-belica/sumy/blob/dev/sumy/evaluation/rouge.py. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import itertools +import numpy as np + +# pylint: disable=C0103 + + +def _get_ngrams(n, text): + """Calcualtes n-grams. + + Args: + n: which n-grams to calculate + text: An array of tokens + + Returns: + A set of n-grams + """ + ngram_set = {} + text_length = len(text) + max_index_ngram_start = text_length - n + for i in range(max_index_ngram_start + 1): + k = " ".join(text[i : i + n]) + if k not in ngram_set: + ngram_set[k] = 0 + ngram_set[k] += 1 + return ngram_set + + +def _get_su(dist, text): + """Calcualtes skip-grams and unigram + + Args: + n: which n-grams to calculate + text: An array of tokens + + Returns: + A set of n-grams + """ + su_set = {} + text_length = len(text) + for i in range(text_length): + k = text[i] + if k not in su_set: + su_set[k] = 0 + su_set[k] += 1 + for j in range(i + 1, text_length): + if j - i - 1 > dist: + break + k = text[i] + " " + text[j] + if k not in su_set: + su_set[k] = 0 + su_set[k] += 1 + return su_set + + +def _split_into_words(sentences): + """Splits multiple sentences into words and flattens the result""" + return list(itertools.chain(*[_.split(" ") for _ in sentences])) + + +def _get_word_ngrams(n, sentences): + """Calculates word n-grams for multiple sentences.""" + assert len(sentences) > 0 + assert n > 0 + + words = _split_into_words(sentences) + return _get_ngrams(n, words) + + +def _get_word_su(dist, sentences): + """Calculates word skip-dist-grams for multiple sentences.""" + assert len(sentences) > 0 + assert dist > 0 + + words = _split_into_words(sentences) + return _get_su(dist, words) + + +def _len_lcs(x, y): + """ + Returns the length of the Longest Common Subsequence between sequen ces x + and y. + Source: http://www.algorithmist.com/index.php/Longest_Common_Subsequence + + Args: + x: sequence of words + y: sequence of words + + Returns + integer: Length of LCS between x and y + """ + table = _lcs(x, y) + n, m = len(x), len(y) + return table[n, m] + + +def _lcs(x, y): + """ + Computes the length of the longest common subsequence (lcs) between two + strings. The implementation below uses a DP programming algorithm and runs + in O(nm) time where n = len(x) and m = len(y). + Source: http://www.algorithmist.com/index.php/Longest_Common_Subsequence + + Args: + x: collection of words + y: collection of words + + Returns: + Table of dictionary of coord and len lcs + """ + n, m = len(x), len(y) + table = dict() + for i in range(n + 1): + for j in range(m + 1): + if i == 0 or j == 0: + table[i, j] = 0 + elif x[i - 1] == y[j - 1]: + table[i, j] = table[i - 1, j - 1] + 1 + else: + table[i, j] = max(table[i - 1, j], table[i, j - 1]) + return table + + +def _recon_lcs(x, y): + """ + Returns the Longest Subsequence between x and y. + Source: http://www.algorithmist.com/index.php/Longest_Common_Subsequence + + Args: + x: sequence of words + y: sequence of words + + Returns: + sequence: LCS of x and y + """ + i, j = len(x), len(y) + table = _lcs(x, y) + + def _recon(i, j): + """private recon calculation""" + if i == 0 or j == 0: + return [] + elif x[i - 1] == y[j - 1]: + return _recon(i - 1, j - 1) + [(x[i - 1], i)] + elif table[i - 1, j] > table[i, j - 1]: + return _recon(i - 1, j) + else: + return _recon(i, j - 1) + + recon_tuple = tuple(map(lambda x: x[0], _recon(i, j))) + return recon_tuple + + +def rouge_su(evaluated_sentences, reference_sentences, dist=4): + """ + Computes ROUGE-SU_dist of two text collections of sentences. + Sourece: http://research.microsoft.com/en-us/um/people/cyl/download/ + papers/rouge-working-note-v1.3.1.pdf + + Args: + evaluated_sentences: The sentences that have been picked by the summarizer + reference_sentences: The sentences from the referene set + n: maximum distance between two tokens. Defaults to 4. + + Returns: + A tuple (f1, precision, recall) for ROUGE-SU4 + + Raises: + ValueError: raises exception if a param has len <= 0 + """ + return rouge_n(evaluated_sentences, reference_sentences, dist=dist, su=True) + + +def rouge_n(evaluated_sentences, reference_sentences, n=2, dist=4, su=False): + """ + Computes ROUGE-N of two text collections of sentences. + Sourece: http://research.microsoft.com/en-us/um/people/cyl/download/ + papers/rouge-working-note-v1.3.1.pdf + + Args: + evaluated_sentences: The sentences that have been picked by the summarizer + reference_sentences: The sentences from the referene set + n: Size of ngram. Defaults to 2. + su: if true, we are computing rouge_su + + Returns: + A tuple (f1, precision, recall) for ROUGE-N + + Raises: + ValueError: raises exception if a param has len <= 0 + """ + if len(evaluated_sentences) <= 0 or len(reference_sentences) <= 0: + raise ValueError("Collections must contain at least 1 sentence.") + + if su == True: + evaluated_ngrams = _get_word_su(dist, evaluated_sentences) + reference_ngrams = _get_word_su(dist, reference_sentences) + else: + evaluated_ngrams = _get_word_ngrams(n, evaluated_sentences) + reference_ngrams = _get_word_ngrams(n, reference_sentences) + + reference_count = sum([v for k, v in reference_ngrams.items()]) + evaluated_count = sum([v for k, v in evaluated_ngrams.items()]) + + # Gets the overlapping ngrams between evaluated and reference + overlapping_count = 0 + for k, v in reference_ngrams.items(): + if k in evaluated_ngrams: + if evaluated_ngrams[k] < v: + overlapping_count += evaluated_ngrams[k] + else: + overlapping_count += v + + # Handle edge case. This isn't mathematically correct, but it's good enough + if evaluated_count == 0: + precision = 0.0 + else: + precision = overlapping_count / evaluated_count + + if reference_count == 0: + recall = 0.0 + else: + recall = overlapping_count / reference_count + + f1_score = 2.0 * ((precision * recall) / (precision + recall + 1e-8)) + + # return overlapping_count / reference_count + return f1_score, precision, recall + + +def _f_p_r_lcs(llcs, m, n): + """ + Computes the LCS-based F-measure score + Source: http://research.microsoft.com/en-us/um/people/cyl/download/papers/ + rouge-working-note-v1.3.1.pdf + + Args: + llcs: Length of LCS + m: number of words in reference summary + n: number of words in candidate summary + + Returns: + Float. LCS-based F-measure score + """ + r_lcs = llcs / m + p_lcs = llcs / n + beta = p_lcs / (r_lcs + 1e-12) + num = (1 + (beta ** 2)) * r_lcs * p_lcs + denom = r_lcs + ((beta ** 2) * p_lcs) + f_lcs = num / (denom + 1e-12) + return f_lcs, p_lcs, r_lcs + + +def rouge_l_sentence_level(evaluated_sentences, reference_sentences): + """ + Computes ROUGE-L (sentence level) of two text collections of sentences. + http://research.microsoft.com/en-us/um/people/cyl/download/papers/ + rouge-working-note-v1.3.1.pdf + + Calculated according to: + R_lcs = LCS(X,Y)/m + P_lcs = LCS(X,Y)/n + F_lcs = ((1 + beta^2)*R_lcs*P_lcs) / (R_lcs + (beta^2) * P_lcs) + + where: + X = reference summary + Y = Candidate summary + m = length of reference summary + n = length of candidate summary + + Args: + evaluated_sentences: The sentences that have been picked by the summarizer + reference_sentences: The sentences from the referene set + + Returns: + A float: F_lcs + + Raises: + ValueError: raises exception if a param has len <= 0 + """ + if len(evaluated_sentences) <= 0 or len(reference_sentences) <= 0: + raise ValueError("Collections must contain at least 1 sentence.") + reference_words = _split_into_words(reference_sentences) + evaluated_words = _split_into_words(evaluated_sentences) + m = len(reference_words) + n = len(evaluated_words) + lcs = _len_lcs(evaluated_words, reference_words) + return _f_p_r_lcs(lcs, m, n) + + +def _union_lcs(evaluated_sentences, reference_sentence): + """ + Returns LCS_u(r_i, C) which is the LCS score of the union longest common + subsequence between reference sentence ri and candidate summary C. For example + if r_i= w1 w2 w3 w4 w5, and C contains two sentences: c1 = w1 w2 w6 w7 w8 and + c2 = w1 w3 w8 w9 w5, then the longest common subsequence of r_i and c1 is + “w1 w2” and the longest common subsequence of r_i and c2 is “w1 w3 w5”. The + union longest common subsequence of r_i, c1, and c2 is “w1 w2 w3 w5” and + LCS_u(r_i, C) = 4/5. + + Args: + evaluated_sentences: The sentences that have been picked by the summarizer + reference_sentence: One of the sentences in the reference summaries + + Returns: + float: LCS_u(r_i, C) + + ValueError: + Raises exception if a param has len <= 0 + """ + if len(evaluated_sentences) <= 0: + raise ValueError("Collections must contain at least 1 sentence.") + + lcs_union = set() + reference_words = _split_into_words([reference_sentence]) + combined_lcs_length = 0 + for eval_s in evaluated_sentences: + evaluated_words = _split_into_words([eval_s]) + lcs = set(_recon_lcs(reference_words, evaluated_words)) + combined_lcs_length += len(lcs) + lcs_union = lcs_union.union(lcs) + + union_lcs_count = len(lcs_union) + union_lcs_value = union_lcs_count / combined_lcs_length + return union_lcs_value + + +def rouge_l_summary_level(evaluated_sentences, reference_sentences): + " "" + Computes ROUGE-L (summary level) of two text collections of sentences. + http://research.microsoft.com/en-us/um/people/cyl/download/papers/ + rouge-working-note-v1.3.1.pdf + + Calculated according to: + R_lcs = SUM(1, u)[LCS(r_i,C)]/m + P_lcs = SUM(1, u)[LCS(r_i,C)]/n + F_lcs = ((1 + beta^2)*R_lcs*P_lcs) / (R_lcs + (beta^2) * P_lcs) + + where: + SUM(i,u) = SUM from i through u + u = number of sentences in reference summary + C = Candidate summary made up of v sentences + m = number of words in reference summary + n = number of words in candidate summary + + Args: + evaluated_sentences: The sentences that have been picked by the summarizer + reference_sentence: One of the sentences in the reference summaries + + Returns: + A float: F_lcs + + Raises: + ValueError: raises exception if a param has len <= 0 + """ + if len(evaluated_sentences) <= 0 or len(reference_sentences) <= 0: + raise ValueError("Collections must contain at least 1 sentence.") + + # total number of words in reference sentences + m = len(_split_into_words(reference_sentences)) + + # total number of words in evaluated sentences + n = len(_split_into_words(evaluated_sentences)) + + union_lcs_sum_across_all_references = 0 + for ref_s in reference_sentences: + union_lcs_sum_across_all_references += _union_lcs(evaluated_sentences, ref_s) + return _f_p_r_lcs(union_lcs_sum_across_all_references, m, n) + + +def rouge(hypotheses, references): + """Calculates average rouge scores for a list of hypotheses and + references""" + + # Filter out hyps that are of 0 length + # hyps_and_refs = zip(hypotheses, references) + # hyps_and_refs = [_ for _ in hyps_and_refs if len(_[0]) > 0] + # hypotheses, references = zip(*hyps_and_refs) + + # Calculate ROUGE-1 F1, precision, recall scores + rouge_1 = [rouge_n([hyp], [ref], 1) for hyp, ref in zip(hypotheses, references)] + rouge_1_f, rouge_1_p, rouge_1_r = map(np.mean, zip(*rouge_1)) + + # Calculate ROUGE-2 F1, precision, recall scores + rouge_2 = [rouge_n([hyp], [ref], 2) for hyp, ref in zip(hypotheses, references)] + rouge_2_f, rouge_2_p, rouge_2_r = map(np.mean, zip(*rouge_2)) + + # Calculate ROUGE-SU4 F1, precision, recall scores + rouge_su4 = [rouge_su([hyp], [ref], 4) for hyp, ref in zip(hypotheses, references)] + rouge_su4_f, rouge_su4_p, rouge_su4_r = map(np.mean, zip(*rouge_su4)) + + # Calculate ROUGE-L F1, precision, recall scores + rouge_l = [ + rouge_l_sentence_level([hyp], [ref]) for hyp, ref in zip(hypotheses, references) + ] + rouge_l_f, rouge_l_p, rouge_l_r = map(np.mean, zip(*rouge_l)) + + return { + "rouge_1_f_score": rouge_1_f, + "rouge_2_f_score": rouge_2_f, + "rouge_su4_f_score": rouge_su4_f, + "rouge_l_f_score": rouge_l_f, + } + + +class OldROUGEEval: + def __init__(self): + pass + + def make_html_safe(self, s): + s.replace("<", "<") + s.replace(">", ">") + return s + + def eval(self, predictions, groundtruths): + predictions = [self.make_html_safe(w) for w in predictions] + groundtruths = [self.make_html_safe(w) for w in groundtruths] + results = rouge(predictions, groundtruths) + return results diff --git a/model/third_party/HMNet/Evaluation/ROUGEEval.py b/model/third_party/HMNet/Evaluation/ROUGEEval.py new file mode 100644 index 0000000000000000000000000000000000000000..e5fb9a95319404cb2ed1d87711947599a1fb7a46 --- /dev/null +++ b/model/third_party/HMNet/Evaluation/ROUGEEval.py @@ -0,0 +1,354 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import os +import re +import shutil +from string import ascii_uppercase +from tqdm.auto import tqdm +from model.third_party.HMNet.Evaluation.OldROUGEEval import rouge +from model.third_party.HMNet.ThirdParty.ROUGE import pyrouge +from shutil import copyfile +from mpi4py import MPI +import torch +import logging +import json + + +def write_json_res( + output_file, tokenizers, x_ids, y_ids, x_tokens, y_tokens, predictions, gts +): + data = [] + + # for x_id, y_id, x_token, y_token, preds, gt in zip(x_ids, y_ids, x_tokens, y_tokens, predictions, gts): + # x_id = tokenizers[0].decode(x_id, skip_special_tokens=False) if x_id.dim() == 1 else tokenizers[0].convert_tokens_to_string(x_token) + # y_id = tokenizers[1].decode(y_id, skip_special_tokens=False) if y_id.dim() == 1 else tokenizers[1].convert_tokens_to_string(y_token) + for x_token, y_token, preds, gt in zip(x_tokens, y_tokens, predictions, gts): + data.append( + { + # 'x_ids': x_id, + # 'y_ids': y_id, + "x_tokens": x_token if isinstance(x_token, str) else " ".join(x_token), + "y_tokens": y_token if isinstance(y_token, str) else " ".join(y_token), + "predictions": preds, + "gt": gt, + } + ) + + json.dump(data, output_file, indent=4, ensure_ascii=False) + + +logger = logging.getLogger(__name__) + +""" +This code can only be run within docker "rouge", because of the usage of rouge-perl +""" + + +"""" In ROUGE parlance, your summaries are ‘system’ summaries and the gold standard summaries are ‘model’ summaries. +The summaries should be in separate folders, whose paths are set with the system_dir and model_dir variables. +All summaries should contain one sentence per line.""" + + +class ROUGEEval: + """ + Wrapper class for pyrouge. + Compute ROUGE given predictions and references for summarization evaluation. + """ + + def __init__(self, run_dir, save_dir, opt): + self.run_dir = run_dir + self.save_dir = save_dir + self.opt = opt + + # use relative path to make it work on Philly + self.pyrouge_dir = os.path.join( + os.path.dirname(__file__), "../ThirdParty/ROUGE/ROUGE-1.5.5/" + ) + + self.eval_batches_num = self.opt.get("EVAL_BATCHES_NUM", float("Inf")) + self.best_score = -float("Inf") + self.best_res = {} + + def reset_best_score(self, set_high=False): + if set_high: + self.best_score = float("Inf") + else: + self.best_score = -float("Inf") + + def make_html_safe(self, s): + s = s.replace("<", "<") + s = s.replace(">", ">") + return s + + def print_to_rouge_dir( + self, summaries, dir, suffix, split_chars, special_char_dict=None + ): + for idx, summary in enumerate(summaries): + fname = os.path.join(dir, "%06d_%s.txt" % (idx, suffix)) + with open(fname, "wb") as f: + sents = re.split(r"(?') + # else: + # new_predicitons.append(pred) + # return new_predicitons, new_groundtruths + + def _convert_tokens_to_string(self, tokenizer, tokens): + if "EVAL_TOKENIZED" in self.opt: + tokens = [t for t in tokens if t not in tokenizer.all_special_tokens] + if "EVAL_LOWERCASE" in self.opt: + tokens = [t.lower() for t in tokens] + if "EVAL_TOKENIZED" in self.opt: + return " ".join(tokens) + else: + return tokenizer.decode( + tokenizer.convert_tokens_to_ids(tokens), skip_special_tokens=True + ) + + def eval_batches(self, module, dev_batches, save_folder, label=""): + max_sent_len = int(self.opt["MAX_GEN_LENGTH"]) + + logger.info( + "Decoding current model ... \nSaving folder is {}".format(save_folder) + ) + + predictions = [] # prediction of tokens from model + x_tokens = [] # input tokens + y_tokens = [] # groundtruths tokens + x_ids = [] # input token ids + y_ids = [] # groundtruths token ids + gts = [] # groundtruths string + got_better_score = False + # err = 0 + if not isinstance(module.tokenizer, list): + encoder_tokenizer = module.tokenizer + decoder_tokenizer = module.tokenizer + elif len(module.tokenizer) == 1: + encoder_tokenizer = module.tokenizer[0] + decoder_tokenizer = module.tokenizer[0] + elif len(module.tokenizer) == 2: + encoder_tokenizer = module.tokenizer[0] + decoder_tokenizer = module.tokenizer[1] + else: + assert False, f"len(module.tokenizer) > 2" + + with torch.no_grad(): + for j, dev_batch in enumerate(dev_batches): + for b in dev_batch: + if torch.is_tensor(dev_batch[b]): + dev_batch[b] = dev_batch[b].to(self.opt["device"]) + + beam_search_res = module( + dev_batch, beam_search=True, max_sent_len=max_sent_len + ) + pred = [ + [t[0] for t in x] if len(x) > 0 else [[]] for x in beam_search_res + ] + predictions.extend( + [ + [ + self._convert_tokens_to_string(decoder_tokenizer, tt) + for tt in t + ] + for t in pred + ] + ) + + gts.extend( + [ + self._convert_tokens_to_string(decoder_tokenizer, t) + for t in dev_batch["decoder_tokens"] + ] + ) + x_t okens.extend(dev_batch["encoder_tokens"]) + y_tokens.extend(dev_batch["decoder_tokens"]) + + if ("DEBUG" in self.opt and j >= 10) or j >= self.eval_batches_num: + # in debug mode (decode first 10 batches) ortherwise decode first self.eval_batches_num bathes + break + + # use MPI to gather results from all processes / GPUs + # the result of the gather operation is a list of sublists + # each sublist corresponds to the list created on one of the MPI processes (or GPUs, respectively) + # we flatten this list into a "simple" list + assert len(predictions) == len( + gts + ), "len(predictions): {0}, len(gts): {1}".format(len(predictions), len(gts)) + comm = MPI.COMM_WORLD + predictions = comm.gather(predictions, root=0) + x_tokens = comm.gather(x_tokens, root=0) + y_tokens = comm.gather(y_tokens, root=0) + # if GPU numbers are high (>=8), passing x_ids, y_ids to a rank 0 will cause out of memory + # x_ids = comm.gather(x_ids, root=0) + # y_ids = comm.gather(y_ids, root=0) + gts = comm.gather(gts, root=0) + if self.opt["rank"] == 0: + # flatten lists + predictions = [item for sublist in predictions for item in sublist] + y_tokens = [item for sublist in y_tokens for item in sublist] + x_tokens = [item for sublist in x_tokens for item in sublist] + # x_ids = [item for sublist in x_ids for item in sublist] + # y_ids = [item for sublist in y_ids for item in sublist] + gts = [item for sublist in gts for item in sublist] + # import pdb; pdb.set_trace() + assert ( + len(predictions) == len(y_tokens) == len(x_tokens) == len(gts) + ), "len(predictions): {0}, len(y_tokens): {1}, len(x_tokens): {2}, len(gts): {3}".format( + len(predictions), len(y_tokens), len(x_tokens), len(gts) + ) + + # write intermediate results only on rank 0 + if not os.path.isdir(os.path.join(save_folder, "intermediate_results")): + os.makedirs(os.path.join(save_folder, "intermediate_results")) + top_1_predictions = [pred[0] for pred in predictions] + with open( + os.path.join( + save_folder, "intermediate_results", "res_" + label + ".json" + ), + "w", + encoding="utf-8", + ) as output_file: + write_json_res( + output_file, + [encoder_tokenizer, decoder_tokenizer], + x_ids, + y_ids, + x_tokens, + y_tokens, + predictions, + gts, + ) + try: + result = self.eval(top_1_predictions, gts) + except Exception as e: + logger.exception("ROUGE Eval ERROR") + result = {} + score = -float("Inf") + pass # this happens when no overlapping between pred and gts + else: + rouge_su4 = rouge(top_1_predictions, gts) # f, prec, recall + result = { + "ROUGE_1": result["rouge_1_f_score"] * 100.0, + "ROUGE_1_Prc": result["rouge_1_precision"] * 100.0, + "ROUGE_1_Rcl": result["rouge_1_recall"] * 100.0, + "ROUGE_2": result["rouge_2_f_score"] * 100.0, + "ROUGE_2_Prc": result["rouge_2_precision"] * 100.0, + "ROUGE_2_Rcl": result["rouge_2_recall"] * 100.0, + "ROUGE_L": result["rouge_l_f_score"] * 100.0, + "ROUGE_L_Prc": result["rouge_l_precision"] * 100.0, + "ROUGE_L_Rcl": result["rouge_l_recall"] * 100.0, + "ROUGE_SU4": rouge_su4["rouge_su4_f_score"] * 100.0, + } + + score = result["ROUGE _1"] + if score > self.best_score: + copyfile( + os.path.join( + save_folder, + "intermediate_results", + "res_" + label + ".json", + ), + os.path.join( + save_folder, + "intermediate_results", + "res_" + label + ".best.json", + ), + ) + self.best_score = score + self.best_res = result + got_better_score = True + + else: + result = {} + score = -float("Inf") + got_better_score = False + + return result, score, got_better_score + + def eval(self, predictions, groundtruths): + # predictions, groundtruths = self.filter_empty(predictions, groundtruths) + predictions = [self.make_html_safe(w) for w in predictions] + groundtruths = [self.make_html_safe(w) for w in groundtruths] + pred_dir = os.path.join(self.save_dir, "predictions") + if os.path.exists(pred_dir): + shutil.rmtree(pred_dir) + os.makedirs(pred_dir) + + gt_dir = os.path.join(self.save_dir, "groundtruths") + if os.path.exists(gt_dir): + shutil.rmtree(gt_dir) + os.makedirs(gt_dir) + + special_char_dict = self.print_to_rouge_dir_gt( + groundtruths, gt_dir, "gt", "SPLIT_CHARS_FOR_EVAL" in self.opt + ) + self.print_to_rouge_dir( + predictions, + pred_dir, + "pred", + "SPLIT_CHARS_FOR_EVAL" in self.opt, + special_char_dict, + ) + + r = pyrouge.Rouge155(self.pyrouge_dir) + r.system_dir = pred_dir + r.model_dir = gt_dir + r.system_filename_pattern = "(\d+)_pred.txt" + r.model_filename_pattern = "[A-Z].#ID#_gt.txt" + results = r.output_to_dict(r.convert_and_evaluate()) + return results diff --git a/model/third_party/HMNet/Evaluation/__init__.py b/model/third_party/HMNet/Evaluation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/model/third_party/HMNet/ExampleConf/conf_eval_hmnet_AMI b/model/third_party/HMNet/ExampleConf/conf_eval_hmnet_AMI new file mode 100644 index 0000000000000000000000000000000000000000..30266c4bf7ae3fc94de7fe034aeb5f5af972dbc9 --- /dev/null +++ b/model/third_party/HMNet/ExampleConf/conf_eval_hmnet_AMI @@ -0,0 +1,98 @@ +################## +# Trainer settings +################## + +MODEL MeetingNet_Transformer +TASK HMNet +CRITERION MLECriterion + +SEED 1033 + +MAX_NUM_EPOCHS 20 +EVAL_PER_UPDATE_NUM 10 +UPDATES_PER_EPOCH 20 + +# The actuall learning rate will be multiplied with the number of GPUs +OPTIMIZER RAdam +START_LEARNING_RATE 1e-3 +LR_SCHEDULER LnrWrmpInvSqRtDcyScheduler +WARMUP_STEPS 16000 +WARMUP_INIT_LR 1e-4 +WARMUP_END_LR 1e-3 + +# The actuall start learning rate equals START_LEARNING_RATE * GRADIENT_ACCUMULATE_STEP +# Model will be updated after every MINI_BATCH * GRADIENT_ACCUMULATE_STEP samples +GRADIENT_ACCUMULATE_STEP 5 + +GRAD_CLIPPING 2 + +################## +# Task settings +################## + +# This is the relative path to the directory where this conf file locates +# not a good idea to put data with code +# Are we able to provide a list of dir paths in TRAIN_FILE? +USE_REL_DATA_PATH +TRAIN_FILE ../ExampleRawData/meeting_summarization/AMI_proprec/train_ami.json +DEV_FILE ../ExampleRawData/meeting_summarization/AMI_proprec/valid_ami.json +TEST_FILE ../ExampleRawData/meeting_summarization/AMI_proprec/test_ami.json +ROLE_DICT_FILE ../ExampleRawData/meeting_summarization/role_dict_ext.json + +MINI_BATCH 1 +MAX_PADDING_RATIO 1 +BATCH_READ_AHEAD 10 +DOC_SHUFFLE_BUF_SIZE 10 +SAMPLE_SHUFFLE_BUFFER_SIZE 10 +BATCH_SHUFFLE_BUFFER_SIZE 10 + +MAX_TRANSCRIPT_WORD 8300 +MAX_SENT_LEN 30 +MAX_SENT_NUM 300 + +################ ## +# Model settings +################## + +DROPOUT 0.1 +VOCAB_DIM 512 +ROLE_SIZE 32 +ROLE_DIM 16 +POS_DIM 16 +ENT_DIM 16 + +USE_ROLE +USE_POSENT + +USE_BOS_TOKEN +USE_EOS_TOKEN + +TRANSFORMER_EMBED_DROPOUT 0.1 +TRANSFORMER_RESIDUAL_DROPOUT 0.1 +TRANSFORMER_ATTENTION_DROPOUT 0.1 +TRANSFORMER_LAYER 6 +TRANSFORMER_HEAD 8 +TRANSFORMER_POS_DISCOUNT 80 + +PRE_TOKENIZER TransfoXLTokenizer +PRE_TOKENIZER_PATH ../ExampleInitModel/transfo-xl-wt103 +PYLEARN_MODEL +# e.g. PYLEARN_MODEL conf_hmnet_AMI_conf~/run_1/11600 +# PYLEARN_MODEL ../ExampleInitModel/AMI-finetuned + +################## +# Tokenizer settings +################## + +EXTRA_IDS 1000 + +################## +# Decoding settings +################## + +BEAM_WIDTH 6 +EVAL_TOKENIZED +EVAL_LOWERCASE +MAX_GEN_LENGTH 512 +MIN_GEN_LENGTH 400 +NO_REPEAT_NGRAM_SIZE 3 \ No newline at end of file diff --git a/model/third_party/HMNet/ExampleConf/conf_eval_hmnet_ICSI b/model/third_party/HMNet/ExampleConf/conf_eval_hmnet_ICSI new file mode 100644 index 0000000000000000000000000000000000000000..18d671da7b9728a9c915b16f4b8c81cb95aa3a70 --- /dev/null +++ b/model/third_party/HMNet/ExampleConf/conf_eval_hmnet_ICSI @@ -0,0 +1,98 @@ +################## +# Trainer settings +################## + +MODEL MeetingNet_Transformer +TASK HMNet +CRITERION MLECriterion + +SEED 1033 + +MAX_NUM_EPOCHS 20 +EVAL_PER_UPDATE_NUM 10 +UPDATES_PER_EPOCH 20 + +# The actuall learning rate will be multiplied with the number of GPUs +OPTIMIZER RAdam +START_LEARNING_RATE 1e-3 +LR_SCHEDULER LnrWrmpInvSqRtDcyScheduler +WARMUP_STEPS 16000 +WARMUP_INIT_LR 1e-4 +WARMUP_END_LR 1e-3 + +# The actuall start learning rate equals START_LEARNING_RATE * GRADIENT_ACCUMULATE_STEP +# Model will be updated after every MINI_BATCH * GRADIENT_ACCUMULATE_STEP samples +GRADIENT_ACCUMULATE_STEP 5 + +GRAD_CLIPPING 2 + +################## +# Task settings +################## + +# This is the relative path to the directory where this conf file locates +# not a good idea to put data with code +# Are we able to provide a list of dir paths in TRAIN_FILE? +USE_REL_DATA_PATH +TRAIN_FILE ../ExampleRawData/meeting_summarization/ICSI_proprec/train_icsi.json +DEV_FILE ../ExampleRawData/meeting_summarization/ICSI_proprec/valid_icsi.json +TEST_FILE ../ExampleRawData/meeting_summarization/ICSI_proprec/test_icsi.json +ROLE_DICT_FILE ../ExampleRawData/meeting_summarization/role_dict_ext.json + +MINI_BATCH 1 +MAX_PADDING_RATIO 1 +BATCH_READ_AHEAD 10 +DOC_SHUFFLE_BUF_SIZE 10 +SAMPLE_SHUFFLE_BUFFER_SIZE 10 +BATCH_SHUFFLE_BUFFER_SIZE 10 + +MAX_TRANSCRIPT_WORD 8300 +MAX_SENT_LEN 30 +MAX_SENT_NUM 300 + +################## +# Model settings +################## + +DROPOUT 0.1 +VOCAB_DIM 512 +ROLE_SIZE 32 +ROLE_DIM 16 +POS_DIM 16 +ENT_DIM 16 + +USE_ROLE +USE_POSENT + +USE_BOS_TOKEN +USE_EOS_TOKEN + +TRANSFORMER_EMBED_DROPOUT 0.1 +TRANSFORMER_RESIDUAL_DROPOUT 0.1 +TRANSFORMER_ATTENTION_DROPOUT 0.1 +TRANSFORMER_LAYER 6 +TRANSFORMER_HEAD 8 +TRANSFORMER_POS_DISCOUNT 80 + +PRE_TOKENIZER TransfoXLTokenizer +PRE_TOKENIZER_PATH ../ExampleInitModel/transfo-xl-wt103 +PYLEARN_MODEL +# e.g. PYLEARN_MODEL conf_hmnet_ICSI_conf~/run_1/26800 +# PYLEARN_MODEL ../ExampleInitModel/ICSI-finetuned + +################## +# Tokenizer settings +################## + +EXTRA_IDS 1000 + +################## +# Decoding settings +################## + +BEAM_WIDTH 6 +EVAL_TOKENIZED +EVAL_LOWERCASE +MAX_GEN_LENGTH 512 +MIN_GEN_LENGTH 280 +NO_REPEAT_NGRAM_SIZE 3 \ No newline at end of file diff --git a/model/third_party/HMNet/ExampleConf/conf_hmnet_AMI b/model/third_party/HMNet/ExampleConf/conf_hmnet_AMI new file mode 100644 index 0000000000000000000000000000000000000000..d5220f8d05db478224bfdb481b6742d5b1ad79d5 --- /dev/null +++ b/model/third_party/HMNet/ExampleConf/conf_hmnet_AMI @@ -0,0 +1,98 @@ +################## +# Trainer settings +################## + +MODEL MeetingNet_Transformer +TASK HMNet +CRITERION MLECriterion + +SEED 1033 +RESUME + +MAX_NUM_EPOCHS 20 +SAVE_PER_UPDATE _NUM 400 +UPDATES_PER_EPOCH 2000 + +# The actuall learning rate will be multiplied with the number of GPUs +OPTIMIZER RAdam +NO_AUTO_LR_SCALING +START_LEARNING_RATE 1e-3 +LR_SCHEDULER LnrWrmpInvSqRtDcyScheduler +WARMUP_STEPS 16000 +WARMUP_INIT_LR 1e-4 +WARMUP_END_LR 1e-3 + +# The actuall start learning rate equals START_LEARNING_RATE * GRADIENT_ACCUMULATE_STEP +# Model will be updated after every MINI_BATCH * GRADIENT_ACCUMULATE_STEP samples +GRADIENT_ACCUMULATE_STEP 20 + +GRAD_CLIPPING 2 + +################## +# Task settings +################## + +# This is the relative path to the directory where this conf file locates +# not a good idea to put data with code +# Are we able to provide a list of dir paths in TRAIN_FILE? +USE_REL_DATA_PATH +TRAIN_FILE ../ExampleRawData/meeting_summarization/AMI_proprec/train_ami.json +DEV_FILE ../ExampleRawData/meeting_summarization/AMI_proprec/valid_ami.json +TEST_FILE ../ExampleRawData/meeting_summarization/AMI_proprec/test_ami.json +ROLE_DICT_FILE ../ExampleRawData/meeting_summarization/role_dict_ext.json + +MINI_BATCH 1 +MAX_PADDING_RATIO 1 +BATCH_READ_AHEAD 10 +DOC_SHUFFLE_BUF_SIZE 10 +SAMPLE_SHUFFLE_BUFFER_SIZE 10 +BATCH_SHUFFLE_BUFFER_SIZE 10 + +MAX_TRANSCRIPT_WORD 8300 +MAX_SENT_LEN 30 +MAX_SENT_NUM 300 + +################## +# Model settings +################## + +DROPOUT 0.1 +VOCAB_DIM 512 +ROLE_SIZE 32 +ROLE_DIM 16 +POS_DIM 16 +ENT_DIM 16 + +USE_ROLE +USE_POSENT + +USE_BOS_TOKEN +USE_EOS_TOKEN + +TRANSFORMER_EMBED_DROPOUT 0.1 +TRANSFORMER_RESIDUAL_DROPOUT 0.1 +TRANSFORMER_ATTENTION_DROPOUT 0.1 +TRANSFORMER_LAYER 6 +TRANSFORMER_HEAD 8 +TRANSFORMER_POS_DISCOUNT 80 + +PRE_TOKENIZER TransfoXLTokenizer +PRE_TOKENIZER_PATH ../ExampleInitModel/transfo-xl-wt103 +PYLEARN_MODEL ../ExampleInitModel/HMNet-pretrained + +################## +# Tokenizer settings +################## + +EXTRA_IDS 1000 + +################## +# Decoding settings +################## + +BEAM_WIDTH 6 +MAX_GEN_LENGTH 512 +MIN_GEN_LENGTH 320 +EVAL_TOKENIZED +EVAL_LOWERCASE +NO_REPEAT_NGRAM_SIZE 3 \ No newline at end of file diff --git a/model/third_party/HMNet/ExampleConf/conf_hmnet_ICSI b/model/third_party/HMNet/ExampleConf/conf_hmnet_ICSI new file mode 100644 index 0000000000000000000000000000000000000000..e3c46e5bb56f32fdbd8e2b3412ee1a371b446c6a --- /dev/null +++ b/model/third_party/HMNet/ExampleConf/conf_hmnet_ICSI @@ -0,0 +1,98 @@ +################## +# Trainer settings +################## + +MODEL MeetingNet_Transformer +TASK HMNet +CRITERION MLECriterion + +SEED 1033 +RESUME + +MAX_NUM_EPOCHS 20 +SAVE_PER_UPDATE_NUM 400 +UPDATES_PER_EPOCH 2000 + +# The actuall learning rate will be multiplied with the number of GPUs +OPTIMIZER RAdam +NO_AUTO_LR_SCALING +START_LEARNING_RATE 1e-3 +LR_SCHEDULER LnrWrmpInvSqRtDcyScheduler +WARMUP_STEPS 16000 +WARMUP_INIT_LR 1e-4 +WARMUP_END_LR 1e-3 + +# The actuall start learning rate equals START_LEARNING_RATE * GRADIENT_ACCUMULATE_STEP +# Model will be updated after every MINI_BATCH * GRADIENT_ACCUMULATE_STEP samples +GRADIENT_ACCUMULATE_STEP 20 + +GRAD_CLIPPING 2 + +################## +# Task settings +################## + +# This is the relative path to the directory where this conf file locates +# not a good idea to put data with code +# Are we able to provide a list of dir paths in TRAIN_FILE? +USE_REL_DATA_PATH +TRAIN_FILE ../ExampleRawData/meeting_summarization/ICSI_proprec/train_icsi.json +DEV_FILE ../ExampleRawData/meeting_summarization/ICSI_proprec/valid_icsi.json +TEST_FILE ../ExampleRawData/meeting_summarization/ICSI_proprec/test_icsi.json +ROLE_DICT_FILE ../ExampleRawData/meeting_summarization/role_dict_ext.json + +MINI_BATCH 1 +MAX_PADDING_RATIO 1 +BATCH_READ_AHEAD 10 +DOC_SHUFFLE_BUF_SIZE 10 +SAMPLE_SHUFFLE_BUFFER_SIZE 10 +BATCH_SHUFFLE_BUFFER_SIZE 10 + +MAX_TRANSCRIPT_WORD 8300 +MAX_SENT_LEN 30 +MAX_SENT_NUM 300 + +################## +# Model settings +################## + +DROPOUT 0.1 +VOCAB_DIM 512 +ROLE_SIZE 32 +ROLE_DIM 16 +POS_DIM 16 +ENT_DIM 16 + +USE_ROLE +USE_POSENT + +USE_BOS_TOKEN +USE_EOS_TOKEN + +TRANSFORMER_EMBED_DROPOUT 0.1 +TRANSFORMER_RESIDUAL_DROPOUT 0.1 +TRANSFORMER_ATTENTION_DROPOUT 0.1 +TRANSFORMER_LAYER 6 +TRANSFORMER_HEAD 8 +TRANSFORMER_POS_DISCOUNT 80 + +PRE_TOKENIZER TransfoXLTokenizer +PRE_TOKENIZER_PATH ../ExampleInitModel/transfo-xl-wt103 +PYLEARN_MODEL ../ExampleInitModel/HMNet-pretrained + +################## +# Tokenizer settings +################## + +EXTRA_IDS 1000 + +################## +# Decoding settings +################## + +BEAM_WIDTH 6 +MAX_GEN_LENGTH 512 +MIN_GEN_LENGTH 420 +EVAL_TOKENIZED +EVAL_LOWERCASE +NO_REPEAT_NGRAM_SIZE 3 \ No newline at end of file diff --git a/model/third_party/HMNet/ExampleInitModel/AMI-finetuned/README.md b/model/third_party/HMNet/ExampleInitModel/AMI-finetuned/README.md new file mode 100644 index 0000000000000000000000000000000000000000..05cb211af02125ee21b1a677bffd393f5b5f5a1a --- /dev/null +++ b/model/third_party/HMNet/ExampleInitModel/AMI-finetuned/README.md @@ -0,0 +1,3 @@ +# Download the HMNet model finetuned for AMI dataset + +Using the download [link](https://sdrgstorage01wus2.blob.core.windows.net/user/ruox/Meeting_Minutes/HMNet/ExampleInitModel/AMI-finetuned/model.pt?sv=2019-10-10&st=2020-10-22T19%3A25%3A46Z&se=2060-10-23T19%3A25%3A00Z&sr=b&sp=r&sig=VTzk30aQu5KKSgKdW2L9DUYGQyZmns16WnIm%2FifMKZQ%3D) to download the `model.pt` file and put it in this directory. \ No newline at end of file diff --git a/model/third_party/HMNet/ExampleInitModel/HMNet-pretrained/README.md b/model/third_party/HMNet/ExampleInitModel/HMNet-pretrained/README.md new file mode 100644 index 0000000000000000000000000000000000000000..1a9e9d8ebcac1b537a6bd4afc7b01835437e66f2 --- /dev/null +++ b/model/third_party/HMNet/ExampleInitModel/HMNet-pretrained/README.md @@ -0,0 +1,3 @@ +# Download the pretrained HMNet model + +Using the download [link](https://sdrgstorage01wus2.blob.core.windows.net/user/ruox/Meeting_Minutes/HMNet/ExampleInitModel/HMNet-pretrained/model.pt?sv=2019-10-10&st=2020-10-22T19%3A24%3A06Z&se=2060-10-23T19%3A24%3A00Z&sr=b&sp=r&sig=cRfastEaN7s75cgMaBvEFGbXio20smnjjRxxYbqEkoE%3D) to download the `model.pt` file and put it in this directory. \ No newline at end of file diff --git a/model/third_party/HMNet/ExampleInitModel/ICSI-finetuned/README.md b/model/third_party/HMNet/ExampleInitModel/ICSI-finetuned/README.md new file mode 100644 index 0000000000000000000000000000000000000000..4710b19942ff4f86f79321a9f744a3516aa1b382 --- /dev/null +++ b/model/third_party/HMNet/ExampleInitModel/ICSI-finetuned/README.md @@ -0,0 +1,3 @@ +# Download the HMNet model finetuned for ICSI dataset + +Using the download [link](https://sdrgstorage01wus2.blob.core.windows.net/user/ruox/Meeting_Minutes/HMNet/ExampleInitModel/ICSI-finetuned/model.pt?sv=2019-10-10&st=2020-10-24T00%3A10%3A47Z&se=2060-10-25T00%3A10%3A00Z&sr=b&sp=r&sig=9vYc0%2BRRRiWwleywDFGOHqBIzzdQbZ4OnVqeZKsRzyM%3D) to download the `model.pt` file and put it in this directory. \ No newline at end of file diff --git a/model/third_party/HMNet/ExampleInitModel/transfo-xl-wt103/special_tokens_map.json b/model/third_party/HMNet/ExampleInitModel/transfo-xl-wt103/special_tokens_map.json new file mode 100644 index 0000000000000000000000000000000000000000..2422483358bd7e5be1ca6e165279403b08e13c78 --- /dev/null +++ b/model/third_party/HMNet/ExampleInitModel/transfo-xl-wt103/special_tokens_map.json @@ -0,0 +1 @@ +{"eos_token": "", "unk_token": "", "additional_special_tokens": [""]} \ No newline at end of file diff --git a/model/third_party/HMNet/ExampleInitModel/transfo-xl-wt103/tokenizer_config.json b/model/third_party/HMNet/ExampleInitModel/transfo-xl-wt103/tokenizer_config.json new file mode 100644 index 0000000000000000000000000000000000000000..9e26dfeeb6e641a33dae4961196235bdb965b21b --- /dev/null +++ b/model/third_party/HMNet/ExampleInitModel/transfo-xl-wt103/tokenizer_config.json @@ -0,0 +1 @@ +{} \ No newline at end of file diff --git a/model/third_party/HMNet/ExampleInitModel/transfo-xl-wt103/vocab.bin b/model/third_party/HMNet/ExampleInitModel/transfo-xl-wt103/vocab.bin new file mode 100644 index 00000000000000000000000000000000000000 00..65920c897ff38919d3af5cc7780f70cbdf63650d Binary files /dev/null and b/model/third_party/HMNet/ExampleInitModel/transfo-xl-wt103/vocab.bin differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/dev/split_0.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/dev/split_0.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..3698c271ae451154356898ee1a1665bd34101368 Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/dev/split_0.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/dev/split_1.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/dev/split_1.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..4321513f21f5663266c246247a5f0a7d62230f1c Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/dev/split_1.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/dev/split_10.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/dev/split_10.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..ce056cc213431c1ec0dcb32891450a88857abbe3 Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/dev/split_10.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/dev/split_11.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/dev/split_11.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..703ea1735fd013fb8122ea6ebac0cd87c1dc50ee Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/dev/split_11.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/dev/split_12.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/dev/split_12.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..cbf1d42e7c48ea8036b9da3d10ab27167eef5f6b Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/dev/split_12.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/dev/split_13.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/dev/split_13.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..b3380346c3eea221f639b5fab6b131b47d0b2d44 Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/dev/split_13.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/dev/split_14.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/dev/split_14.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..62f32d4285619003c57d4f61c07cf09004279eb4 Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/dev/split_14.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/dev/split_15.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/dev/split_15.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..491873cff7780f303e1a9989aa8589c4a03c6543 Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/dev/split_15.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/dev/split_16.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/dev/split_16.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..bc54af00b9acb16f9d7ec92477bb900663308c85 Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/me eting_summarization/AMI_proprec/dev/split_16.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/dev/split_17.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/dev/split_17.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..7b9654ec8e5ccfd7ed04a8a80901458339b380f6 Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/dev/split_17.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/dev/split_18.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/dev/split_18.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..a05a48f26078620f593a5e070cf0d0bd8e873ac6 Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/dev/split_18.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/dev/split_19.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/dev/split_19.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..f5b576920107e9c0a950a21005a86081efd4cca2 Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/dev/split_19.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/dev/split_2.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/dev/split_2.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..444cf8138a87a7349955bee150326fff975278d8 Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/dev/split_2.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/dev/split_3.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/dev/split_3.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..2d3a5a74808a48b55d4b5dcf1c6d03e31707be92 Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/dev/split_3.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/dev/split_4.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/dev/split_4.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..4317cdafa8f8f49abb650e0a12e5bb437ed94b04 Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/dev/split_4.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/dev/split_5.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/dev/split_5.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..5613b5790e3f0fd2a9299a34a6f55d94e13c17fd Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/dev/split_5.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/dev/split_6.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/dev/split_6.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..90527416b7c9d8f97c81de7be78713b28eadc8cc Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/dev/split_6.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/dev/split_7.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/dev/split_7.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..67db2159295639c1f20441369376afee99104680 Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/dev/split_7.jsonl.gz differ diff --git a/model/third_party/HMNet/Ex ampleRawData/meeting_summarization/AMI_proprec/dev/split_8.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/dev/split_8.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..ed7d69511510fed625a3ad429a8945b2742524b6 Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/dev/split_8.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/dev/split_9.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/dev/split_9.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..fb221ed0ade09ce477af0046d069dc3544e1376f Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/dev/split_9.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/test/split_0.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/test/split_0.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..8ed2a582ce10f4b064ad6f2f4d8494a73f3081e2 Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/test/split_0.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/test/split_1.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/test/split_1.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..77adb51a09ec682fbd2d86e3fecc2079555cc10c Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/test/split_1.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/test/split_10.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/test/split_10.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..6a7a8ade0955437ba0e3a0730d154d82224c613c Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/test/split_10.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/test/split_11.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/test/split_11.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..5fdc1eb6766bb717bd19e1f0e65044182eb4ccab Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/test/split_11.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/test/split_12.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/test/split_12.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..9cf82e2e0eafc62506b391ff5edbe2c4ea81a089 Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/test/split_12.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/test/split_13.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/test/split_13.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..39a34f1a4e5ba746599970f5195c802611b73457 Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/test/split_13.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/test/split_14.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/test/split_14.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..be1db9359ee204cb0c11af229dadd4b3cb839dd6 Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/test/split_14.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/test/split_15.jsonl.gz b/m odel/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/test/split_15.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..1d638765c7fce8ca6c86cb936f0418bd65f550f7 Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/test/split_15.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/test/split_16.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/test/split_16.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..65eb693d3f33723d5a004d95bc2d80c84533ac63 Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/test/split_16.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/test/split_17.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/test/split_17.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..d37ecacaa9ad83fc3a04a2ec04b83ee116f95ee9 Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/test/split_17.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/test/split_18.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/test/split_18.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..a7e97cd28d2864249054146ff0c6d29f91d4991d Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/test/split_18.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/test/split_19.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/test/split_19.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..4ab026f0a269d3b452e2bb262d8fbdab435c7954 Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/test/split_19.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/test/split_2.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/test/split_2.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..1e394a66254c32d1bfe4f61ad8b5dc314b76bac8 Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/test/split_2.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/test/split_3.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/test/split_3.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..64b10d5ddaa52cadccaa4678d1a878b9c28fa71e Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/test/split_3.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/test/split_4.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/test/split_4.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..48597968dd42895369156af0aa53b36dd7fcd4ff Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/test/split_4.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/test/split_5.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/test/split_5.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..79ffed9951614260327117b78fc07e6f07f2e4a7 Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/test/split_5.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/test/split_6.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_prop rec/test/split_6.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..63d32b86d7e0bebfa0b6aa447d4a87fc78d58eee Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/test/split_6.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/test/split_7.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/test/split_7.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..f2c288a32f713fba1f8871aa0620f785d58d510a Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/test/split_7.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/test/split_8.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/test/split_8.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..536223f1d5b85ec235e741cc2af97727a3d20b79 Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/test/split_8.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/test/split_9.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/test/split_9.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..e08e2067c9b0218a7cfe313ad3afced99318c3a4 Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/test/split_9.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/test_ami.json b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/test_ami.json new file mode 100644 index 0000000000000000000000000000000000000000..76a8fecc54e8aa12d4de808d5278cb0b53e4f0d5 --- /dev/null +++ b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/test_ami.json @@ -0,0 +1,10 @@ +[ + { + "source": + { + "dataset": "../ExampleRawData/meeting_summarization/AMI_proprec/test/" + }, + "task": "meeting", + "name": "ami" + } +] diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_0.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_0.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..8177850d6ddd4378f6c4c954c9c71c3c4fca96d5 Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_0.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_1.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_1.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..34c65fe89c170f68ba81db58f6e0dba504407a7d Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_1.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_10.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_10.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..706475b5bb31320a158a80dce87a57e8b1b2362a Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_10.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_11.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_11.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..d32ef99416ca25df2aacf4d095d62286151a0bc4 Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_11.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summariza tion/AMI_proprec/train/split_12.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_12.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..3f417c4a880d11a89fb8b66013f6a8bf90a55be3 Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_12.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_13.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_13.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..30f554db05328a0a3c189a1983690fb1bebc842f Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_13.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_14.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_14.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..6ea791ea5398a48aa38af4997e97de60235220ff Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_14.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_15.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_15.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..e872918a1fc17d98f0bdd4e4175dd69d5151ce5c Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_15.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_16.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_16.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..ea8d2e321d7951a9c6f2b56b4fcca5a5991b1f19 Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_16.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_17.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_17.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..9becb7a4081d03d89701d10cb37ff29ee28505c6 Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_17.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_18.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_18.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..09e3756a98727a6614397c0aa0e5c9da99977f1c Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_18.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_19.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_19.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..426bd28bd13c06fa5a3d92fd57b16c4ff7e0f094 Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_19.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_2.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_2.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..ca3d71e8e3c5eeddad397d00f8057eb1d199adb7 Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_2.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_20 .jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_20.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..3678fdb8ccf5ea8c02d4ee49e2264c529a2f4d54 Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_20.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_21.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_21.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..da94cb8bc8c690c64088158b3eb73ba369ab5325 Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_21.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_22.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_22.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..29af75a1b43a6ac0ec07786ad1af2319913b6554 Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_22.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_23.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_23.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..213c494fdc61aa0df696aa3f1447947d85e63b0d Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_23.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_24.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_24.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..b3c2d9da7bde6f347204095c7fd4f044ee46a3bf Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_24.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_25.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_25.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..0395fc25df6cc6eaf7d8716bf58dd903d4dda602 Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_25.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_26.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_26.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..3de1059b75aecce2216c360a4ce324222d029329 Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_26.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_27.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_27.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..276238a9635e61cbbe8f9503ac5f7f39a1f62cc9 Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_27.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_28.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_28.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..30393b4dc323ab97213abbff07c7ddd5bd3d2ae6 Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_28.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_29.jsonl.gz b/model/third_part y/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_29.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..851e0780d6d5bfda79310eb4131767fca50c986a Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_29.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_3.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_3.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..bf76234f14948e762a7eece5531002229a0d62fa Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_3.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_30.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_30.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..4c36050dbba443fe8acc1448bfee34631f44969f Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_30.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_31.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_31.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..c78adedbdedf4029d517762b29e73a2f08bcdf3b Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_31.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_4.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_4.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..86defa2c5bfdcb062abc924209e8a4649512e770 Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_4.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_5.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_5.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..d032891e41b0c8bfa9ebef7c9dd33f3d566ed7bc Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_5.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_6.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_6.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..d170ab349a9a4cf1aff9864532106218854a82f9 Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_6.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_7.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_7.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..856ed340f3e22e28212ed7bb0a842a801ade8f59 Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_7.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_8.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_8.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..4cd094567d6795db1901e26bc2d1c18a616f4580 Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_8.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_9.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AM I_proprec/train/split_9.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..62340aabb094d21d63037c085e32a3795abc30a5 Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train/split_9.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train_ami.json b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train_ami.json new file mode 100644 index 0000000000000000000000000000000000000000..72ed25b70f63f3d8f6110bd75e3724e48b744a69 --- /dev/null +++ b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/train_ami.json @@ -0,0 +1,10 @@ +[ + { + "source": + { + "dataset": "../ExampleRawData/meeting_summarization/AMI_proprec/train/" + }, + "task": "meeting", + "name": "ami" + } +] diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/valid_ami.json b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/valid_ami.json new file mode 100644 index 0000000000000000000000000000000000000000..0df95b9f7caa98c59afc564014ec7ddb46242329 --- /dev/null +++ b/model/third_party/HMNet/ExampleRawData/meeting_summarization/AMI_proprec/valid_ami.json @@ -0,0 +1,10 @@ +[ + { + "source": + { + "dataset": "../ExampleRawData/meeting_summarization/AMI_proprec/dev/" + }, + "task": "meeting", + "name": "ami" + } +] diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/dev/split_0.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/dev/split_0.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..77dc09ab7e969820e94bd643ad5205128720a949 Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/dev/split_0.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/dev/split_1.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/dev/split_1.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..9f5179d7a4239eea8a016d05e1793120fcef4c8a Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/dev/split_1.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/dev/split_2.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/dev/split_2.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..c5e8bedb7435b27265a9c94924438403d53068c0 Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/dev/split_2.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/dev/split_3.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/dev/split_3.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..d15799954c08990b8067b85272005c3e4812305f Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/dev/split_3.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/dev/split_4.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/dev/split_4.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..c980fb9e1a638a817bd3ca72c2bceab0750db7ed Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/dev/split_4.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/dev/split_5.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/dev/split_5.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..cf50f8c10e6e952e29c5059f0af7d324b678b8c7 Binary files /dev/null and b/m odel/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/dev/split_5.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/dev/split_6.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/dev/split_6.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..b9f7f2b624475a12a1ee29187568a3ac6085483f Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/dev/split_6.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/dev/split_7.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/dev/split_7.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..2ec2052a275b09e8a22199ef7ddd8f6d8eeb411b Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/dev/split_7.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/dev/split_8.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/dev/split_8.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..6484bea1db119bec1236f3560ca6c09a0707893a Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/dev/split_8.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/dev/split_9.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/dev/split_9.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..477770d8b8265b0eb24e874566769bc79ef1f0ee Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/dev/split_9.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/test/split_0.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/test/split_0.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..e90093e40dd780d125804197857a1f28f0c0ee5c Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/test/split_0.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/test/split_1.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/test/split_1.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..0353a802f6512f69834c7d96a4eff9dc7d3102fe Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/test/split_1.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/test/split_2.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/test/split_2.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..26a3a30ea8b80bebd60985490d3fc57d3eb1a122 Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/test/split_2.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/test/split_3.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/test/split_3.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..b1951576cd579d4f8393dc05992def4e0c0aeecf Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/test/split_3.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/test/split_4.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/test/split_4.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..58dc5186bfe06aaf7db5f6b880c86ad45380e949 Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_pro prec/test/split_4.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/test/split_5.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/test/split_5.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..e6a41599b12798ebf75d837fb39ebda28e5237c4 Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/test/split_5.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/test_icsi.json b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/test_icsi.json new file mode 100644 index 0000000000000000000000000000000000000000..455fdc72358115205049cd226ade793c355df968 --- /dev/null +++ b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/test_icsi.json @@ -0,0 +1,10 @@ +[ + { + "source": + { + "dataset": "../ExampleRawData/meeting_summarization/ICSI_proprec/test/" + }, + "task": "meeting", + "name": "icsi" + } +] diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_0.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_0.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..1e8dc32e147a534a31b18b2fce9abbc9ff034b82 Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_0.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_1.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_1.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..d151c8cac9e0de765c2e9bbba28974fbe70b4655 Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_1.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_10.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_10.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..fc92648421c511a14ccf55652dc6b274c9f5f247 Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_10.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_11.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_11.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..c2128baf418dd08511aabb3aa87901acbc10a23c Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_11.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_12.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_12.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..c1e463ff774b4c0e99795e7a8f3d95074b339bd0 Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_12.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_13.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_13.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..7c8253329544ee2602cccd30313951da55bf5d09 Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_13.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_14.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_14.jsonl.gz new file mode 100644 index 00000000000000000000000 00000000000000000..881a5005939df791508779659e814300bdbee8a3 Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_14.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_15.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_15.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..c5edd3e3ab2a0892f108582edcbef68fd7d4b255 Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_15.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_16.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_16.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..0ab3f38105ba124ba32dc5af5c5e97f62b7f19c4 Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_16.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_17.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_17.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..96fb89fd632f87d9f5f77955caa68045b79b85bb Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_17.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_18.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_18.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..365cbef95534127fd9108e2a0bdf904d639f820c Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_18.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_19.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_19.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..f05b177f8a80313c77a16d7625a4d4a9f351ffda Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_19.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_2.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_2.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..283b71c41da3bb24089930b3808456a4d2a8eb07 Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_2.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_20.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_20.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..62c4419ad474eabb1de5b548d0c3217e13bad466 Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_20.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_21.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_21.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..dae2e357aec7755b62929f497a7e230e303a2b5b Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_21.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_22.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_22.jsonl.gz new file mode 100644 index 000000000000000000000000000 0000000000000..783495cf010f0e16be1cb94d1809d103b0a715f0 Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_22.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_23.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_23.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..46254513d9373e81960128a31b2728ae8f726faa Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_23.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_24.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_24.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..f45917318656173cb20e3ffb96f10d669636ded3 Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_24.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_25.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_25.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..284dd54583b003f08598f59b6b754aec2dd3963c Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_25.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_26.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_26.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..e42c016b9105fb88706299d0d3ad8d320e647897 Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_26.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_27.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_27.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..053ac2c4f7747663aac2f62d4df3996f7df39866 Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_27.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_28.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_28.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..064968464947fbbe4d8d5e8d1d6dbac25bda97ac Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_28.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_29.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_29.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..fe6fbb9df2ac49b62b2b4a4ea58d489625875378 Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_29.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_3.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_3.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..2addf7c08ab6eb150af83c464c482719c55ef7cd Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_3.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_30.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_30.jsonl.gz new file mode 100644 index 0000000000000000000000000000000 000000000..03f0e12bece930d5d942553974223f23990d01db Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_30.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_31.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_31.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..9426ab39006a09392c4fc9d6989ace349e4c7cbe Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_31.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_4.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_4.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..36d23ee48110fa87b05ff3464a9fbea40d62c68d Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_4.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_5.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_5.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..c8dfd1dc2415faf7923fe63016ed95bb6ac9202c Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_5.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_6.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_6.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..e30b51290a1605b87cf621c45e9bf6095c37405f Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_6.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_7.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_7.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..1d84f31d9ad6e3dacdee4f805cbfc9940413ab3c Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_7.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_8.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_8.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..954197d10de625a4a963962eb789cbbdb8017213 Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_8.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_9.jsonl.gz b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_9.jsonl.gz new file mode 100644 index 0000000000000000000000000000000000000000..62160070f8a547cf0d777ce0487e6f6096095ab0 Binary files /dev/null and b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train/split_9.jsonl.gz differ diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train_icsi.json b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train_icsi.json new file mode 100644 index 0000000000000000000000000000000000000000..c771d70e4eaeb96c628ea1a7a8d744142fed8fd3 --- /dev/null +++ b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/train_icsi.json @@ -0,0 +1,10 @@ +[ + { + "source": + { + "dataset": "../ExampleRawData/meeting_summarization/ICSI_proprec/train/" + }, + "task": "meeting", + "name": "icsi" + } +] diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/valid_icsi.json b/model/ third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/valid_icsi.json new file mode 100644 index 0000000000000000000000000000000000000000..5ab37d2d95dcd06d832919535f6169c4a493ee8e --- /dev/null +++ b/model/third_party/HMNet/ExampleRawData/meeting_summarization/ICSI_proprec/valid_icsi.json @@ -0,0 +1,10 @@ +[ + { + "source": + { + "dataset": "../ExampleRawData/meeting_summarization/ICSI_proprec/dev/" + }, + "task": "meeting", + "name": "icsi" + } +] diff --git a/model/third_party/HMNet/ExampleRawData/meeting_summarization/role_dict_ext.json b/model/third_party/HMNet/ExampleRawData/meeting_summarization/role_dict_ext.json new file mode 100644 index 0000000000000000000000000000000000000000..b0706aac0840bbaf69cdfc0c2da64eab7fa68164 --- /dev/null +++ b/model/third_party/HMNet/ExampleRawData/meeting_summarization/role_dict_ext.json @@ -0,0 +1,38 @@ +{ + "": 0, + "PM": 1, + "ID": 2, + "UI": 3, + "ME": 4, + "Grad": 5, + "Professor": 6, + "Postdoc": 7, + "PhD": 8, + "cnn": 9, + "xsum": 10, + "nyt": 11, + "cnn-0": 12, + "cnn-1": 13, + "cnn-2": 14, + "cnn-3": 15, + "cnn-4": 16, + "cnn-5": 17, + "cnn-6": 18, + "cnn-7": 19, + "xsum-0": 20, + "xsum-1": 21, + "xsum-2": 22, + "xsum-3": 23, + "xsum-4": 24, + "xsum-5": 25, + "xsum-6": 26, + "xsum-7": 27, + "nyt-0": 28, + "nyt-1": 29, + "nyt-2": 30, + "nyt-3": 31, + "nyt-4": 32, + "nyt-5": 33, + "nyt-6": 34, + "nyt-7": 35 +} \ No newline at end of file diff --git a/model/third_party/HMNet/Models/Criteria/MLECriterion.py b/model/third_party/HMNet/Models/Criteria/MLECriterion.py new file mode 100644 index 0000000000000000000000000000000000000000..c92da7ffb610ee94efb15620f402d7d5ffdfbfc1 --- /dev/null +++ b/model/third_party/HMNet/Models/Criteria/MLECriterion.py @@ -0,0 +1,40 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import sys +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class MLECriterion(nn.Module): + """ + Class to define loss give input, model output and groundtruth + """ + + def __init__(self, opt, module): + super().__init__() + self.opt = opt + self.ignore_index = ( + self.opt["IGNORE_INDEX"] + if "IGNORE_INDEX" in self.opt + else module.tokenizer.pad_token_id + ) + + def forward(self, vocab_logprob, batch): + extended_vocab_size = vocab_logprob.shape[2] + y = batch["decoder_input_ids"] + + if "USE_BOS_TOKEN" in self.opt: + y = y[:, 1:] + + if "USE_EOS_TOKEN" in self.opt: + vocab_logprob = vocab_logprob[:, :-1, :] + + loss = F.nll_loss( + vocab_logprob.contiguous().view(-1, extended_vocab_size), + y.contiguous().view(-1), + ignore_index=self.ignore_index, + ) + + return loss diff --git a/model/third_party/HMNet/Models/Networks/Layers.py b/model/third_party/HMNet/Models/Networks/Layers.py new file mode 100644 index 0000000000000000000000000000000000000000..3bdf090fded691eedfd86905da1570500e30adf0 --- /dev/null +++ b/model/third_party/HMNet/Models/Networks/Layers.py @@ -0,0 +1,48 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import os +import random +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable +import torch.nn.init as init +from torch.nn.parameter import Parameter +from torch.nn.utils.rnn import pad_packed_sequence as unpack +from torch.nn.utils.rnn import pack_padded_sequence as pack + + +def set_dropout_prob(p): + global dropout_p + dropout_p = p + + +def set_seq_dropout(option): # option = True or False + global do_seq_dropout + do_seq_dropout = option + + +def seq_dropout(x, p=0, training=False): + """ + x: batch * len * input_size + """ + if training == False or p == 0: + return x + dropout_mask = Variable( + 1.0 + / ( 1 - p) + * torch.bernoulli((1 - p) * (x.data.new(x.size(0), x.size(2)).zero_() + 1)), + requires_grad=False, + ) + return dropout_mask.unsqueeze(1).expand_as(x) * x + + +def dropout(x, p=0, training=False): + """ + x: (batch * len * input_size) or (any other shape) + """ + if do_seq_dropout and len(x.size()) == 3: # if x is (batch * len * input_size) + return seq_dropout(x, p=p, training=training) + else: + return F.dropout(x, p=p, training=training) diff --git a/model/third_party/HMNet/Models/Networks/MeetingNet_Transformer.py b/model/third_party/HMNet/Models/Networks/MeetingNet_Transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..f4e3e33c18b65e84a7b360aa1c5267051a586916 --- /dev/null +++ b/model/third_party/HMNet/Models/Networks/MeetingNet_Transformer.py @@ -0,0 +1,1528 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import copy +import math +import numpy as np +import random +import time +import torch +from torch.autograd import Variable +from torch.distributions import Categorical +import torch.nn as nn +import torch.nn.init as init +import torch.nn.functional as F +from torch.nn.parameter import Parameter +from model.third_party.HMNet.Models.Networks.Layers import dropout, set_seq_dropout +from model.third_party.HMNet.Models.Networks.Transformer import ( + EncoderBlock, + LayerNorm, + Embedder, + Splitter, + Attention, + MLP, +) +from model.third_party.HMNet.ThirdParty.Huggingface.Transformers.src.transformers import ( + tokenization_transfo_xl, +) +from model.third_party.HMNet.ThirdParty.Huggingface.Transformers.src.transformers.modeling_encoder_decoder import ( + calc_banned_ngram_tokens, + calc_banned_bad_words_ids, + top_k_top_p_filtering, + BeamHypotheses, +) +import sys +import os + +# These two dicts are adapted from SpaCy 2.3.1, since HMNet's embedding for POS and ENT is fixed +POS = { + "": 0, + "$": 1, + "''": 2, + ",": 3, + "-LRB-": 4, + "-RRB-": 5, + ".": 6, + ":": 7, + "ADD": 8, + "AFX": 9, + "CC": 10, + "CD": 11, + "DT": 12, + "EX": 13, + "FW": 14, + "HYPH": 15, + "IN": 16, + "JJ": 17, + "JJR": 18, + "JJS": 19, + "LS": 20, + "MD": 21, + "NFP": 22, + "NN": 23, + "NNP": 24, + "NNPS": 25, + "NNS": 26, + "PDT": 27, + "POS": 28, + "PRP": 29, + "PRP$": 30, + "RB": 31, + "RBR": 32, + "RBS": 33, + "RP": 34, + "SYM": 35, + "TO": 36, + "UH": 37, + "VB": 38, + "VBD": 39, + "VBG": 40, + "VBN": 41, + "VBP": 42, + "VBZ": 43, + "WDT": 44, + "WP": 45, + "WP$": 46, + "WRB": 47, + "XX": 48, + "_SP": 49, + "``": 50, +} +ENT = { + "": 0, + "B-ORG": 1, + "B-DATE": 2, + "B-PERSON": 3, + "B-GPE": 4, + "B-MONEY": 5, + "B-CARDINAL": 6, + "B-NORP": 7, + "B-PERCENT": 8, + "B-WORK_OF_ART": 9, + "B-LOC": 10, + "B-TIME": 11, + "B-QUANTITY": 12, + "B-FAC": 13, + "B-EVENT": 14, + "B-ORDINAL": 15, + "B-PRODUCT": 16, + "B-LAW": 17, + "B-LANGUAGE": 18, + "I-ORG": 19, + "I-DATE": 20, + "I-PERSON": 21, + "I-GPE": 22, + "I-MONEY": 23, + "I-CARDINAL": 24, + "I-NORP": 25, + "I-PERCENT": 26, + "I-WORK_OF_ART": 27, + "I-LOC": 28, + "I-TIME": 29, + "I-QUANTITY": 30, + "I-FAC": 31, + "I-EVENT": 32, + "I-ORDINAL": 33, + "I-PRODUCT": 34, + "I-LAW": 35, + "I-LANGUAGE": 36, + "L-ORG": 37, + "L-DATE": 38, + "L-PERSON": 39, + "L-GPE": 40, + "L-MONEY": 41, + "L-CARDINAL": 42, + "L-NORP": 43, + "L-PERCENT": 44, + "L-WORK_OF_ART": 45, + "L-LOC": 46, + "L-TIME": 47, + "L-QUANTITY": 48, + "L-FAC": 49, + "L-EVENT": 50, + "L-ORDINAL": 51, + "L-PRODUCT": 52, + "L-LAW": 53, + "L-LANGUAGE": 54, + "U-ORG": 55, + "U-DATE": 56, + "U-PERSON": 57, + "U-GPE": 58, + "U-MONEY": 59, + "U-CARDINAL": 60, + "U-NORP": 61, + "U-PERCENT": 62, + "U-WORK_OF_ART": 63, + "U-LOC": 64, + "U-TI ME": 65, + "U-QUANTITY": 66, + "U-FAC": 67, + "U-EVENT": 68, + "U-ORDINAL": 69, + "U-PRODUCT": 70, + "U-LAW": 71, + "U-LANGUAGE": 72, + "O": 73, +} + + +class MeetingNet_Transformer(nn.Module): + def __init__(self, opt): + super(MeetingNet_Transformer, self).__init__() + + self.opt = opt + self.use_cuda = self.opt["cuda"] == True + self.config = {} + + # load tokenizer + self.tokenizer_class = getattr(tokenization_transfo_xl, opt["PRE_TOKENIZER"]) + self.pretrained_tokenizer_path = os.path.join( + opt["datadir"], opt["PRE_TOKENIZER_PATH"] + ) + if not os.path.isdir(self.pretrained_tokenizer_path): + """ + This if-else statement makes sure the pre-trained tokenizer exists + If it does not exist, it assumes the input string is the HuggingFace tokenizer name, + and downloads it from their website. + """ + self.pretrained_tokenizer_path = opt["PRE_TOKENIZER_PATH"] + else: + print("Loading Tokenizer from {}...".format(self.pretrained_tokenizer_path)) + + # here is a simple workaround to make sure all special tokens are not None + self.tokenizer = self.tokenizer_class.from_pretrained( + self.pretrained_tokenizer_path + ) + special_tokens_tuple_list = [ + ("eos_token", 128), + ("unk_token", 129), + ("pad_token", 130), + ("bos_token", 131), + ] + + for special_token_name, special_token_id_offset in special_tokens_tuple_list: + if getattr(self.tokenizer, special_token_name) == None: + setattr( + self.tokenizer, + special_token_name, + self.tokenizer.convert_ids_to_tokens( + len(self.tokenizer) - special_token_id_offset + ), + ) + self.config[special_token_name] = self.tokenizer.convert_ids_to_tokens( + len(self.tokenizer) - special_token_id_offset + ) + self.config[special_token_name + "_id"] = ( + len(self.tokenizer) - special_token_id_offset + ) + + self.vocab_size = self.tokenizer.vocab_size + opt["vocab_size"] = self.vocab_size + self.role_size = int(opt["ROLE_SIZE"]) + vocab_dim = int(opt["VOCAB_DIM"]) + role_dim = int(opt["ROLE_DIM"]) + opt["transformer_embed_dim"] = vocab_dim + embed = nn.Embedding( + self.vocab_size, vocab_dim, padding_idx=self.tokenizer.pad_token_id + ) + nn.init.normal_(embed.weight, std=0.02) + embedder = Embedder(opt, embed) + role_embed = nn.Embedding(self.role_size, role_dim, padding_idx=0) + + self.encoder = Encoder( + opt, self.vocab_size, vocab_dim, role_dim, embedder, role_embed + ) + self.decoder = Decoder( + opt, + vocab_dim, + self.vocab_size, + embedder, + self.encoder.token_transformer_dim, + self.encoder.sent_transformer_dim, + ) + + if "PYLEARN_MODEL" in self.opt: + self.from_pretrained(os.path.join(opt["datadir"], opt["PYLEARN_MODEL"])) + + def save_pretrained(self, save_dir): + network_state = dict([(k, v) for k, v in self.state_dict().items()]) + params = { + "state_dict": {"network": network_state}, + "config": self.opt, + } + torch.save(params, os.path.join(save_dir, "model.pt")) + + def from_pretrained(self, load_dir): + checkpoint = torch.load( + os.path.join(load_dir, "model.pt"), + map_location=torch.device("cuda", self.opt["local_rank"]) + if self.use_cuda + else "cpu", + ) + state_dict = checkpoint["state_dict"] + + self.load_state_dict(state_dict["network"]) + + return self + + def get_training_parameters(self): + ret urn [p for p in self.parameters() if p.requires_grad] + + def forward(self, batch, beam_search=False, max_sent_len=None): + if beam_search: + # return self.beam_search(batch, max_sent_len) + return self.generate(batch, max_sent_len) + + outputs = self._forward(**batch) + vocab_logprob = outputs[0] + + # assume all encoder-decoder model input has BOS and EOS + # otherwise the loss will be ill-defined + return vocab_logprob + + """ + Input: + encoders_input_ids = 1 * num_turns * x_len (word_ids) + encoders_input_roles = 1 * num_turns (role_ids) + encoders_input_pos = 1 * num_turns * x_len (pos_ids) + encoders_input_ent = 1 * num_turns * x_len (ent_ids) + decoder_input_ids = 1 * y_len (word_ids) + Output: + vocab_logprob = 1 x y_len x vocab_size + """ + + def _forward(self, **kwargs): + + encoder_input_ids = kwargs.pop("encoder_input_ids") + encoder_input_roles = kwargs.pop("encoder_input_roles") + encoder_input_pos = kwargs.pop("encoder_input_pos") + encoder_input_ent = kwargs.pop("encoder_input_ent") + decoder_input_ids = kwargs.pop("decoder_input_ids") + + token_encoder_outputs, sent_encoder_outputs = self.encoder( + encoder_input_ids, encoder_input_roles, encoder_input_pos, encoder_input_ent + ) + vocab_logprob = self.decoder( + token_encoder_outputs, sent_encoder_outputs, decoder_input_ids + ) + return vocab_logprob, (token_encoder_outputs, sent_encoder_outputs) + + def generate(self, batch, max_sent_len): + self.eval() + self.beam_width = int(self.opt["BEAM_WIDTH"]) + + input_ids = batch["encoder_input_ids"] + input_roles = batch["encoder_input_roles"] + input_pos = batch["encoder_input_pos"] + input_ent = batch["encoder_input_ent"] + + batch_size = input_ids.shape[0] + + num_return_sequences = self.opt.get("NUM_RETURN_SEQUENCES", 1) + outputs = self._generate( + input_ids=input_ids, + input_roles=input_roles, + input_pos=input_pos, + input_ent=input_ent, + min_length=self.opt.get("MIN_GEN_LENGTH", None), + max_length=max_sent_len, + num_beams=self.beam_width, + bad_words_ids=None, + bos_token_id=self.tokenizer.bos_token_id, + decoder_start_token_id=self.tokenizer.bos_token_id, + eos_token_id=self.tokenizer.eos_token_id, + pad_token_id=self.tokenizer.pad_token_id, + do_sample=self.opt.get("DO_SAMPLE", False), + top_k=self.opt.get("TOP_K", 50), + top_p=self.opt.get("TOP_P", 1), + repetition_penalty=self.opt.get("REPETITION_PENALTY", 1.0), + length_penalty=self.opt.get("LENGTH_PENALTY", 1.0), + no_repeat_ngram_size=self.opt.get("NO_REPEAT_NGRAM_SIZE", 3), + num_return_sequences=num_return_sequences, + ) + + sents = [] + outputs = outputs.view(outputs.shape[0], num_return_sequences, -1) + + for idx in range(batch_size): + # TODO: use real inference scores + candidates = [ + (self.tokenizer.convert_ids_to_tokens(outputs[idx, i, :]), 0.0) + for i in range(num_return_sequences) + ] + sents.append(candidates) + + return sents + + def prepare_inputs_for_generation(self, input_ids, past, attention_mask, **kwargs): + assert past is not None, "past has to be defined for encoder_outputs" + + # first step + if type(past) is tuple: + encoder_outputs = past + else: + encoder_outputs = (past,) + + return { + "decoder_input_ids": input_ids, + "token_encoder_outputs": encoder_outputs[0], + "sent_encoder_outputs": encoder_outputs[1], + } + + def prepare_scores_for_generation(self, scores, **kwargs): + return scores + + def enforce_repetition_penalty_( + self, lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty + ): + """repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858).""" + for i in range(batch_size * num_beams): + for previous_token in set(prev_output_tokens[i].tolist()): + # if score < 0 then repetition penalty has to multiplied to reduce the previous token probability + if lprobs[i, previous_token] < 0: + lprobs[i, previous_token] *= repetition_penalty + else: + lprobs[i, previous_token] /= repetition_penalty + + @torch.no_grad() + def _generate( + self, + input_ids=None, + input_roles=None, + input_pos=None, + input_ent=None, + max_length=None, + min_length=None, + do_sample=None, + early_stopping=False, + num_beams=None, + temperature=1.0, + top_k=None, + top_p=None, + repetition_penalty=None, + bad_words_ids=None, + bos_token_id=None, + pad_token_id=None, + eos_token_id=None, + length_penalty=None, + no_repeat_ngram_size=None, + num_return_sequences=None, + attention_mask=None, + decoder_start_token_id=None, + ): + r"""Generates sequences for models with a LM head. The method currently supports greedy decoding, beam-search decoding, sampling with temperature, sampling with top-k or nucleus sampling. + + Adapted in part from `Facebook's XLM beam search code`_. + + .. _`Facebook's XLM beam search code`: + https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529 + + + Parameters: + + input_ids: (`optional`) `torch.LongTensor` of shape `(batch_size, sequence_length)` + The sequence used as a prompt for the generation. If `None` the method initializes + it as an empty `torch.LongTensor` of shape `(1,)`. + + max_length: (`optional`) int + The max length of the sequence to be generated. Between `min_length` and infinity. Default to 20. + + min_length: (`optional`) int + The min length of the sequence to be generated. Between 0 and infinity. Default to 0. + + do_sample: (`optional`) bool + If set to `False` greedy decoding is used. Otherwise sampling is used. Defaults to `False` as defined in `configuration_utils.PretrainedConfig`. + + early_stopping: (`optional`) bool + if set to `True` beam search is stopped when at least `num_beams` sentences finished per batch. Defaults to `False` as defined in `configuration_utils.PretrainedConfig`. + + num_beams: (`optional`) int + Number of beams for beam search. Must be between 1 and infinity. 1 means no beam search. Default to 1. + + temperature: (`optional`) float + The value used to module the next token probabilities. Must be strictly positive. Default to 1.0. + + top_k: (`optional`) int + The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50. + + top_p: (`optional`) float + The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1. + + repetition_penalty: (`optional`) float + The parameter for repetition penalty. Between 1.0 and infinity. 1.0 means no penalty. Default to 1.0. + + pad_token_id: (`optional`) int + Padding token. Default to specicic model pad_token_id or None if it does not exist. + + bos_token_id: (`optional`) int + BOS token. Defaults to `bos_token_id` as defined in the models config. + + eos_token_id: (`optional`) int + EOS token. Defaults to `eos_token_id` as d efined in the models config. + + length_penalty: (`optional`) float + Exponential penalty to the length. Default to 1. + + no_repeat_ngram_size: (`optional`) int + If set to int > 0, all ngrams of size `no_repeat_ngram_size` can only occur once. + bad_words_ids: (`optional`) list of lists of int + `bad_words_ids` contains tokens that are not allowed to be generated. In order to get the tokens of the words that should not appear in the generated text, use `tokenizer.encode(bad_word, add_prefix_space=True)`. + + num_return_sequences: (`optional`) int + The number of independently computed returned sequences for each element in the batch. Default to 1. + + attention_mask (`optional`) obj: `torch.LongTensor` of same shape as `input_ids` + Mask to avoid performing attention on padding token indices. + Mask values selected in ``[0, 1]``: + ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. + Defaults to `None`. + + `What are attention masks? <../glossary.html#attention-mask>`__ + + decoder_start_token_id=None: (`optional`) int + If an encoder-decoder model starts decoding with a different token than BOS. + Defaults to `None` and is changed to `BOS` later. + + Return: + + output: `torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)` + sequence_length is either equal to max_length or shorter if all batches finished early due to the `eos_token_id` + + Examples:: + + tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer + model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache. + outputs = model.generate(max_length=40) # do greedy decoding + print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True))) + + tokenizer = AutoTokenizer.from_pretrained('openai-gpt') # Initialize tokenizer + model = AutoModelWithLMHead.from_pretrained('openai-gpt') # Download model and configuration from S3 and cache. + input_context = 'The dog' + input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context + outputs = model.generate(input_ids=input_ids, num_beams=5, num_return_sequences=3, temperature=1.5) # generate 3 independent sequences using beam search decoding (5 beams) with sampling from initial context 'The dog' + for i in range(3): # 3 output sequences were generated + print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True))) + + tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer + model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache. + input_context = 'The dog' + input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context + outputs = model.generate(input_ids=input_ids, max_length=40, temperature=0.7, num_return_sequences=3) # 3 generate sequences using by sampling + for i in range(3): # 3 output sequences were generated + print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True))) + + tokenizer = AutoTokenizer.from_pretrained('ctrl') # Initialize tokenizer + model = AutoModelWithLMHead.from_pretrained('ctrl') # Download model and configuration from S3 and cache. + input_context = 'Legal My neighbor is' # "Legal" is one of the control codes for ctrl + input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context + outputs = model.generate(input_ids=input_ids, max_length=50, temperature=0.7, repetition_penalty=1.2) # generate sequences + print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True))) + + tokenizer = AutoTokenizer.from_pretrained('gpt2') # Initialize tokenizer + model = AutoModelWithLMHead.from_pretrained('gpt2') # Download model and configuration from S3 and cache. + input_context = 'My cute dog' # "Legal" is one of the control codes for ctrl + bad_words_ids = [tokenizer.encode(bad_word, add_prefix_space=True) for bad_word in ['idiot', 'stupid', 'shut up']] + input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context + outputs = model.generate(input_ids=input_ids, max_length=100, do_sample=True, bad_words_ids=bad_words_ids) # generate sequences without allowing bad_words to be generated + """ + + max_length = max_length if max_length is not None else self.config.max_length + min_length = min_length if min_length is not None else self.config.min_length + do_sample = do_sample if do_sample is not None else self.config.do_sample + early_stopping = ( + early_stopping if early_stopping is not None else self.config.early_stopping + ) + num_beams = num_beams if num_beams is not None else self.config.num_beams + temperature = ( + temperature if temperature is not None else self.config.temperature + ) + top_k = top_k if top_k is not None else self.config.top_k + top_p = top_p if top_p is not None else self.config.top_p + repetition_penalty = ( + repetition_penalty + if repetition_penalty is not None + else self.config.repetition_penalty + ) + bos_token_id = ( + bos_token_id if bos_token_id is not None else self.config.bos_token_id + ) + pad_token_id = ( + pad_token_id if pad_token_id is not None else self.config.pad_token_id + ) + eos_token_id = ( + eos_token_id if eos_token_id is not None else self.config.eos_token_id + ) + length_penalty = ( + length_penalty if length_penalty is not None else self.config.length_penalty + ) + no_repeat_ngram_size = ( + no_repeat_ngram_size + if no_repeat_ngram_size is not None + else self.config.no_repeat_ngram_size + ) + bad_words_ids = bad_words_ids + num_return_sequences = ( + num_return_sequences + if num_return_sequences is not None + else self.config.num_return_sequences + ) + decoder_start_token_id = ( + decoder_start_token_id + if decoder_start_token_id is not None + else self.config.decoder_start_token_id + ) + + if input_ids is not None: + batch_size = input_ids.shape[0] # overriden by the input batch_size + else: + batch_size = 1 + + assert ( + isinstance(max_length, int) and max_length > 0 + ), "`max_length` should be a strictly positive integer." + assert ( + isinstance(min_length, int) and min_length >= 0 + ), "`min_length` should be a positive integer." + assert isinstance(do_sample, bool), "`do_sample` should be a boolean." + assert isinstance(early_stopping, bool), "`early_stopping` should be a boolean." + assert ( + isinstance(num_beams, int) and num_beams > 0 + ), "`num_beams` should be a strictly positive integer." + assert temperature > 0, "`temperature` should be strictly positive." + assert ( + isinstance(top_k, int) and top_k >= 0 + ), "`top_k` should be a positive integer." + assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1." + assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1." + assert input_ids is not None or ( + isinstance(bos_token_id, int) and bos_token_id >= 0 + ), "If input_ids is not defined, `bos_token_id` should be a positive integer." + assert pad_token_id is None or ( + isinstance(pad_token_id, int) and (pad_token_id >= 0) + ), "`pad_token_id` should be a positive integer." + assert (eos_token_id is None) or ( + isinstance(eos_token_id, int) and (eos_token_id >= 0) + ), "`eos_token_id` should be a positive integer." + assert length_penalty > 0, "`length_penalty` should be strictly positive." + assert ( + isinstance(no_repeat_ngram_size, int) and no_repeat_ngram_size >= 0 + ), "`no_repeat_ngram_size` should be a positive integer." + assert ( + isinstance(num_return_sequences, int) and num_return_sequences > 0 + ), "`num_return_sequences` should be a strictly positive integer." + assert ( + bad_words_ids is None + or isinstance(bad_words_ids, list) + and isinstance(bad_words_ids[0], list) + ), "`bad_words_ids` is either `None` or a list of lists of tokens that should not be generated" + + if input_ids is None: + assert isinstance(bos_token_id, int) and bos_token_id >= 0, ( + "you should either supply a context to complete as `input_ids` input " + "or a `bos_token_id` (integer >= 0) as a first token to start the generation." + ) + input_ids = torch.full( + (batch_size, 1), + bos_token_id, + dtype=torch.long, + device=next(self.parameters()).device, + ) + else: + assert ( + input_ids.dim() == 3 + ), "Input prompt should be of shape (batch_size, sequence length)." + + # not allow to duplicate outputs when greedy decoding + if do_sample is False: + if num_beams == 1: + # no_beam_search greedy generation conditions + assert ( + num_return_sequences == 1 + ), "Greedy decoding will always produce the same output for num_beams == 1 and num_return_sequences > 1. Please set num_return_sequences = 1" + + else: + # beam_search greedy generation conditions + assert ( + num_beams >= num_return_sequences + ), "Greedy beam search decoding cannot return more sequences than it has beams. Please set num_beams >= num_return_sequences" + + # create attention mask if necessary + # TODO (PVP): this should later be handled by the forward fn() in each model in the future see PR 3140 + if ( + (attention_mask is None) + and (pad_token_id is not None) + and (pad_token_id in input_ids) + ): + attention_mask = input_ids.ne(pad_token_id).long() + elif attention_mask is None: + attention_mask = input_ids.new_ones(input_ids.shape) + + # set pad_token_id to eos_token_id if not set. Important that this is done after + # attention_mask is created + if pad_token_id is None and eos_token_id is not None: + logger.warning( + "Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format( + eos_token_id + ) + ) + pad_token_id = eos_token_id + + # current position and vocab size + vocab_size = self.vocab_size + + # set effective batch size and effective batch multiplier according to do_sample + if do_sample: + effective_batch_size = batch_size * num_return_sequences + effective_batch_mult = num_return_sequences + else: + effective_batch_size = batch_size + effective_batch_mult = 1 + + if decoder_start_token_id is None: + decoder_start_token_id = bos_token_id + + assert ( + decoder_start_token_id is not None + ), "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation" + + encoder_outputs = self.encoder(input_ids, inpu t_roles, input_pos, input_ent) + + # # Expand input ids if num_beams > 1 or num_return_sequences > 1 + # if num_return_sequences > 1 or num_beams > 1: + # input_sent_len = input_ids.shape[2] + # input_word_len = input_ids.shape[3] + # input_ids = input_ids.unsqueeze(1).expand(batch_size, effective_batch_mult * num_beams, input_sent_len, input_word_len) + # attention_mask = attention_mask.unsqueeze(1).expand( + # batch_size, effective_batch_mult * num_beams, input_sent_len, input_word_len + # ) + + # input_ids = input_ids.contiguous().view( + # effective_batch_size * num_beams, input_sent_len, input_word_len + # ) # shape: (batch_size * num_return_sequences * num_beams, input_sent_len, input_word_len) + # attention_mask = attention_mask.contiguous().view( + # effective_batch_size * num_beams, input_sent_len, input_word_len + # ) # shape: (batch_size * num_return_sequences * num_beams, input_sent_len, input_word_len) + + # create empty decoder_input_ids + input_ids = torch.full( + (effective_batch_size * num_beams, 1), + decoder_start_token_id, + dtype=torch.long, + device=next(self.parameters()).device, + ) + cur_len = 1 + + assert ( + batch_size == encoder_outputs[0].shape[0] + ), f"expected encoder_outputs[0] to have 1st dimension bs={batch_size}, got {encoder_outputs[0].shape[0]} " + + # expand batch_idx to assign correct encoder output for expanded input_ids (due to num_beams > 1 and num_return_sequences > 1) + expanded_batch_idxs = ( + torch.arange(batch_size) + .view(-1, 1) + .repeat(1, num_beams * effective_batch_mult) + .view(-1) + .to(input_ids.device) + ) + # expand encoder_outputs + encoder_outputs = ( + encoder_outputs[0].index_select(0, expanded_batch_idxs), + encoder_outputs[1].index_select(0, expanded_batch_idxs), + ) + + if num_beams > 1: + output = self._generate_beam_search( + input_ids, + cur_len=cur_len, + max_length=max_length, + min_length=min_length, + do_sample=do_sample, + early_stopping=early_stopping, + temperature=temperature, + top_k=top_k, + top_p=top_p, + repetition_penalty=repetition_penalty, + no_repeat_ngram_size=no_repeat_ngram_size, + bad_words_ids=bad_words_ids, + bos_token_id=bos_token_id, + pad_token_id=pad_token_id, + decoder_start_token_id=decoder_start_token_id, + eos_token_id=eos_token_id, + batch_size=effective_batch_size, + num_return_sequences=num_return_sequences, + length_penalty=length_penalty, + num_beams=num_beams, + vocab_size=vocab_size, + encoder_outputs=encoder_outputs, + attention_mask=attention_mask, + ) + else: + output = self._generate_no_beam_search( + input_ids, + cur_len=cur_len, + max_length=max_length, + min_length=min_length, + do_sample=do_sample, + temperature=temperature, + top_k=top_k, + top_p=top_p, + repetition_penalty=repetition_penalty, + no_repeat_ngram_size=no_repeat_ngram_size, + bad_words_ids=bad_words_ids, + bos_token_id=bos_token_id, + pad_token_id=pad_token_id, + decoder_start_token_id=decoder_start_token_id, + eos_token_id=eos_token_id, + batch_size=effective_batch_size, + encoder_outputs=encoder_outputs, + attention_mask=at tention_mask, + ) + + return output + + def _generate_no_beam_search( + self, + input_ids, + cur_len, + max_length, + min_length, + do_sample, + temperature, + top_k, + top_p, + repetition_penalty, + no_repeat_ngram_size, + bad_words_ids, + bos_token_id, + pad_token_id, + eos_token_id, + decoder_start_token_id, + batch_size, + encoder_outputs, + attention_mask, + ): + """Generate sequences for each example without beam search (num_beams == 1). + All returned sequence are generated independantly. + """ + # length of generated sentences / unfinished sentences + unfinished_sents = input_ids.new(batch_size).fill_(1) + sent_lengths = input_ids.new(batch_size).fill_(max_length) + + past = encoder_outputs # defined for encoder-decoder models, None for decoder-only models + + while cur_len < max_length: + model_inputs = self.prepare_inputs_for_generation( + input_ids, past=past, attention_mask=attention_mask + ) + + outputs = self.decoder(**model_inputs) + next_token_logits = outputs[:, -1, :] + + # repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858) + if repetition_penalty != 1.0: + self.enforce_repetition_penalty_( + next_token_logits, batch_size, 1, input_ids, repetition_penalty + ) + + if no_repeat_ngram_size > 0: + # calculate a list of banned tokens to prevent repetitively generating the same ngrams + # from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345 + banned_tokens = calc_banned_ngram_tokens( + input_ids, batch_size, no_repeat_ngram_size, cur_len + ) + for batch_idx in range(batch_size): + next_token_logits[batch_idx, banned_tokens[batch_idx]] = -float( + "inf" + ) + + if bad_words_ids is not None: + # calculate a list of banned tokens according to bad words + banned_tokens = calc_banned_bad_words_ids(input_ids, bad_words_ids) + + for batch_idx in range(batch_size): + next_token_logits[batch_idx, banned_tokens[batch_idx]] = -float( + "inf" + ) + + # set eos token prob to zero if min_length is not reached + if eos_token_id is not None and cur_len < min_length: + next_token_logits[:, eos_token_id] = -float("inf") + + if do_sample: + # Temperature (higher temperature => more likely to sample low probability tokens) + if temperature != 1.0: + next_token_logits = next_token_logits / temperature + # Top-p/top-k filtering + next_token_logits = top_k_top_p_filtering( + next_token_logits, top_k=top_k, top_p=top_p + ) + # Sample + probs = F.softmax(next_token_logits, dim=-1) + next_token = torch.multinomial(probs, num_samples=1).squeeze(1) + else: + # Greedy decoding + next_token = torch.argmax(next_token_logits, dim=-1) + + # update generations and finished sentences + if eos_token_id is not None: + # pad finished sentences if eos_token_id exist + tokens_to_add = next_token * unfinished_sents + (pad_token_id) * ( + 1 - unfinished_sents + ) + else: + tokens_to_add = next_token + + input_ids = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=-1) + + if eos_token_id is not None: + eos_in_sents = tokens_to_add == eos_token_id + # if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length + is_sents_unfinished_and_token_to_add_is_eos = unfinished_sents.mul( + eos_in_sents.long() + ).bool() + sent_lengths.masked_fill_( + is_sents_unfinished_and_token_to_add_is_eos, cur_len + 1 + ) + # unfinished_sents is set to zero if eos in sentence + unfinished_sents.mul_((~eos_in_sents).long()) + + # stop when there is a in each sentence, or if we exceed the maximul length + if unfinished_sents.max() == 0: + break + + cur_len = cur_len + 1 + + # if there are different sentences lengths in the batch, some batches have to be padded + if sent_lengths.min().item() != sent_lengths.max().item(): + assert ( + pad_token_id is not None + ), "`Pad_token_id` has to be defined if batches have different lengths" + # finished sents are filled with pad_token + decoded = input_ids.new(batch_size, sent_lengths.max().item()).fill_( + pad_token_id + ) + else: + decoded = input_ids + + for hypo_idx, hypo in enumerate(input_ids): + decoded[hypo_idx, : sent_lengths[hypo_idx]] = hypo[: sent_lengths[hypo_idx]] + + return decoded + + def _generate_beam_search( + self, + input_ids, + cur_len, + max_length, + min_length, + do_sample, + early_stopping, + temperature, + top_k, + top_p, + repetition_penalty, + no_repeat_ngram_size, + bad_words_ids, + bos_token_id, + pad_token_id, + eos_token_id, + decoder_start_token_id, + batch_size, + num_return_sequences, + length_penalty, + num_beams, + vocab_size, + encoder_outputs, + attention_mask, + ): + """Generate sequences for each example with beam search.""" + + # generated hypotheses + generated_hyps = [ + BeamHypotheses( + num_beams, max_length, length_penalty, early_stopping=early_stopping + ) + for _ in range(batch_size) + ] + + # scores for each sentence in the beam + beam_scores = torch.zeros( + (batch_size, num_beams), dtype=torch.float, device=input_ids.device + ) + + # for greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times + if do_sample is False: + beam_scores[:, 1:] = -1e9 + beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,) + + # cache compute states + past = encoder_outputs # defined for encoder-decoder models, None for decoder-only models + + # done sentences + done = [False for _ in range(batch_size)] + + while cur_len < max_length: + model_inputs = self.prepare_inputs_for_generation( + input_ids, past=past, attention_mask=attention_mask + ) + outputs = self.decoder( + **model_inputs + ) # (batch_size * num_beams, cur_len, vocab_size) + next_token_logits = outputs[ + :, -1, : + ] # (batch_size * num_beams, vocab_size) + + # repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858) + if repetition_penalty != 1.0: + self.enforce_repetition_penalty_( + next_token_logits, + batch_size, + num_beams, + input_ids, + repetition_penalty, + ) + + if temperature != 1.0: + next_token_logits = next_token_logits / temperature + + scores = F.log_softmax( + next_token_logits, dim=-1 + ) # (batch_siz e * num_beams, vocab_size) + if do_sample is False: + # TODO (PVP) still a bit hacky here - there might be a better solution + scores = self.prepare_scores_for_generation( + scores, cur_len=cur_len, max_length=max_length + ) + + # set eos token prob to zero if min_length is not reached + if eos_token_id is not None and cur_len < min_length: + scores[:, eos_token_id] = -float("inf") + + if no_repeat_ngram_size > 0: + # calculate a list of banned tokens to prevent repetitively generating the same ngrams + num_batch_hypotheses = batch_size * num_beams + # from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345 + banned_batch_tokens = calc_banned_ngram_tokens( + input_ids, num_batch_hypotheses, no_repeat_ngram_size, cur_len + ) + for i, banned_tokens in enumerate(banned_batch_tokens): + scores[i, banned_tokens] = -float("inf") + + if bad_words_ids is not None: + # calculate a list of banned tokens according to bad words + banned_tokens = calc_banned_bad_words_ids(input_ids, bad_words_ids) + + for i, banned_tokens in enumerate(banned_tokens): + scores[i, banned_tokens] = -float("inf") + + assert scores.shape == ( + batch_size * num_beams, + vocab_size, + ), "Shapes of scores: {} != {}".format( + scores.shape, (batch_size * num_beams, vocab_size) + ) + + if do_sample: + _scores = scores + beam_scores[:, None].expand_as( + scores + ) # (batch_size * num_beams, vocab_size) + # Top-p/top-k filtering + _scores = top_k_top_p_filtering( + _scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=2 + ) # (batch_size * num_beams, vocab_size) + # re-organize to group the beam together to sample from all beam_idxs + _scores = _scores.contiguous().view( + batch_size, num_beams * vocab_size + ) # (batch_size, num_beams * vocab_size) + + # Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search) + probs = F.softmax(_scores, dim=-1) + next_tokens = torch.multinomial( + probs, num_samples=2 * num_beams + ) # (batch_size, num_beams * 2) + # Compute next scores + next_scores = torch.gather( + _scores, -1, next_tokens + ) # (batch_size, num_beams * 2) + # sort the sampled vector to make sure that the first num_beams samples are the best + next_scores, next_scores_indices = torch.sort( + next_scores, descending=True, dim=1 + ) + next_tokens = torch.gather( + next_tokens, -1, next_scores_indices + ) # (batch_size, num_beams * 2) + + else: + next_scores = scores + beam_scores[:, None].expand_as( + scores + ) # (batch_size * num_beams, vocab_size) + + # re-organize to group the beam together (we are keeping top hypothesis accross beams) + next_scores = next_scores.view( + batch_size, num_beams * vocab_size + ) # (batch_size, num_beams * vocab_size) + + next_scores, next_tokens = torch.topk( + next_scores, 2 * num_beams, dim=1, largest=True, sorted=True + ) + + assert ( + next_scores.size() == next_tokens.size() == (batch_size, 2 * num_beams) + ) + + # next batch beam content + next_batch_beam = [] + + # for each sentence + for batch_idx in range(batch_size): + + # if we are done with this sentence + if done[batch_idx]: + assert ( + len(generated_hyps[batch_idx]) >= num_beams + ), "Batch can only be done if at least {} beams have been generated".format( + num_beams + ) + assert ( + eos_token_id is not None and pad_token_id is not None + ), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined" + next_batch_beam.extend( + [(0, pad_token_id, 0)] * num_beams + ) # pad the batch + continue + + # next sentence beam content + next_sent_beam = [] + + # next tokens for this sentence + for beam_token_rank, (beam_token_id, beam_token_score) in enumerate( + zip(next_tokens[batch_idx], next_scores[batch_idx]) + ): + # get beam and token IDs + beam_id = beam_token_id // vocab_size + token_id = beam_token_id % vocab_size + + effective_beam_id = batch_idx * num_beams + beam_id + # add to generated hypotheses if end of sentence or last iteration + if (eos_token_id is not None) and (token_id.item() == eos_token_id): + # if beam_token does not belong to top num_beams tokens, it should not be added + is_beam_token_worse_than_top_num_beams = ( + beam_token_rank >= num_beams + ) + if is_beam_token_worse_than_top_num_beams: + continue + generated_hyps[batch_idx].add( + input_ids[effective_beam_id].clone(), + beam_token_score.item(), + ) + else: + # add next predicted token if it is not eos_token + next_sent_beam.append( + (beam_token_score, token_id, effective_beam_id) + ) + + # the beam for next step is full + if len(next_sent_beam) == num_beams: + break + + # Check if were done so that we can save a pad step if all(done) + done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done( + next_scores[batch_idx].max().item(), cur_len=cur_len + ) + + # update next beam content + assert len(next_sent_beam) == num_beams, "Beam should always be full" + next_batch_beam.extend(next_sent_beam) + assert len(next_batch_beam) == num_beams * (batch_idx + 1) + + # stop when we are done with each sentence + if all(done): + break + + # sanity check / prepare next batch + assert len(next_batch_beam) == batch_size * num_beams + beam_scores = beam_scores.new([x[0] for x in next_batch_beam]) + beam_tokens = input_ids.new([x[1] for x in next_batch_beam]) + beam_idx = input_ids.new([x[2] for x in next_batch_beam]) + + # re-order batch + input_ids = input_ids[beam_idx, :] + input_ids = torch.cat([input_ids, beam_tokens.unsqueeze(1)], dim=-1) + # re-order internal states + if past is not None: + past = self._reorder_cache(past, beam_idx) + + # update current length + cur_len = cur_len + 1 + + # finalize all open beam hypotheses and end to generated hypotheses + for batch_idx in range(batch_size): + if done[batch_idx]: + continue + + # test that beam scores match previously calculated scores if not eos and batch_idx not done + if eos_token_id is not None and all( + (token_id % vocab_size).item() is not eos_token_id + for token_id in next_tokens[batch_idx] + ): + assert torch.all( + next_scores[batch_idx, :num_beams] + == beam_scores.view(batch_size, num_beams)[batch_idx] + ), "If batch_idx is not done, final next scores: {} have to equal to accumulated beam_scores: {}".format( + next_scores[:, :num_beams][batch_idx], + beam_scores.view(batch_size, num_beams)[batch_idx], + ) + + # need to add best num_beams hypotheses to generated hyps + for beam_id in range(num_beams): + effective_beam_id = batch_idx * num_beams + beam_id + final_score = beam_scores[effective_beam_id].item() + final_tokens = input_ids[effective_beam_id] + generated_hyps[batch_idx].add(final_tokens, final_score) + + # depending on whether greedy generation is wanted or not define different output_batch_size and output_num_return_sequences_per_batch + output_batch_size = ( + batch_size if do_sample else batch_size * num_return_sequences + ) + output_num_return_sequences_per_batch = 1 if do_sample else num_return_sequences + + # select the best hypotheses + sent_lengths = input_ids.new(output_batch_size) + best = [] + + # retrieve best hypotheses + for i, hypotheses in enumerate(generated_hyps): + sorted_hyps = sorted(hypotheses.beams, key=lambda x: x[0]) + for j in range(output_num_return_sequences_per_batch): + effective_batch_idx = output_num_return_sequences_per_batch * i + j + best_hyp = sorted_hyps.pop()[1] + sent_lengths[effective_batch_idx] = len(best_hyp) + best.append(best_hyp) + + # shorter batches are filled with pad_token + if sent_lengths.min().item() != sent_lengths.max().item(): + assert pad_token_id is not None, "`Pad_token_id` has to be defined" + sent_max_len = min(sent_lengths.max().item() + 1, max_length) + decoded = input_ids.new(output_batch_size, sent_max_len).fill_(pad_token_id) + + # fill with hypothesis and eos_token_id if necessary + for i, hypo in enumerate(best): + decoded[i, : sent_lengths[i]] = hypo + if sent_lengths[i] < max_length: + decoded[i, sent_lengths[i]] = eos_token_id + else: + # none of the hypotheses have an eos_token + assert (len(hypo) == max_length for hypo in best) + decoded = ( + torch.stack(best).type(torch.long).to(next(self.parameters()).device) + ) + + return decoded + + # force one of token_ids to be generated by setting prob of all other tokens to 0. + def _force_token_ids_generation(self, scores, token_ids): + if isinstance(token_ids, int): + token_ids = [token_ids] + all_but_token_ids_mask = torch.tensor( + [x for x in range(self.vocab_size) if x not in token_ids], + dtype=torch.long, + device=next(self.parameters()).device, + ) + assert ( + len(scores.shape) == 2 + ), "scores should be of rank 2 with shape: [batch_size, vocab_size]" + scores[:, all_but_token_ids_mask] = -float("inf") + + @staticmethod + def _reorder_cache(past, beam_idx): + reordered_past = [] + for layer_past in past: + # get the correct batch idx from layer past batch dim + # batch dim of `past` and `mems` is at 2nd position + reordered_layer_past = [ + layer_past[i, :].unsqueeze(0).clone().detach() for i in beam_idx + ] + reordered_layer_past = torch.cat(reordered_layer_past, dim =0) + # check that shape matches + assert reordered_layer_past.shape == layer_past.shape + reordered_past.append(reordered_layer_past) + past = tuple(reordered_past) + return past + + +""" + Transformer encoder +""" + + +class MeetingTransformerEncoder(nn.Module): + """ + Input: + transformer_embed_dim: transformer dimension + """ + + def __init__(self, opt, transformer_embed_dim): + super(MeetingTransformerEncoder, self).__init__() + vocab = int(opt["vocab_size"]) + n_layer = int(opt["TRANSFORMER_LAYER"]) + opt["transformer_embed_dim"] = transformer_embed_dim + block = EncoderBlock(opt) + self.blocks = nn.ModuleList([copy.deepcopy(block) for _ in range(n_layer)]) + + """ + Input: + x: batch x len x n_state + Output: + h: batch x len x n_state + """ + + def forward(self, x): + h = x + for block in self.blocks: + h = block(h, None) + return h + + +""" + One encoder block of transformer +""" + + +class MeetingDecoderBlock(nn.Module): + def __init__(self, opt, n_state): + super(MeetingDecoderBlock, self).__init__() + self.opt = opt + self.decoder_splitter = Splitter(n_state) + self.attn = Attention(n_state, opt) + self.token_attn = Attention(n_state, opt) + self.sent_attn = Attention(n_state, opt) + self.ln_1 = LayerNorm(n_state) + self.ln_2 = LayerNorm(n_state) + opt["transformer_embed_dim"] = n_state + self.mlp = MLP(4 * n_state, opt) + self.ln_3 = LayerNorm(n_state) + self.ln_4 = LayerNorm(n_state) + + """ + Input: + y: batch x len x n_state (decoder part) + token_enc_key: batch x encoder_len x n_state + token_enc_value: batch x encoder_len x n_state + sent_enc_key: batch x encoder_len x n_state + sent_enc_value: batch x encoder_len x n_state + Output: + h: batch x len x n_state + """ + + def forward(self, y, token_enc_key, token_enc_value, sent_enc_key, sent_enc_value): + query, key, value = self.decoder_splitter(y) + # batch x len x n_state + + # self-attention + a = self.attn(query, key, value, None, one_dir_visible=True) + # batch x len x n_state + + n = self.ln_1(y + a) # residual + + if "NO_HIERARCHY" in self.opt: + q = y + r = n + else: + # src-tgt attention on sentences + q = self.sent_attn(n, sent_enc_key, sent_enc_value, None) + r = self.ln_3(n + q) # residual + # batch x len x n_state + + # src-tgt attention on tokens + o = self.token_attn(r, token_enc_key, token_enc_value, None) + p = self.ln_2(r + o) # residual + # batch x len x n_state + + m = self.mlp(p) + h = self.ln_4(p + m) + return h + + +""" + Transformer decoder +""" + + +class MeetingTransformerDecoder(nn.Module): + """ + Input: + embed_size: decoder transformer dimension + token_dim: dimension of transformer from token encoder side + sent_dim: dimension of transformer from sent encoder side + """ + + def __init__(self, opt, embedder, embed_size, token_dim, sent_dim): + super(MeetingTransformerDecoder, self).__init__() + self.fp16 = "FP16" in opt + vocab_size = int(opt["vocab_size"]) + n_layer = int(opt["TRANSFORMER_LAYER"]) + self.encoder_splitter = Splitter(embed_size) + block = MeetingDecoderBlock(opt, embed_size) + self.token_linear = nn.Linear(token_dim, embed_size) + self.sent_linear = nn.Linear(sent_dim, embed_size) + self.blocks = nn.ModuleList([copy.deepcopy(block) for _ in range(n_layer)]) + self.linear = nn.Linear(embed_size, vocab_size, bias=False) + self.linear.weight = embedder.embed.weight # share weight + + """ + Input: + token_encoder_outputs: 1 x (encoder_len - sent_num) x token_transformer_dim + sent_encoder_outputs: 1 x sent_nu m x sent_transformer_dim + y: batch x len x n_state + Output: + prob: batch x len x vocab_size (probabilities after softmax) + """ + + def forward(self, token_encoder_inputs, sent_encoder_inputs, decoder_input_ids): + _, token_enc_key, token_enc_value = self.encoder_splitter( + self.token_linear(token_encoder_inputs) + ) + # token_enc_key: batch x encoder_len x n_state + # token_enc_value: batch x encoder_len x n_state + + _, sent_enc_key, sent_enc_value = self.encoder_splitter( + self.sent_linear(sent_encoder_inputs) + ) + # sent_enc_key: batch x encoder_len x n_state + # sent_enc_value: batch x encoder_len x n_state + + h = decoder_input_ids + for block in self.blocks: + h = block(h, token_enc_key, token_enc_value, sent_enc_key, sent_enc_value) + prob = F.softmax(self.linear(h), dim=-1) + return prob + + +class Encoder(nn.Module): + """ + vocab_size: size of input vocabulary + embed_size: word embedding dimension of dictionary + role_dim: role embedding dimension + embed: the nn.Embedding for vocab + role_embed: the nn.Embedding for role + """ + + def __init__(self, opt, vocab_size, embed_size, role_dim, embedder, role_embed): + super(Encoder, self).__init__() + self.opt = opt + self.vocab_size = vocab_size + + set_seq_dropout("VARIATIONAL_DROPOUT" in self.opt) + + self.embed_size = embed_size + self.embedder = embedder + self.role_embed = role_embed + + self.token_transformer_dim = embed_size + if "USE_POSENT" in opt: + print("Use POS and ENT") + pos_dim = opt["POS_DIM"] + ent_dim = opt["ENT_DIM"] + self.pos_embed = nn.Embedding(len(POS), pos_dim) + self.ent_embed = nn.Embedding(len(ENT), ent_dim) + self.token_transformer_dim += pos_dim + ent_dim + + self.sent_transformer_dim = self.token_transformer_dim + if "USE_ROLE" in opt: + print("USE_ROLE") + role_dim = opt["ROLE_DIM"] + self.sent_transformer_dim += role_dim + + self.token_encoder = MeetingTransformerEncoder(opt, self.token_transformer_dim) + self.sent_encoder = MeetingTransformerEncoder(opt, self.sent_transformer_dim) + + """ + x = bz * sent_num * x_len (word_ids) + x_role = bz * sent_num (role_ids) + x_pos = bz * sent_num * x_len (pos_ids) + x_ent = bz * sent_num * x_len (ent_ids) + outputs: + token_encoder_outputs: bz x x_len_total x token_transformer_dim + sent_encoder_outputs: bz x sent_num x sent_transformer_dim + """ + + def forward(self, x, x_role, x_pos, x_ent): + batch_size = x.size(0) + sent_num = x.size(1) + x_len = x.size(2) + + # x contains word id >= vocab_size + vocab_x = x.clone() + vocab_x[vocab_x >= self.vocab_size] = 1 # UNK + embedded = self.embedder(vocab_x.view(batch_size, -1)) + # embedded = 1 x sent_num * x_len x embed_size + embedded = embedded.view(batch_size, sent_num, x_len, -1) + # embedded = 1 x sent_num x x_len x embed_size + + if "USE_ROLE" in self.opt: + role_embed = self.role_embed(x_role) # 1 x sent_num x role_dim + + if "USE_POSENT" in self.opt: + embedded = torch.cat( + [embedded, self.pos_embed(x_pos), self.ent_embed(x_ent)], dim=3 + ) + # 1 x sent_num x x_len x (embed_size + pos_dim + ent_dim ) + + feat_dim = embedded.size(3) + + token_transformer_output = self.token_encoder( + embedded.view(-1, x_len, feat_dim) + ) + token_transformer_dim = token_transformer_output.size(2) + token_transformer_output = token_transformer_output.view( + batch_size, sent_num, x_len, token_transformer_dim + ) + # 1 x sent_num x x_len x token_transformer_dim + + sent_encoder_inputs = token_transformer_output[ + :, :, 0, : + ] # 1 x sent_num x token_transformer_dim + if "USE_ROLE" in self.opt: + sent_encoder_inputs = torch.cat([sent_encoder_inputs, role_embed], dim=2) + sent_encoder_outputs = self.sent_encoder( + sent_encoder_inputs + ) # 1 x sent_num x sent_transformer_dim + + token_transformer_output = token_transformer_output.view( + batch_size, -1, token_transformer_dim + ) + + return token_transformer_output, sent_encoder_outputs + + +class Decoder(nn.Module): + def __init__( + self, + opt, + embed_size, + vocab_size, + embedder, + token_transformer_dim, + sent_transformer_dim, + ): + super(Decoder, self).__init__() + self.opt = opt + self.embed_size = embed_size + self.vocab_size = vocab_size + self.embedder = embedder + self.sent_decoder = MeetingTransformerDecoder( + opt, embedder, embed_size, token_transformer_dim, sent_transformer_dim + ) + + def forward(self, token_encoder_outputs, sent_encoder_outputs, decoder_input_ids): + vocab_y = decoder_input_ids.clone() + vocab_y[vocab_y >= self.vocab_size] = 1 # UNK + embedded = self.embedder(vocab_y) + + vocab_prob = self.sent_decoder( + token_encoder_outputs, sent_encoder_outputs, embedded + ) + # vocab_prob: batch x y_len x vocab_size + + vocab_logprob = torch.log(vocab_prob + 1e-15) + return vocab_logprob diff --git a/model/third_party/HMNet/Models/Networks/Transformer.py b/model/third_party/HMNet/Models/Networks/Transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..e1ce4582b9ca2d9ac5b6ab3720ab9e6e1581c719 --- /dev/null +++ b/model/third_party/HMNet/Models/Networks/Transformer.py @@ -0,0 +1,845 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import copy +import json +import math +import re +import collections +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable +from torch.nn.parameter import Parameter + + +def gelu(x): + return ( + 0.5 + * x + * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) + ) + + +def swish(x): + return x * torch.sigmoid(x) + + +class LayerNorm(nn.Module): + "Construct a layernorm module in the OpenAI style (epsilon inside the square root)." + + def __init__(self, n_state, e=1e-5): + super(LayerNorm, self).__init__() + self.g = nn.Parameter(torch.ones(n_state)) + self.b = nn.Parameter(torch.zeros(n_state)) + self.e = e + + """ + Input: + x: n_state-dim + Output: + o: n_state-dim + """ + + def forward(self, x): + u = x.mean(-1, keepdim=True) + s = (x - u).pow(2).mean(-1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.e) + return self.g * x + self.b + + +""" + Convolution + nx is the last input dim + nf is the last output dim +""" + + +class Conv1D(nn.Module): + def __init__(self, nf, nx): + super(Conv1D, self).__init__() + self.nf = nf + w = torch.empty(nx, nf) + nn.init.normal_(w, std=0.02) + self.w = Parameter(w) + self.b = Parameter(torch.zeros(nf)) + + """ + Input: + x: batch x len x nx + Output: + x: batch x len x nf + """ + + def forward(self, x): + size_out = x.size()[:-1] + (self.nf,) + x = torch.addmm(self.b, x.view(-1, x.size(-1)), self.w) + x = x.view(*size_out) + return x + + +class PositionalEmbedding(nn.Module): + def __init__(self, opt, demb): + super(PositionalEmbedding, self).__init__() + self.demb = demb + inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb)) + self.pos_discount = float(opt["TRANSFORMER_POS_DISCOUNT"]) + self.register_buffer("inv_freq", inv_freq) + + """ + Input: + pos_seq: len + Output: + pos_emb: le n x demb + """ + + def forward(self, pos_seq): + sinusoid_inp = torch.ger(pos_seq, self.inv_freq) + pos_emb = ( + torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1) + / self.pos_discount + ) + return pos_emb + + +""" + Splitter +""" + + +class Splitter(nn.Module): + def __init__(self, nx): + super(Splitter, self).__init__() + self.nx = nx + self.augmenter = Conv1D(nx * 3, nx) + + """ + Input: + x: batch x len x nx + Output: + query,key,value: batch x len x nx + """ + + def forward(self, x): + x = self.augmenter(x) + # x: batch x len x (3 x nx) + + query, key, value = x.split(self.nx, dim=2) + # query,key,value: batch x len x nx + + return query, key, value + + +""" + Multi-head Attention +""" + + +class Attention(nn.Module): + """ + nx: input dimension + """ + + def __init__(self, nx, opt): + super(Attention, self).__init__() + n_state = nx # in Attention: n_state=768 (nx=n_embd) + # [switch nx => n_state from Block to Attention to keep identical to TF implem] + n_head = int(opt["TRANSFORMER_HEAD"]) + resid_pdrop = opt["TRANSFORMER_RESIDUAL_DROPOUT"] + attn_pdrop = opt["TRANSFORMER_ATTENTION_DROPOUT"] + use_cuda = opt["cuda"] + + assert n_state % n_head == 0 + # if mask is needed, uncomment this + self.maxlen = 2048 # beyond this scale + self.mask = ( + Variable( + torch.tril(torch.ones(self.maxlen, self.maxlen)).view( + 1, 1, self.maxlen, self.maxlen + ), + requires_grad=False, + ).cuda() + if use_cuda + else Variable( + torch.tril(torch.ones(self.maxlen, self.maxlen)).view( + 1, 1, self.maxlen, self.maxlen + ), + requires_grad=False, + ) + ) + self.n_head = n_head + self.c_proj = Conv1D(n_state, nx) + self.attn_dropout = nn.Dropout(attn_pdrop) + self.resid_dropout = nn.Dropout(resid_pdrop) + self.use_cuda = use_cuda + + """ + Input: + q: batch x n_head x len x dim + k: batch x n_head x dim x kv_len + v: batch x n_head x kv_len x dim + x_mask: batch x kv_len # key and value's mask (if not None, used for encoder's self-attention and decoder's src-tgt attention) + one_dir_visible: only sees previous history (used for decoder's self-attention) + return_attn_weight: if true, also return the attention weights + Output: + a: batch x n_head x len x n_state x dim + attn_weight (if return_attn_weight): attn_weight: batch x n_head x len x kv_len + """ + + def _attn(self, q, k, v, x_mask, one_dir_visible, return_attn_weight): + w = torch.matmul(q, k) + # batch x n_head x len x kv_len + w = w / math.sqrt(v.size(-1)) + + mask = None + if one_dir_visible: # mask "seeing the future" + if w.size(-2) <= self.maxlen and w.size(-1) <= self.maxlen: + mask = ( + self.mask[:, :, : w.size(-2), : w.size(-1)].cuda() + if self.use_cuda + else self.mask[:, :, : w.size(-2), : w.size(-1)] + ) + else: + mask = ( + Variable( + torch.tril(torch.ones(w.size(-2), w.size(-1))).view( + 1, 1, w.size(-2), w.size(-1) + ), + requires_grad=False, + ).cuda() + if self.use_cuda + else Variable( + torch.tril(torch.ones(w.size(-2), w.size(-1))).view( + 1, 1, w.size(-2), w.size(-1) + ), + requires_grad=False, + ) + ) + + if x_mask is not None: + mask = x_ mask.unsqueeze(1).unsqueeze(1).expand_as(w).float() + # batch x n_head x len x kv_len + + if mask is not None: + w = w * mask + -1e9 * (1 - mask) + + w_prob = nn.Softmax(dim=-1)(w) + w_prob = self.attn_dropout(w_prob) + if return_attn_weight: + return torch.matmul(w_prob, v), w + else: + return torch.matmul(w_prob, v) + + def merge_heads(self, x): + x = x.permute(0, 2, 1, 3).contiguous() + new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),) + return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states + + """ + Input: + x: batch x len x dim + Output: + not k: batch x n_head x (dim/n_head) x len + k: batch x n_head x len x (dim/n_head) + """ + + def split_heads(self, x, k=False): + new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head) + x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states + if k: + return x.permute(0, 2, 3, 1) + else: + return x.permute(0, 2, 1, 3) + + """ + Input: + query: batch x len x n_state + key, value: batch x kv_len x n_state + x_mask: batch x kv_len # key and value's mask (if not None, used for encoder's self-attention and decoder's src-tgt attention) + one_dir_visible: only sees previous history (used for decoder's self-attention) + return_attn_weight: if true, also return the attention weights + Output: + a: batch x len x n_state + attn_weight (if return_attn_weight): batch x len x kv_len + """ + + def forward( + self, query, key, value, x_mask, one_dir_visible=False, return_attn_weight=False + ): + query = self.split_heads(query) + # batch x n_head x len x (n_state/n_head) + + key = self.split_heads(key, k=True) + # batch x n_head x (n_state/n_head) x kv_len + + value = self.split_heads(value) + # batch x n_head x kv_len x (n_state/n_head) + + out = self._attn(query, key, value, x_mask, one_dir_visible, return_attn_weight) + + if return_attn_weight: + a, attn_weight = out + # a: batch x n_head x len x (n_state/n_head) + # attn_weight: batch x n_head x len x kv_len + attn_weight = attn_weight.permute(0, 2, 3, 1).contiguous() + # batch x len x kv_len x n_head + attn_weight = torch.sum(attn_weight, dim=3) + # batch x len x kv_len + else: + a = out + # batch x n_head x len x (n_state/n_head) + + a = self.merge_heads(a) + # batch x len x n_state + + a = self.c_proj(a) + # batch x len x n_state + + a = self.resid_dropout(a) + # batch x len x n_state + + if return_attn_weight: + return a, attn_weight + else: + return a + + +""" + Two-layer network +""" + + +class MLP(nn.Module): + """ + Input: + n_state: intermediate dim + """ + + def __init__(self, n_state, opt): # in MLP: n_state=3072 (4 * n_embd) + super(MLP, self).__init__() + nx = int(opt["transformer_embed_dim"]) + resid_pdrop = opt["TRANSFORMER_RESIDUAL_DROPOUT"] + self.c_fc = Conv1D(n_state, nx) + self.c_proj = Conv1D(nx, n_state) + self.dropout = nn.Dropout(resid_pdrop) + + """ + Input: + x: batch x len x nx + Output: batch x len x nx + """ + + def forward(self, x): + h = F.relu(self.c_fc(x)) + h2 = self.c_proj(h) + return self.dropout(h2) + + +""" + One encoder block of transformer +""" + + +class EncoderBlock(nn.Module): + def __init__(self, opt): + super(EncoderBlock, self).__init__() + nx = int(opt["transformer_embed_dim"]) + self.one_dir_visible = False + if "transformer_encoder_one_dir_visible" in opt: + self.one_dir_visible = opt["transformer_encoder_one_dir_visible"] + self.splitter = Splitter(nx) + self.attn = Attention(nx, opt) + self.ln_1 = LayerNorm(nx) + self.mlp = MLP(4 * nx, opt) + self.ln_2 = LayerNorm(nx) + + """ + Input: + x: batch x len x n_state + x_mask: batch x len (1 means there's something) + Output: + h: batch x len x n_state + """ + + def forward(self, x, x_mask): + query, key, value = self.splitter(x) + if self.one_dir_visible: + # in this case, use triangle masking, as it's one_direction + a = self.attn(query, key, value, None, one_dir_visible=True) + else: + # in this case, use x_mask for attention masking + a = self.attn(query, key, value, x_mask, one_dir_visible=False) + + n = self.ln_1(x + a) # residual + m = self.mlp(n) + h = self.ln_2(n + m) + return h + + +""" + One encoder block of transformer +""" + + +class DecoderBlock(nn.Module): + def __init__(self, opt): + super(DecoderBlock, self).__init__() + nx = int(opt["transformer_embed_dim"]) + self.decoder_splitter = Splitter(nx) + self.self_attn = Attention(nx, opt) + self.cross_attn = Attention(nx, opt) + self.ln_1 = LayerNorm(nx) + self.ln_2 = LayerNorm(nx) + self.mlp = MLP(4 * nx, opt) + self.ln_3 = LayerNorm(nx) + + """ + Input: + x_mask: batch x len, mask for encoder's input + y: batch x len x n_state (decoder part) + enc_key: batch x encoder_len x n_state + enc_value: batch x encoder_len x n_state + lang_model: whether it's for language model training (no encoder part is used) + Output: + h: batch x len x n_state + """ + + def forward(self, x_mask, y, enc_key, enc_value, lang_model=False): + query, key, value = self.decoder_splitter(y) + # batch x len x n_state + + # self-attention + a = self.self_attn(query, key, value, None, one_dir_visible=True) + # batch x len x n_state + + n = self.ln_1(y + a) # residual + + # seq2seq + if not lang_model: + # src-tgt attention + o = self.cross_attn(n, enc_key, enc_value, x_mask) + p = self.ln_2(n + o) # residual + # batch x len x n_state + else: # language model + p = n + + m = self.mlp(p) + h = self.ln_3(p + m) + return h + + +""" + Embedder +""" + + +class Embedder(nn.Module): + """ + Input: + vocab: size of vocabulary + """ + + def __init__(self, opt, embed=None): + super(Embedder, self).__init__() + n_state = int(opt["transformer_embed_dim"]) # n_state + embed_dropout_rate = opt["TRANSFORMER_EMBED_DROPOUT"] + if embed is None: + self.embed = nn.Embedding(opt["vocab_size"], n_state) + nn.init.normal_(self.embed.weight, std=0.02) + else: + self.embed = embed + self.drop = nn.Dropout(embed_dropout_rate) + self.pos_emb = PositionalEmbedding(opt, n_state) + self.use_cuda = opt["cuda"] + + """ + Input: + x: batch x len (word_id) + Output: + h: batch x len x n_state + """ + + def forward(self, x): + x_emb = self.embed(x) + batch_size = x.shape[0] + x_len = x.shape[1] + x_pos = self.pos_emb( + torch.arange(x_len).type( + torch.cuda.FloatTensor if self.use_cuda else torch.FloatTensor + ) + ) # len x n_state + x_pos = ( + Variable( + x_pos.unsqueeze(0).repeat(batch_size, 1, 1), requires_grad=False + ).cuda() + if self.use_cuda + else Variable( + x_pos.unsqueeze(0).repeat(batch_size, 1, 1), requires_grad=False + ) + ) + x_input = x_emb + x_pos + h = self.drop(x_input) + return h + + +""" + Transformer encoder +""" + + +class TransformerEncoder(nn.Module): + """ + Input: + embed: (if not None) pre-computed vocab embeddings + """ + + def __init__(self, opt, embed=None): + super(TransformerEncoder, self).__init__() + vocab = int(opt["vocab_size"]) + n_state = int(opt["transformer_embed_dim"]) + n_layer = int(opt["TRANSFORMER_LAYER"]) + if "vae_z_scale_factor" in opt: + self.vae_z_scale_factor = float(opt["vae_z_scale_factor"]) + + self.embedder = Embedder(opt, embed) + block = EncoderBlock(opt) + self.blocks = nn.ModuleList([copy.deepcopy(block) for _ in range(n_layer)]) + self.use_cuda = opt["cuda"] + + """ + Input: + x: batch x len (word_id) + z (optional): batch x len x n_state (for VAE) + Output: + h: batch x len x n_state (word_id) + """ + + def forward(self, x, z=None): + x_mask = ~x.eq(0) # 1 is PAD_id + x_mask = x_mask.type( + torch.cuda.FloatTensor if self.use_cuda else torch.FloatTensor + ) + + h = self.embedder(x) + if z is not None: + z *= self.vae_z_scale_factor + h += z + + for block in self.blocks: + h = block(h, x_mask) + return h + + +""" + Transformer decoder +""" + + +class TransformerDecoder(nn.Module): + """ + Input: + embed: (if not None) pre-computed vocab embeddings + """ + + def __init__(self, opt, embed=None): + super(TransformerDecoder, self).__init__() + self.opt = opt + vocab_size = int(opt["vocab_size"]) + n_state = int(opt["transformer_embed_dim"]) # n_state + n_layer = int(opt["TRANSFORMER_LAYER"]) + self.embedder = Embedder(opt, embed) + self.encoder_splitter = Splitter(n_state) + block = DecoderBlock(opt) + self.blocks = nn.ModuleList([copy.deepcopy(block) for _ in range(n_layer)]) + if embed is None: + self.linear = Conv1D(vocab_size, n_state) + else: + self.linear = nn.Linear(n_state, vocab_size, bias=False) + if ( + "FINETUNE_RETRAIN_SOFTMAX" not in opt + ): # if FINETUNE_RETRAIN_SOFTMAX, linear needs to be seperately trained + self.linear.weight = embed.weight # share weight + self.use_coda = opt["cuda"] + + """ + Input: + x: batch x encoder_len (word id) + x_out: batch x encoder_len x n_state + y: batch x len (word_id) (decoder part) + lang_model: whether it's for language model training (no encoder part is used) + Output: + prob: batch x len x vocab_size (probabilities after softmax) + """ + + def forward(self, x, x_out, y, lang_model=False): + # seq2seq + if not lang_model: + _, enc_key, enc_value = self.encoder_splitter(x_out) + # enc_key: batch x encoder_len x n_state + # enc_value: batch x encoder_len x n_state + + x_mask = ~x.eq(0) # 1 is PAD_id + x_mask = x_mask.type( + torch.cuda.FloatTensor if self.use_cuda else torch.FloatTensor + ) + else: + enc_key = None + enc_value = None + x_mask = None + + h = self.embedder(y) + for block in self.blocks: + h = block(x_mask, h, enc_key, enc_value, lang_model) + prob = F.softmax(self.linear(h), dim=-1) + return prob + + +class TransformerBeam: + """ + Input: + encoder: TransformerEncoder class + decoder: TransformerDecoder class + begin_id: word id of '' + vocab: list of words + """ + + def __init__(self, opt, encoder, decoder, begin_id, vocab): + self.encoder = encoder + self.decoder = decoder + self.opt = opt + self.max_sent_len = int(opt["max_sent_len"]) + self.begin_id = begin_id + self.vocab = vocab + self.beam_width = int(opt["beam_width"]) + self.use_cuda = opt["cuda"] + + # each candidate is (idx, prob, 0/1, position/wordid) + def merge_candidates(self, cand_A, cand_B): + C = [] + pA, lA, pB, lB = 0, len(cand_A), 0, len(cand_B) + lC = 0 + while (pA < lA or pB < lB) and (lC < self.beam_width): + if pA < lA and (pB >= lB or cand_A[pA][1] > cand_B[pB][1]): + C.append(cand_A[pA]) + pA += 1 + else: + C.append(cand_B[pB]) + pB += 1 + lC += 1 + return C + + """ + Input: + x = batch * encoder_len (word_ids) encoder's input + k: top-k sampling + Output: + sents: list of words, with batch items, each one with up to beam_width (sentence, log_prob), each sentence with up to max_sent_len_word words + """ + + def topk(self, x, k): + batch_size = x.shape[0] + x_len = x.shape[1] + x_out = self.encoder(x) + # x_out: batch x encoder_len x n_state + + # sent_ids is the words for each of the batch_size sentences + sent_ids = [] + for i in range(batch_size): + sent_ids.append([self.begin_id]) + + topk = 1 + MIN_GEN_LENGTH = 45 + if "MIN_GEN_LENGTH" in self.opt: + MIN_GEN_LENGTH = int(self.opt["MIN_GEN_LENGTH"]) + for l in range(self.max_sent_len): + y = ( + Variable(torch.LongTensor(sent_ids)).cuda() + if self.use_cuda + else Variable(torch.LongTensor(sent_ids)) + ) # batch_size x l + decoder_outputs = self.decoder(x, x_out, y) + probs = decoder_outputs[ + :, -1, : + ] # batch_size x vocab_size (only take the last output) + for i in range(batch_size): + topk_probs, _ = torch.topk(probs[i], k) + threshold = float(topk_probs[-1]) + probs[i][probs[i] < threshold] = 0.0 + + samples = torch.multinomial( + probs, 2 + ) # sample 2 since the first one may be + for i in range(batch_size): + if l < MIN_GEN_LENGTH and self.vocab[int(samples[i, 0])] == "": + sent_ids[i].append(int(samples[i, 1])) + else: + sent_ids[i].append(int(samples[i, 0])) + + sents = [] + for i in range(batch_size): + utt = [] + for j in range(len(sent_ids[i])): + w = self.vocab[sent_ids[i][j]] + if w == "": + continue + if w == "": + break + utt.append(w) + sents.append([(utt, 0)]) + + return sents + + """ + Input: + x = batch * encoder_len (word_ids) encoder's input + Output: + sents: list of words, with batch items, each one with up to beam_width (sentence, log_prob), each sentence with up to max_sent_len_word words + """ + + def beam_search(self, x): + batch_size = x.shape[0] + x_len = x.shape[1] + x_out = self.encoder(x) + # x_out: batch x encoder_len x n_state + + sents = [] + topk = 1 + history_nodes = [{}] + end_nodes = {} + for idx in range(batch_size): + start_node = BeamSearchNode([self.begin_id], 0, 1) + history_nodes[0][idx] = [start_node] + end_nodes[idx] = [] + + for l in range(self.max_sent_len): + last_nodes = history_nodes[-1] + if sum([len(l) for i, l in last_nodes.items()]) == 0: # no nodes left + break + ys = [] + x_outs = [] + xs = [] + for idx in range(batch_size): + ys.extend([node.word_ids for node in last_nodes[idx]]) + x_outs.extend( + [x_out[idx, :, :].unsqueeze(0) for node in last_nodes[idx]] + ) + xs.extend([x[idx, :].unsqueeze(0) for node in last_nodes[idx]]) + + ys = ( + Variable(torch.LongTensor(ys)).cuda() + if self.use_cuda + else Variable(torch.LongTensor(ys)) + ) # N x l + x_outs = torch.cat(x_outs, dim=0) # N x x_len x n_state + xs = torch.c at(xs, dim=0) # N x x_len + probs = self.decoder(xs, x_outs, ys) + log_probs = torch.log( + probs[:, -1, :] + 1e-15 + ) # N x vocab_size (only take the last output) + + history_nodes.append({}) + p = 0 + for idx in range(batch_size): + history_nodes[-1][idx] = [] + N = len(last_nodes[idx]) + if N == 0: + continue + log_prob = log_probs[p : p + N] + p += N + # log_prob = N x extended_vocab_size + + # generate + candidates = [] + for k in range(N): + logprobs, ids = torch.topk(log_prob[k], self.beam_width) + candidates = self.merge_candidates( + candidates, [(k, p, d) for p, d in zip(logprobs, ids)] + ) + + candidates = candidates[: self.beam_width] + extended_nodes_in_last_nodes = set() + for k in range(len(candidates)): + h, logp, next_word_id = candidates[ + k + ] # h means "the h-th node in last_nodes" + logp = float(logp) + next_word_id = int(next_word_id) + prev_node = last_nodes[idx][h] + next_wordids = prev_node.word_ids + [next_word_id] + next_word = self.vocab[next_word_id] + + next_node = BeamSearchNode( + next_wordids, prev_node.log_prob + logp, prev_node.length + 1 + ) + if next_node.duplicate == False: # no duplicate trigram generated + extended_nodes_in_last_nodes.add(h) + if next_word == "" or l == self.max_sent_len - 1: + end_nodes[idx].append((next_node.eval(), next_node)) + else: + history_nodes[-1][idx].append(next_node) + + special_words = ["", "", "", "", "", ""] + for k in range(N): + if k not in extended_nodes_in_last_nodes: + node = last_nodes[idx][k] + effective_word_count = sum( + [ + 1 + for x in node.word_ids + if self.vocab[x] not in special_words + ] + ) + if effective_word_count >= 5: + end_nodes[idx].append((node.eval(), node)) + + MIN_GEN_LENGTH = 45 + if "MIN_GEN_LENGTH" in self.opt: + MIN_GEN_LENGTH = int(self.opt["MIN_GEN_LENGTH"]) + for idx in range(batch_size): + t = len([w for w in end_nodes[idx] if w[1].length > MIN_GEN_LENGTH]) + if t > 0: + end_nodes[idx] = [ + w for w in end_nodes[idx] if w[1].length > MIN_GEN_LENGTH + ] + + end_nodes[idx].sort(key=lambda tup: tup[0], reverse=True) + candidates = [] + for score, node in end_nodes[idx][:topk]: + utt = [self.vocab[x] for x in node.word_ids] + utt = [x for x in utt if x not in ["", ""]] + candidates.append((utt, score)) + if len(candidates) == 0: + candidates.append(("", 0)) + sents.append(candidates) + + return sents + + +class BeamSearchNode(object): + def __init__(self, word_ids, log_prob, length): + self.word_ids = word_ids + self.log_prob = log_prob + self.length = length + + trigram_set = set() + self.duplicate = False + + for i in range(2, len(word_ids)): + trigram = ( + str(word_ids[i - 2]) + + " " + + str(word_ids[i - 1]) + + " " + + str(word_ids[i]) + ) + if trigram in trigram_set: + self.duplicate = True + break + trigram_set.add(trigram) + + def eval(self): + return self.log_prob / float(self.length - 1.0 + 1e-6) + + def __lt__(self, other): + return self.length < other.length diff --git a/model/third_party/HMNet/Models/Optimizers/LnrWrmpInvSqRtDcyScheduler.py b/model/third_party/HMNet/Models/Optimizers/LnrWrmpInvSqRtDcyScheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..c9ce98d92c4eb2fcd9b688c8ca6d8fb49a842875 --- /dev/null +++ b/model/third_party/HMNet/Models/Optimizers/LnrWrmpInvSqRtDcyScheduler.py @@ -0,0 +1,27 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import math +from torch.optim.lr_scheduler import LambdaLR + + +class LnrWrmpInvSqRtDcyScheduler(LambdaLR): + """Inverse Square Root learning rate schedule used in T5""" + + def __init__(self, optimizer, warmup_steps, warmup_init_lr, warmup_end_lr): + self.warmup_steps = warmup_steps + self.warmup_init_lr = warmup_init_lr + self.warmup_end_lr = warmup_end_lr + self.lr_step = (warmup_end_lr - warmup_init_lr) / warmup_steps + super(LnrWrmpInvSqRtDcyScheduler, self).__init__( + optimizer, self.lr_lambda, last_epoch=-1 + ) + + def lr_lambda(self, step): + if step < self.warmup_steps: + return (self.warmup_init_lr + step * self.lr_step) / self.warmup_end_lr + else: + return 1.0 / float(math.sqrt(step / float(self.warmup_steps))) + + def get_last_lr(self): + return self.get_lr() diff --git a/model/third_party/HMNet/Models/Optimizers/RAdam.py b/model/third_party/HMNet/Models/Optimizers/RAdam.py new file mode 100644 index 0000000000000000000000000000000000000000..b74642c2f8870d37d0faa9a4824f2bb8c5fbe331 --- /dev/null +++ b/model/third_party/HMNet/Models/Optimizers/RAdam.py @@ -0,0 +1,247 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import math +import torch +from torch.optim.optimizer import Optimizer, required + + +class RAdam(Optimizer): + """ + @article{liu2019radam, + title={On the Variance of the Adaptive Learning Rate and Beyond}, + author={Liu, Liyuan and Jiang, Haoming and He, Pengcheng and Chen, Weizhu and Liu, Xiaodong and Gao, Jianfeng and Han, Jiawei}, + journal={arXiv preprint arXiv:1908.03265}, + year={2019} + } + """ + + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + self.buffer = [[None, None, None] for ind in range(10)] + super(RAdam, self).__init__(params, defaults) + + def __setstate__(self, state): + super(RAdam, self).__setstate__(state) + + def step(self, closure=None): + + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad.data.float() + if grad.is_sparse: + raise RuntimeError("RAdam does not support sparse gradients") + + p_data_fp32 = p.data.float() + + state = self.state[p] + + if len(state) == 0: + state["step"] = 0 + state["exp_avg"] = torch.zeros_like(p_data_fp32) + state["exp_avg_sq"] = torch.zeros_like(p_data_fp32) + else: + state["exp_avg"] = state["exp_avg"].type_as(p_data_fp32) + state["exp_avg_sq"] = state["exp_avg_sq"].type_as(p_data_fp32) + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + beta1, beta2 = group["betas"] + + exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) + exp_avg.mul_(beta1).add_(1 - bet a1, grad) + + state["step"] += 1 + buffered = self.buffer[int(state["step"] % 10)] + if state["step"] == buffered[0]: + N_sma, step_size = buffered[1], buffered[2] + else: + buffered[0] = state["step"] + beta2_t = beta2 ** state["step"] + N_sma_max = 2 / (1 - beta2) - 1 + N_sma = N_sma_max - 2 * state["step"] * beta2_t / (1 - beta2_t) + buffered[1] = N_sma + + # more conservative since it's an approximated value + if N_sma >= 5: + step_size = ( + group["lr"] + * math.sqrt( + (1 - beta2_t) + * (N_sma - 4) + / (N_sma_max - 4) + * (N_sma - 2) + / N_sma + * N_sma_max + / (N_sma_max - 2) + ) + / (1 - beta1 ** state["step"]) + ) + else: + step_size = group["lr"] / (1 - beta1 ** state["step"]) + buffered[2] = step_size + + if group["weight_decay"] != 0: + p_data_fp32.add_(-group["weight_decay"] * group["lr"], p_data_fp32) + + # more conservative since it's an approximated value + if N_sma >= 5: + denom = exp_avg_sq.sqrt().add_(group["eps"]) + p_data_fp32.addcdiv_(-step_size, exp_avg, denom) + else: + p_data_fp32.add_(-step_size, exp_avg) + + p.data.copy_(p_data_fp32) + + return loss + + +class PlainRAdam(Optimizer): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + + super(PlainRAdam, self).__init__(params, defaults) + + def __setstate__(self, state): + super(PlainRAdam, self).__setstate__(state) + + def step(self, closure=None): + + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad.data.float() + if grad.is_sparse: + raise RuntimeError("RAdam does not support sparse gradients") + + p_data_fp32 = p.data.float() + + state = self.state[p] + + if len(state) == 0: + state["step"] = 0 + state["exp_avg"] = torch.zeros_like(p_data_fp32) + state["exp_avg_sq"] = torch.zeros_like(p_data_fp32) + else: + state["exp_avg"] = state["exp_avg"].type_as(p_data_fp32) + state["exp_avg_sq"] = state["exp_avg_sq"].type_as(p_data_fp32) + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + beta1, beta2 = group["betas"] + + exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) + exp_avg.mul_(beta1).add_(1 - beta1, grad) + + state["step"] += 1 + beta2_t = beta2 ** state["step"] + N_sma_max = 2 / (1 - beta2) - 1 + N_sma = N_sma_max - 2 * state["step"] * beta2_t / (1 - beta2_t) + + if group["weight_decay"] != 0: + p_data_fp32.add_(-group["weight_decay"] * group["lr"], p_data_fp32) + + # more conservative since it's an approximated value + if N_sma >= 5: + step_size = ( + group["lr"] + * math.sqrt( + (1 - beta2_t) + * (N_sma - 4 ) + / (N_sma_max - 4) + * (N_sma - 2) + / N_sma + * N_sma_max + / (N_sma_max - 2) + ) + / (1 - beta1 ** state["step"]) + ) + denom = exp_avg_sq.sqrt().add_(group["eps"]) + p_data_fp32.addcdiv_(-step_size, exp_avg, denom) + else: + step_size = group["lr"] / (1 - beta1 ** state["step"]) + p_data_fp32.add_(-step_size, exp_avg) + + p.data.copy_(p_data_fp32) + + return loss + + +class AdamW(Optimizer): + def __init__( + self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, warmup=0 + ): + defaults = dict( + lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, warmup=warmup + ) + super(AdamW, self).__init__(params, defaults) + + def __setstate__(self, state): + super(AdamW, self).__setstate__(state) + + def step(self, closure=None): + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad.data.float() + if grad.is_sparse: + raise RuntimeError( + "Adam does not support sparse gradients, please consider SparseAdam instead" + ) + + p_data_fp32 = p.data.float() + + state = self.state[p] + + if len(state) == 0: + state["step"] = 0 + state["exp_avg"] = torch.zeros_like(p_data_fp32) + state["exp_avg_sq"] = torch.zeros_like(p_data_fp32) + else: + state["exp_avg"] = state["exp_avg"].type_as(p_data_fp32) + state["exp_avg_sq"] = state["exp_avg_sq"].type_as(p_data_fp32) + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + beta1, beta2 = group["betas"] + + state["step"] += 1 + + exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) + exp_avg.mul_(beta1).add_(1 - beta1, grad) + + denom = exp_avg_sq.sqrt().add_(group["eps"]) + bias_correction1 = 1 - beta1 ** state["step"] + bias_correction2 = 1 - beta2 ** state["step"] + + if group["warmup"] > state["step"]: + scheduled_lr = 1e-8 + state["step"] * group["lr"] / group["warmup"] + else: + scheduled_lr = group["lr"] + + step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1 + + if group["weight_decay"] != 0: + p_data_fp32.add_(-group["weight_decay"] * scheduled_lr, p_data_fp32) + + p_data_fp32.addcdiv_(-step_size, exp_avg, denom) + + p.data.copy_(p_data_fp32) + + return loss diff --git a/model/third_party/HMNet/Models/Trainers/BaseTrainer.py b/model/third_party/HMNet/Models/Trainers/BaseTrainer.py new file mode 100644 index 0000000000000000000000000000000000000000..532070d5d776b1a2b9435522b7fe9d03224ff87f --- /dev/null +++ b/model/third_party/HMNet/Models/Trainers/BaseTrainer.py @@ -0,0 +1,68 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import os + + +class BaseTrainer: + def __init__(self, opt): + self.opt = opt + if self.opt["cuda"] == True: + self.use_cuda = True + print("Using Cuda\n") + else: + self.use_cuda = False + print("Using CPU\n") + + self.is_official = "OFFICIAL" in self.opt + self.opt["logFile"] = "log.txt" + self.saveFolder = None + self.logFileHandle = None + self.tb_writer = None + + def log(self, s): + # In official c ase, the program does not output logs + if self.is_official: + return + try: + if self.logFileHandle is None: + self.logFileHandle = open( + os.path.join(self.saveFolder, self.opt["logFile"]), "a" + ) + self.logFileHandle.write(s + "\n") + except Exception as e: + print("ERROR while writing log file:", e) + print(s) + + def getSaveFolder(self): + runid = 1 + while True: + saveFolder = os.path.join( + self.opt["datadir"], + self.opt["basename"] + "_conf~", + "run_" + str(runid), + ) + if not os.path.exists(saveFolder): + self.saveFolder = saveFolder + os.makedirs(self.saveFolder) + print("Saving logs, model and evaluation in " + self.saveFolder) + return + runid = runid + 1 + + # save copy of conf file + def saveConf(self): + # with open(self.opt['confFile'], encoding='utf-8') as f: + # with open(os.path.join(self.saveFolder, 'conf_copy.tsv'), 'w', encoding='utf-8') as fw: + # for line in f: + # fw.write(line) + with open( + os.path.join(self.saveFolder, "conf_copy.tsv"), "w", encoding="utf-8" + ) as fw: + for k in self.opt: + fw.write("{0}\t{1}\n".format(k, self.opt[k])) + + def train(self): + pass + + def load(self): + pass diff --git a/model/third_party/HMNet/Models/Trainers/DistributedTrainer.py b/model/third_party/HMNet/Models/Trainers/DistributedTrainer.py new file mode 100644 index 0000000000000000000000000000000000000000..3ae8bf565f151c8746033f7832a17e0e9ea0b6f3 --- /dev/null +++ b/model/third_party/HMNet/Models/Trainers/DistributedTrainer.py @@ -0,0 +1,148 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import os +import torch +from torch.utils.tensorboard import SummaryWriter +import random +import numpy as np + +from pkg_resources import parse_version +from model.third_party.HMNet.Models.Trainers.BaseTrainer import BaseTrainer +from model.third_party.HMNet.Utils.GeneralUtils import bcolors +from model.third_party.HMNet.Utils.distributed import distributed + + +class DistributedTrainer(BaseTrainer): + def __init__(self, opt): + super().__init__(opt) + + self.seed = int(self.opt["SEED"]) if "SEED" in self.opt else 0 + + random.seed(self.seed) + np.random.seed(self.seed) + torch.manual_seed(self.seed) + + ( + self.opt["device"], + _, + self.opt["world_size"], + self.opt["local_size"], + self.opt["rank"], + self.opt["local_rank"], + _, + self.opt["run"], + ) = distributed(opt, not self.use_cuda) + + self.getSaveFolder() + self.opt["logFile"] = f"log_{self.opt['rank']}.txt" + self.saveConf() + + self.high_pytorch_version = parse_version(torch.__version__) >= parse_version( + "1.2.0" + ) + if self.opt["rank"] == 0: + print( + bcolors.OKGREEN, + torch.__version__, + bcolors.ENDC, + "is", + "high" if self.high_pytorch_version else "low", + ) + + if self.use_cuda: + # torch.cuda.manual_seed_all(self.seed) + # ddp: only set seed on GPU associated with this process + torch.cuda.manual_seed(self.seed) + + # ddp: print stats and update learning rate + if self.opt["rank"] == 0: + print( + "Number of GPUs is", + bcolors.OKGREEN, + self.opt["world_size"], + bcolors.ENDC, + ) + # print('Boost learning rate from', bcolors.OKGREEN, self.opt['START_LEARNING_RATE'], bcolors.ENDC, 'to', + # bcolors.OKGREEN, self.opt['START_LEARNING_RATE'] * self.opt ['world_size'], bcolors.ENDC) + print( + "Effective batch size is increased from", + bcolors.OKGREEN, + self.opt["MINI_BATCH"], + bcolors.ENDC, + "to", + bcolors.OKGREEN, + self.opt["MINI_BATCH"] * self.opt["world_size"], + bcolors.ENDC, + ) + + self.grad_acc_steps = 1 + if "GRADIENT_ACCUMULATE_STEP" in self.opt: + if self.opt["rank"] == 0: + print( + "Gradient accumulation steps =", + bcolors.OKGREEN, + self.opt["GRADIENT_ACCUMULATE_STEP"], + bcolors.ENDC, + ) + # print('Boost learning rate from', bcolors.OKGREEN, self.opt['START_LEARNING_RATE'], bcolors.ENDC, 'to', + # bcolors.OKGREEN, self.opt['START_LEARNING_RATE'] * self.opt['world_size'] * self.opt['GRADIENT_ACCUMULATE_STEP'], bcolors.ENDC) + print( + "Effective batch size =", + bcolors.OKGREEN, + self.opt["MINI_BATCH"] + * self.opt["world_size"] + * self.opt["GRADIENT_ACCUMULATE_STEP"], + bcolors.ENDC, + ) + self.grad_acc_steps = int(self.opt["GRADIENT_ACCUMULATE_STEP"]) + # self.opt['START_LEARNING_RATE'] *= self.opt['world_size'] * self.grad_acc_steps + + def tb_log_scalar(self, name, value, step): + if self.opt["rank"] == 0: + if self.tb_writer is None: + self.tb_writer = SummaryWriter( + os.path.join(self.saveFolder, "tensorboard") + ) + self.tb_writer.add_scalar(name, value, step) + + def log(self, s): + # When 'OFFICIAL' flag is set in the config file, the program does not output logs + if self.is_official: + return + try: + if self.logFileHandle is None: + self.logFileHandle = open( + os.path.join(self.saveFolder, self.opt["logFile"]), "a" + ) + self.logFileHandle.write(s + "\n") + except Exception as e: + print("ERROR while writing log file:", e) + print(s) + + def getSaveFolder(self): + runid = 1 + while True: + saveFolder = os.path.join( + self.opt["datadir"], + self.opt["basename"] + "_conf~", + "run_" + str(runid), + ) + if not os.path.isdir(saveFolder): + if self.opt["world_size"] > 1: + torch.distributed.barrier() + if self.opt["rank"] == 0: + os.makedirs(saveFolder) + self.saveFolder = saveFolder + if self.opt["world_size"] > 1: + torch.distributed.barrier() + print( + "Saving logs, model, checkpoint, and evaluation in " + + self.saveFolder + ) + return + runid = runid + 1 + + def saveConf(self): + if self.opt["rank"] == 0: + super().saveConf() diff --git a/model/third_party/HMNet/Models/Trainers/HMNetTrainer.py b/model/third_party/HMNet/Models/Trainers/HMNetTrainer.py new file mode 100644 index 0000000000000000000000000000000000000000..771e4883f7325e18d99c7ef1686fb1393a36ebe4 --- /dev/null +++ b/model/third_party/HMNet/Models/Trainers/HMNetTrainer.py @@ -0,0 +1,689 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from collections import defaultdict +from datetime import datetime +import os +import sys +import importlib +import json +import random +import numpy as np +import inspect +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import torch.optim.lr_scheduler as lr_scheduler + +from model.third_party.HMNet.Models.Trainers.DistributedTrainer import ( + DistributedTrainer, +) +fr om model.third_party.HMNet.Models.Trainers.Tasks import Task +from model.third_party.HMNet.Utils.GeneralUtils import ( + AverageMeter, + BaseBatchGen, + bcolors, +) + +from model.third_party.HMNet.DataLoader import iterators + + +class ObjectView(object): + def __init__(self, d): + self.__dict__ = d + + +class WrappedModel(nn.Module): + def __init__(self, model, criterion): + super(WrappedModel, self).__init__() + self.add_module("model", model) + self.add_module("criterion", criterion) + + def forward(self, batch): + output = self.model(batch) + loss = self.criterion(output, batch) + return loss + + +class HMNetTrainer(DistributedTrainer): + """ + The trainer class for HMNet model training (pre-train and fine-tune.) + Its train() and eval() methods are intended to directly called to + start training and evaluation respectively. + + Before running, the trainer must contain proper Task, Criterion, and Optimizer + instances. + + """ + + def __init__(self, opt): + super().__init__(opt) + self.task = Task.setup_task(self.opt["TASK"], self.opt, self.saveFolder) + + def is_gradient_accumulation_boundary(self): + return (self.updates + 1) % self.grad_acc_steps == 0 + + def get_batch_generator(self, dataset_label): + batch_generator = self.task.batch_gen( + self.opt, + dataset_label=dataset_label, + model_config=self.module.config, + tokenizer=self.module.tokenizer, + world_size=self.opt["world_size"], + rank=self.opt["rank"], + seed=self.seed, + ) + if isinstance(batch_generator, BaseBatchGen): + # If it is a wrapper class of an infinibatch iterator, + # get the internal infnitibatch iterator. + batch_generator = batch_generator.iterator + self.log(f"Loaded data on rank {self.opt['rank']}.") + return batch_generator + + def set_up_model(self): + # instantiate module (tokenizer should be contained in module as self.module.tokenizer) + try: + model_module = importlib.import_module( + "model.third_party.HMNet.Models.Networks." + self.opt["MODEL"] + ) + model_class = getattr(model_module, self.opt["MODEL"]) + self.module = model_class(self.opt) + except Exception as e: + self.log(e) + self.log("ERROR: Model {} is unknown".format(self.opt["MODEL"])) + assert False + + # calculate total trainable parameters + pytorch_total_params = sum( + p.numel() for p in self.module.parameters() if p.requires_grad + ) + self.log("Total trainable parameters: {}".format(pytorch_total_params)) + + # instantiate criterion + try: + criterion_module = importlib.import_module( + "model.third_party.HMNet.Models.Criteria." + self.opt["CRITERION"] + ) + criterion_class = getattr(criterion_module, self.opt["CRITERION"]) + self.criterion = criterion_class(self.opt, self.module) + except Exception as e: + self.log(e) + self.log("ERROR: Criterion {} is unknown".format(self.opt["CRITERION"])) + assert False + + self.module.to(self.opt["device"]) + + def get_optimizer_params_config(self, optimizer_class): + optimizer_parameters = {} + sig = inspect.signature(optimizer_class) + for param_name in sig.parameters.keys(): + if param_name == "lr": + optimizer_parameters[param_name] = self.opt["START_LEARNING_RATE"] + if param_name not in ["params", "lr"] and param_name.upper() in self.opt: + optimizer_parameters[param_name] = self.opt[param_name.upper()] + return optimizer_parameters + + def get_lr_scheduler_params_config(self, lr_scheduler_class): + lr_scheduler_parameters = {} + sig = inspect.signature(lr_scheduler_class) + for param_ name in sig.parameters.keys(): + if param_name not in ["optimizer"] and param_name.upper() in self.opt: + lr_scheduler_parameters[param_name] = self.opt[param_name.upper()] + return lr_scheduler_parameters + + def set_up_optimizer_and_lr_scheduler(self): + + parameters = self.module.get_training_parameters() + + # instantiate optimizer + try: # first try pytorch native optimizer + optimizer_class = getattr(optim, self.opt["OPTIMIZER"]) + self.log( + "Using pytorch native optimizier: {}".format(self.opt["OPTIMIZER"]) + ) + except: + try: # then try custom optimizer inside Models.Optimizers + optimizer_module = importlib.import_module( + "model.third_party.HMNet.Models.Optimizers." + self.opt["OPTIMIZER"] + ) + optimizer_class = getattr(optimizer_module, self.opt["OPTIMIZER"]) + self.log("Using custom optimizer: {}".format(self.opt["OPTIMIZER"])) + except Exception as e: + self.log(e) + self.log("ERROR: Optimizer {} is unknown".format(self.opt["OPTIMIZER"])) + assert False + + optimizer_parameters = self.get_optimizer_params_config(optimizer_class) + self.log(f"Optimizer parameters: {optimizer_parameters}") + self.optimizer = optimizer_class(parameters, **optimizer_parameters) + self.optimizer.zero_grad() + + # instantiate lr scheduler + try: # first look for pytorch native lr scheduler + lr_scheduler_class = getattr(lr_scheduler, self.opt["LR_SCHEDULER"]) + self.log( + "Using pytorch native lr scheduler: {}".format(self.opt["LR_SCHEDULER"]) + ) + except: + try: # then look for custom lr scheduler inside Models.Optimizers + lr_scheduler_module = importlib.import_module( + "model.third_party.HMNet.Models.Optimizers." + + self.opt["LR_SCHEDULER"] + ) + lr_scheduler_class = getattr( + lr_scheduler_module, self.opt["LR_SCHEDULER"] + ) + self.log( + "Using custom lr scheduler: {}".format(self.opt["LR_SCHEDULER"]) + ) + except Exception as e: + self.log(e) + self.log( + "ERROR: LR Scheduler {} is unknown".format(self.opt["LR_SCHEDULER"]) + ) + assert False + + lr_scheduler_parameters = self.get_lr_scheduler_params_config( + lr_scheduler_class + ) + self.log(f"Lr scheduler parameters: {lr_scheduler_parameters}") + self.lr_scheduler = lr_scheduler_class( + self.optimizer, **lr_scheduler_parameters + ) + + def initialize_fp16_DDP(self): + """ + Wrap the module and criterion to a single network, then depending on the settings, + wrap the network with apex amp module for fp16 training, and wrap the network with + pytorch DDP module for distributed data parallel training + """ + self.network = WrappedModel(self.module, self.criterion) + self.network.to(self.opt["device"]) + + if self.opt["fp16"]: + from apex import amp + + self.network, self.optimizer = amp.initialize( + self.network, self.optimizer, opt_level=self.opt["fp16_opt_level"] + ) + + if self.opt["world_size"] > 1: + self.network = torch.nn.parallel.DistributedDataParallel( + self.network, + device_ids=[self.opt["local_rank"]], + output_device=self.opt["local_rank"], + find_unused_parameters=True, + ) + self.log(f"Wrapped model with DDP on rank {self.opt['rank']}.") + assert self.module is self.network.module.model + else: + assert self.module is self.network.model + + def eval(self): + if self.opt["rank"] == 0: + self.log("-----------------------------------------------") + self.log("Evaluating model ... ") + self.set_up_model() + + for eval_dataset in ["dev", "test"]: + batch_generator_eval = self.get_batch_generator(eval_dataset) + + self.task.evaluator.reset_best_score(set_high=True) + result, score, got_better_score = self.task.evaluator.eval_batches( + self.module, batch_generator_eval, self.saveFolder, eval_dataset + ) + if self.opt["rank"] == 0: + self.log("{0} results breakdown\n{1}".format(eval_dataset, result)) + + def eval_return_results(self): + if self.opt["rank"] == 0: + self.log("-----------------------------------------------") + self.log("Evaluating model ... ") + self.set_up_model() + + for eval_dataset in ["test"]: + batch_generator_eval = self.get_batch_generator(eval_dataset) + + self.task.evaluator.reset_best_score(set_high=True) + result, score, got_better_score = self.task.evaluator.eval_batches( + self.module, batch_generator_eval, self.saveFolder, eval_dataset + ) + if self.opt["rank"] == 0: + self.log("{0} results breakdown\n{1}".format(eval_dataset, result)) + return result + + def train(self): + self.log(f"train on rank {self.opt['rank']}") + if self.opt["rank"] == 0: + self.log("-----------------------------------------------") + self.log("Initializing model...") + + self.set_up_model() # setup self.module as original model + self.network = None + self.train_batch_generator = self.get_batch_generator("train") + if isinstance(self.train_batch_generator, iterators.CheckpointableIterator): + # training batch generator is infinite + self.updates_per_epoch = self.opt["UPDATES_PER_EPOCH"] + else: + self.updates_per_epoch = len(self.train_batch_generator) + self.updates = 0 + self.optim_steps = 0 + self.start_epoch_idx = 0 + self.start_batch_idx = 0 + + self.set_up_optimizer_and_lr_scheduler() + self.initialize_fp16_DDP() + if "RESUME" in self.opt: + # Resume complete training states, including optimizer, lr_scheduler, train batch generator, and updates count + # from the checkpoint location indicated in a .json file + self.load_checkpoint() + + ###################### + # Start the main loop + ###################### + + numEpochs = self.opt["MAX_NUM_EPOCHS"] + self.train_loss = AverageMeter() # track the average training loss + self.acc_loss = 0.0 + # after every 'SAVE_PER_UPDATE_NUM' updates, it will save a checkpoint by setting save_a_checkpoint to True temporarily + save_a_checkpoint = False + for epoch in range(self.start_epoch_idx, numEpochs): + self.current_epoch_idx = epoch + self.log("Epoch {}".format(epoch)) + + startTime = datetime.now() + + for batch_idx, batch in enumerate(self.train_batch_generator): + if self.current_epoch_idx == self.start_epoch_idx: + if isinstance( + self.train_batch_generator, iterators.CheckpointableIterator + ): + batch_idx += self.start_batch_idx + elif batch_idx < self.start_batch_idx: + continue + self.current_batch_idx = batch_idx + + # after every 'SAVE_PER_UPDATE_NUM' updates, save a checkpoint + if ("SAVE_PER_UPDATE_NUM" in self.opt) and ( + self.updates + 1 + ) % self.opt["SAVE_PER_UPDATE_NUM"] == 0: + # Make sure the next update is going to update the weights and zero the gradients, then we can checkpoint + asse rt self.is_gradient_accumulation_boundary() + save_a_checkpoint = True + + # update + self.update(batch) + + if save_a_checkpoint: + # evaluate at the checkpointed moment, and log the results + if self.task.evaluator is not None: + evaluate_label = "update_" + str(self.updates) + eval_dataset = "dev" + batches = self.get_batch_generator(eval_dataset) + ( + result, + score, + got_better_score, + ) = self.task.evaluator.eval_batches( + self.module, batches, self.saveFolder, evaluate_label + ) + self.tb_log_scalar("Eval/score", score, self.updates) + if got_better_score: + self.log( + "Got new better score on rank-{0} evaluator, at updates {1}".format( + self.opt["rank"], self.updates + ) + ) + self.log( + "Updates {0} - {1}: Current Score: {2:.3f} (best Score: {3:.3f})".format( + self.updates, + eval_dataset, + score, + self.task.evaluator.best_score, + ) + ) + self.log("Current results breakdown\n{0}".format(result)) + self.log( + "Best results breakdown\n{0}".format( + self.task.evaluator.best_res + ) + ) + # save complete training states, including model weights, optimizer, lr_scheduler, batch generator, and updates count + self.save_checkpoint(self.updates) + save_a_checkpoint = False + + # logging + if ( + (batch_idx % 10 == 0) + or (epoch == 0 and batch_idx <= 50) + or "DEBUG" in self.opt + ): + if self.opt["rank"] == 0: + batch_size = batch["encoder_input_ids"].shape[0] + self.log( + "epochs[{0:6}] updates[{1:6}] bsz[{2:d}] train loss[{3:.5f}] avg train loss[{4:.5f}] learning rate[{5:.5e}] remaining[{6}]".format( + epoch, + self.updates, + batch_size, + self.train_loss.val, + self.train_loss.avg, + self.lr_scheduler.get_lr()[0], + str( + (datetime.now() - startTime) + / (batch_idx + 1) + * (self.updates_per_epoch - batch_idx - 1) + ).split(".")[0], + ) + ) + + self.tb_log_scalar( + "Loss/train_val", self.train_loss.val, self.updates + ) + self.tb_log_scalar( + "Loss/train_avg", self.train_loss.avg, self.updates + ) + self.tb_log_scalar( + "Learning Rate/lr", + self.lr_scheduler.get_lr()[0], + self.updates, + ) + + # if "DEBUG" in self.opt and batch_idx > 200: # exist early for DEBUG mode + # break + + if ( + isinstance( + self.train_batch_generator, iterators.CheckpointableIterator + ) + and batch_idx + 1 == self.updates_per_epoch + ): + break + + self.log("This epoch takes" + str(datetime.now() - startTime)) + self.log("PROGRESS: {0:.2f}%".format(100.0 * (epoch + 1) / numEpochs)) + self.log("Config file is at " + self.opt["confFile"]) + + if "DEBUG" in self.opt: # exist early for DEBUG mode + break + + def update(self, batch): + # forward loss, backward propagation, model update, and one step of optimization and lr scheduler + self.network.train() + # put the batch to the device + # @TODO make this more general, maybe have a self.task.move_batch(batch, device) + # so the trainer decides when and where to move batches, and task tells how + if isinstance(batch, tuple): + batch = tuple(t.to(self.opt["device"]) for t in batch) + elif isinstance(batch, list): + batch = [t.to(self.opt["device"]) for t in batch] + elif isinstance(batch, dict): + for k in batch: + if torch.is_tensor(batch[k]): + batch[k] = batch[k].to(self.opt["device"]) + else: + assert torch.is_tensor(batch) + batch = batch.to(self.opt["device"]) + + # determine whether gradient sync can be skiped or not for this update + skip_gradient_sync = False + if self.opt["world_size"] > 1 and not self.is_gradient_accumulation_boundary(): + if not self.opt["fp16"]: + # https://krishansubudhi.github.io/deeplearning/2020/02/06/apex-gradient-accumulation.html + # When using fp16, if we skip grad sync during grad accumulation, the grad sync at the + # grad accumulation boundary cannot properly sync the whole accumulated grad. + # So with fp16 on, we have to sync even if it's not grad accumulation boundary. + if self.high_pytorch_version: + skip_gradient_sync = True + + # forward + if skip_gradient_sync: + with self.network.no_sync(): + loss = self.network(batch) + else: + loss = self.network(batch) + if self.grad_acc_steps > 1: + loss = loss / self.grad_acc_steps + self.acc_loss += loss + # self.log(f"forward() done on rank {self.opt['rank']}") + # print(loss.item()) + + # backward + def backward(loss_tensor): + if self.opt["fp16"]: + from apex import amp + + with amp.scale_loss(loss_tensor, self.optimizer) as scaled_loss: + scaled_loss.backward() + else: + loss_tensor.backward() + + if skip_gradient_sync: + with self.network.no_sync(): + backward(loss) + else: + if "DEBUG" in self.opt and self.opt["rank"] == 0: + self.log( + "Performing synchronized backward at step {0}".format( + self.optim_steps + ) + ) + backward(loss) + # self.log(f"backward() done on rank {self.opt['rank']}") + + # step + if self.is_gradient_accumulation_boundary(): + if self.opt["world_size"] > 1: + # ddp: use all_reduce to sum up values of self.acc_loss over all processes + # the operations happens in place (i.e., the value of self.acc_loss is replaced) and all processes received the updated value + torch.distributed.all_reduce( + self.acc_loss, torch.distributed.ReduceOp.SUM + ) + self.acc_loss /= self.opt["world_size"] + self.train_loss.update(self.acc_loss.data, 1) + self.acc_loss = 0.0 + if "GRAD_CLIPPING" in self.opt: + if self.opt["fp16"]: + from apex import amp + + torch.nn.utils.clip_grad_norm_( + amp.master_params(self.optimizer), self.opt["GRAD_CLIPPING"] + ) + else: + torch.nn.utils.clip_grad_norm_( + self.network.parameters(), self.opt["GRAD_CLIPPING"] + ) + self.optim_steps += 1 + self.optimizer.step() + self.optimizer.zero_grad() + self.lr_scheduler.step() + + self.updates += 1 + # self.log(f"step() done on rank {self.opt['rank']}") + + def save_checkpoint(self, tag): + """ + Save complete training states, including model weights, optimizer, lr_scheduler, + fp16 loss scaler, random state, batch generator, and updates count + Also save a model with save_pretrained API for model transfer + """ + self.log("Saving checkpoint...") + resume_epoch_idx = self.current_epoch_idx + resume_batch_idx = self.current_batch_idx + 1 + if resume_batch_idx == self.updates_per_epoch: + resume_batch_idx = 0 + resume_epoch_idx += 1 + + if self.opt["fp16"]: + from apex import amp + if self.opt["rank"] == 0: + save_dir = os.path.join(self.saveFolder, str(tag)) + os.makedirs(save_dir) + save_path = os.path.join(save_dir, "training_states.pt") + state = { + "network": self.network.state_dict(), + "optimizer": self.optimizer.state_dict(), + "lr_scheduler": self.lr_scheduler.state_dict(), + "amp": amp.state_dict() if self.opt["fp16"] else None, + "optim_steps": self.optim_steps, + "updates": self.updates, + "updates_per_epoch": self.updates_per_epoch, + "start_epoch_idx": resume_epoch_idx, + "start_batch_idx": resume_batch_idx, + } + + torch.save(state, save_path) + if self.opt["world_size"] > 1: + torch.distributed.barrier() + save_dir = os.path.join(self.saveFolder, str(tag)) + assert os.path.isdir(save_dir) + + random_state_path = os.path.join( + save_dir, "random_state_rank_{:04d}".format(self.opt["rank"]) + ) + random_state = { + "random": random.getstate(), + "numpy_random": np.random.get_state(), + "torch_random": torch.get_rng_state(), + "torch_cuda_random": torch.cuda.get_rng_state(device=self.opt["device"]) + if self.use_cuda + else None, + } + torch.save(random_state, random_state_path) + + if isinstance(self.train_batch_generator, iterators.CheckpointableIterator): + # save batch generators for all ranks + batch_generator_file_path = os.path.join( + save_dir, + "batch_generator_checkpoint_rank_{:04d}".format(self.opt["rank"]), + ) + batch_generator_state = self.train_batch_generator.getstate() + torch.save(batch_generator_state, batch_generator_file_path) + else: + self.log( + "Batch generator is not checkpointable. Cannot save to checkpoint." + ) + + if self.opt["rank"] == 0: + self.module.save_pretrained(save_dir) + + if self.opt["rank"] == 0: + # save the latest checkpoint location to json file + checkpoint_location = { + "checkpoint_tag": str(tag), + "checkpoint_path": os.path.relpath( + self.saveFolder, start=self.opt["datadir"] + ), + } + json.dump( + checkpoint_location, + open( + os.path.join( + self.opt["datadir"], + self.opt["basename"] + "_resume_checkpoint.json", + ), + "w", + encoding="utf-8", + ), + ) + self.log(f"Finished saving checkpoint and model to {save_dir}.") + + def load_model(self, model_path): + # Load the model only, without any training states, using the from_pretrained API + self.module = self.module.from_pretrained(model_path) + self.module.to(self.opt["device"]) + + def load_checkpoint(self): + """ + Load complete training states, including model weights, optimizer, lr_scheduler, + fp16 loss scaler, random state, batch generator, and updates count + """ + try: + # load the checkpoint location from json file + checkpoint_location = json.load( + open( + os.path.join( + self.opt["datadir"], + self.opt["basename"] + "_resume_checkpoint.json", + ), + encoding="utf-8", + ) + ) + checkpoint_path = os.path.join( + self.opt["datadir"], + checkpoint_location["checkpoint_path"], + checkpoint_location["checkpoint_tag"], + ) + tag = checkpoint_location["checkpoint_tag"] + if not os.path.isdir(checkpoint_path): + if self.opt["rank"] == 0: + self.log( + "Checkpoint path {} not exist. Continue without loading checkpoint".format( + checkpoint_path + ) + ) + return + except: + if self.opt["rank"] == 0: + self.log( + f"Cannot find checkpoint path from {self.opt['basename']+'_resume_checkpoint.json'}.\n" + f"Make sure {os.path.join(self.opt['datadir'], self.opt['basename']+'_resume_checkpoint.json')} exists.\n" + f"Continue without loading checkpoint" + ) + return + # save a copy of the resumed checkpoint location in the save folder of current run + if self.opt["rank"] == 0: + json.dump( + checkpoint_location, + open( + os.path.join(self.saveFolder, "resumed_checkpoint.json"), + "w", + encoding="utf-8", + ), + ) + + self.log(f"Loading checkpoint from {checkpoint_path}...") + load_path = os.path.join(checkpoint_path, "training_states.pt") + state = torch.load(load_path, map_location=self.opt["device"]) + self.network.load_state_dict(state["network"]) + self.optimizer.load_state_dict(state["optimizer"]) + self.lr_scheduler.load_state_dict(state["lr_scheduler"]) + if self.opt["fp16"]: + from apex import amp + + amp.load_state_dict(state["amp"]) + self.optim_steps = state["optim_steps"] + self.updates = state["updates"] + self.start_epoch_idx = state["start_epoch_idx"] + self.start_batch_idx = state["start_batch_idx"] + assert self.updates_per_epoch == state["updates_per_epoch"] + assert self.start_batch_idx < self.updates_per_epoch + + random_state_path = os.path.join( + checkpoint_path, "random_state_rank_{:04d}".format(self.opt["rank"]) + ) + random_state = torch.load(random_state_path, map_location="cpu") + random.setstate(random_state["random"]) + np.random.set_state(random_state["numpy_random"]) + torch.set_rng_state(random_state["torch_random"]) + if self.use_cuda: + torch.cuda.set_rng_state( + random_state["torch_cuda_random"], device=self.opt["device"] + ) + + if "RESET_DATA_LOADER" not in self.opt and isinstance( + self.train_batch_generator, iterators.CheckpointableIterator + ): + batch_generator_file_path = os.path.join( + checkpoint_path, + "batch_generator_checkpoint_rank_{:04d}".format(self.opt["rank"]), + ) + batch_generator_state = torch.load( + batch_generator_file_path, map_location="cpu" + ) + self.train_batch_generator.setstate(batch_generator_state) + else: + self.log( + "No need to resume batch generator or batch generator is not checkpointable. Didn't load from checkpoint." + ) + self.log(f"Finished loading checkpoint from {checkpoint_path}.") diff --git a/model/third_party/HMNet/Models/Trainers/Tasks.py b/model/third_party/HMNet/Models/Trainers/Tasks.py new file mode 100644 index 0000000000000000000000000000000000000000..7463abfd9d547af935838c85d0b711998d620902 --- /dev/null +++ b/model/third_party/HMNet/Models/Trainers/Tasks.py @@ -0,0 +1,32 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + + +class Task: + """ + This class is the ensemble of two classes: BatchGen and Eval. + The `setup_task` function defines tasks w.r.t the three components based + on the `task_name`. + """ + + def __init__(self, batch_gen, evaluator): + self.batch_gen = batch_gen + self.evaluator = evaluator + + @classmethod + def setup_task(cls, task_name, opt, save_dir): + + if task_name == "HMNet": + from model.third_party.HMNet.Utils.HMNet.InfinibatchLoader import ( + HMNetBatchGen, + ) + + batch_gen = HMNetBatchGen + from model.third_party.HMNet.Evaluation.ROUGEEval import ROUGEEval + + evaluator = ROUGEEval(opt["datadir"], save_dir, opt) + else: + assert False + print("ERROR: Task {} not defined".format(task_name)) + + return cls(batch_gen, evaluator) diff --git a/model/third_party/HMNet/ThirdParty/Huggingface/Transformers/LICENSE b/model/third_party/HMNet/ThirdParty/Huggingface/Transformers/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..d645695673349e3947e8e5ae42332d0ac3164cd7 --- /dev/null +++ b/model/third_party/HMNet/ThirdParty/Huggingface/Transformers/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived fr om) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred b y, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/model/third_party/HMNet/ThirdParty/Huggingface/Transformers/src/transformers/file_utils.py b/model/third_party/HMNet/ThirdParty/Huggingface/Transformers/src/transformers/file_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..db18a53c7fc6f77e7ab106701132d0321f8cee6b --- /dev/null +++ b/model/third_party/HMNet/ThirdParty/Huggingface/Transformers/src/transformers/file_utils.py @@ -0,0 +1,534 @@ +""" +Utilities for working with the local dataset cache. +This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp +Copyright by the AllenNLP authors. +""" + +import fnmatch +import json +import logging +import os +import shutil +import sys +import tarfile +import tempfile +from contextlib import contextmanager +from functools import partial, wraps +from hashlib import sha256 +from typing import Optional +from urllib.parse import urlparse +from zipfile import ZipFile, is_zipfile + +import boto3 +import requests +from botocore.config import Config +from botocore.exceptions import ClientError +from filelock import FileLock +from tqdm.auto import tqdm + +logger = logging.getLogger(__name__) # pylint: disable=invalid-name + +try: + USE_TF = os.environ.get("USE_TF", "AUTO").upper() + USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper() + if USE_TORCH in ("1", "ON", "YES", "AUTO") and USE_TF not in ("1", "ON", "YES"): + import torch + + _torch_available = True # pylint: disable=invalid-name + logger.info("PyTorch version {} available.".format(torch.__version__)) + else: + logger.info("Disabling PyTorch because USE_TF is set") + _torch_available = False +except ImportError: + _torch_available = False # pylint: disable=invalid-name + +try: + USE_TF = os.environ.get("USE_TF", "AUTO").upper() + USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper() + + if USE_TF in ("1", "ON", "YES", "AUTO") and USE_TORCH not in ("1", "ON", "YES"): + import tensorflow as tf + + assert hasattr(tf, "__version__") and int(tf.__version__[0]) >= 2 + _tf_available = True # pylint: disable=invalid-name + logger.info("TensorFlow version {} available.".format(tf.__version__)) + else: + logger.info("Disabling Tensorflow because USE_TORCH is set") + _tf_available = False +except (ImportError, AssertionError): + _tf_available = False # pylint: disable=invalid-name + +try: + from torch.hub import _get_torch_home + + torch_cache_home = _get_torch_home() +except ImportError: + torch_cache_home = os.path.expanduser( + os.getenv( + "TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch") + ) + ) +default_cache_path = os.path.join(torch_cache_home, "transformers") + +try: + from pathlib import Path + + PYTORCH_PRETRAINED_BERT_CACHE = Path( + os.getenv( + "PYTORCH_TRANSFORMERS_CACHE", + os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path), + ) + ) +except (AttributeError, ImportError): + PYTORCH_PRETRAINED_BERT_CACHE = os.getenv( + "PYTORCH_TRANSFORMERS_CACHE", + os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path), + ) + +PYTORCH_TRANSFORMERS_CACHE = ( + PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility +) +TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility + +WEIGHTS_NAME = "pytorch_model.bin" +TF2_WEIGHTS_NAME = "tf_model.h5" +TF_WEIGHTS_NAME = "model.ckpt" +CONFIG_NAME = "config.json" +MODEL_CARD_NAME = "modelcard.json" + + +MULTIPLE_CHOICE_DUMMY_INPUTS = [[[0], [1]], [[0], [1]]] +DUMMY_INPUTS = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]] +DUMMY_MASK = [[1, 1, 1, 1, 1], [1, 1, 1, 0, 0], [0, 0, 0, 1, 1]] + +S3_BUCKET_PREFIX = "https://s3.amazonaws.com/models.huggingface.co/bert" +CLOUDFRONT_DISTRIB_PREFIX = "https://d2ws9o8vfrpkyk.cloudfront.net" + + +def is_torch_available(): + return _torch_available + + +def is_tf_available(): + return _tf_available + + +def add_start_docstrings(*docstr): + def docstring_decorator(fn): + fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "") + return fn + + return docstring_decorator + + +def add_start_docstrings_to_callable(*docstr): + def docstring_decorator(fn): + class_name = ":class:`~transformers.{}`".format(fn.__qualname__.split(".")[0]) + intro = " The {} forward method, overrides the :func:`__call__` special method.".format( + class_name + ) + note = r""" + + .. note:: + Although the recipe for forward pass needs to be defined within + this function, one should call the :class:`Module` instance afterwards + instead of this since the former takes care of running the + pre and post processing steps while the latter silently ignores them. + """ + fn.__doc__ = ( + intro + + note + + "".join(docstr) + + (fn.__doc__ if fn.__doc__ is not None else "") + ) + return fn + + return docstring_decorator + + +def add_end_docstrings(*docstr): + def docstring_decorator(fn): + fn.__doc__ = fn.__doc__ + "".join(docstr) + return fn + + return docstring_decorator + + +def is_remote_url(url_or_filename): + parsed = urlparse(url_or_filename) + return parsed.scheme in ("http", "https", "s3") + + +def hf_bucket_url(identifier, postfix=None, cdn=False) -> str: + endpoint = CLOUDFRONT_DISTRIB_PREFIX if cdn else S3_BUCKET_PREFIX + if postfix is None: + return "/".join((endpoint, identifier)) + else: + return "/".join((endpoint, identifier, postfix)) + + +def url_to_filename(url, etag=None): + """ + Convert `url` into a hashed filename in a repeatable way. + If `etag` is specified, append its hash to the url's, delimited + by a period. + If the url ends with .h5 (Keras HDF5 weights) adds '.h5' to the name + so that TF 2.0 can identify it as a HDF5 file + (see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1380) + """ + url_bytes = url.encode("utf-8") + url_hash = sha256(url_bytes) + filename = url_hash.hexdigest() + + if etag: + etag_bytes = etag.encode("utf-8") + etag_hash = sha256(etag_bytes) + filename += "." + etag_hash.hexdigest() + + if url.endswith(".h5"): + filename += ".h5" + + return filename + + +def filename_to_url(filename, cache_dir=None): + """ + Return the url and etag (which may be ``None``) stored for `filename`. + Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist. + """ + if cache_dir is None: + cache_dir = TRANSFORMERS_CACHE + if isinstance(cache_dir, Path): + cache_dir = str(cache_dir) + + cache_path = os.path.join(cache_dir, filename) + if not os.path.exists(cache_path): + raise EnvironmentError("file {} not found".format(cache_path)) + + meta_path = cache_path + ".json" + if not os.path.exists(meta_path): + raise EnvironmentError("file {} not found".format(meta_path)) + + with open(meta_path, encoding="utf-8") as meta_file: + metadata = json.load(meta_file) + url = metadata["url"] + etag = metadata["etag"] + + return url, etag + + +def cached_path( + url_or_filename, + cache_dir=None, + force_download=False, + proxies=None, + resume_download=False, + user_agent=None, + extract_compressed_file=False, + force_extract=False, + local_files_only=False, +) -> Optional[str]: + """ + Given something that might be a URL (or might be a local path), + determine which. If it's a URL, download the file and cache it, and + return the path to the cached file. If it's already a local path, + make sure the file exists and then return the path. + Args: + cache_dir: specify a cache directory to save the file to (overwrite the default cache dir). + force_download: if True, re-dowload the file even if it's already cached in the cache dir. + resume_download: if True, resume the download if incompletly recieved file is found. + user_agent: Optional string or dict that will be appended to the user-agent on remote requests. + extract_compressed_file: if True and the path point to a zip or tar file, extract the compressed + file in a folder along the archive. + force_extract: if True when extract_compressed_file is True and the archive was already extracted, + re-extract the archive and overide the folder where it was extracted. + + Return: + None in case of non-recoverable file (non-existent or inaccessible url + no cache on disk). + Local path (string) otherwise + """ + if cache_dir is None: + cache_dir = TRANSFORMERS_CACHE + if isinstance(url_or_filename, Path): + url_or_filename = str(url_or_filename) + if isinstance(cache_dir, Path): + cache_dir = str(cache_dir) + + if is_remote_url(url_or_filename): + # URL, so get it from the cache (downloading if necessary) + output_path = get_from_cache( + url_or_filename, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + user_agent=user_agent, + local_files_only=local_files_only, + ) + elif os.path.exists(url_or_filename): + # File, and it exists. + output_path = url_or_filename + elif urlparse(url_or_filename).scheme == "": + # File, but it doesn't exist. + raise EnvironmentError("file {} not found".format(url_or_filename)) + else: + # Something unknown + raise ValueError( + "unable to parse {} as a URL or as a local path".format(url_or_filename) + ) + + if extract_compressed_file: + if not is_zipfile(output_path) and not tarfile.is_tarfile(output_path): + return output_path + + # Path where we extract compressed archives + # We avoid '.' in dir name and add "-extracted" at the end: "./model.zip" => "./model-zip-extracted/" + output_dir, output_file = os.path.split(output_path) + output_extract_dir_name = output_file.replace(".", "-") + "-extracted" + output_path_extracted = os.path.join(output_dir, output_extract_dir_name) + + if ( + os.path.isdir(output_path_extracted) + and os.listdir(output_path_extracted) + and not force_extract + ): + return output_path_extracted + + # Prevent parallel extractions + lock_path = output_path + ".lock" + with FileLock(lock_path): + shutil.rmtree(output_path_extracted, ignore_errors=True) + os.makedirs(output_path_extracted) + if is_zipfile(output_path): + with ZipFile(output_path, "r") as zip_file: + zip_file.extractall(output_path_extracted) + zip_file.close() + elif tarfile.is_tarfile(output_path): + tar_file = tarfile.open(output_path) + tar_file.extractall(output_path_extracted) + tar_file.close() + else: + raise EnvironmentError( + "Archive format of {} could not be identified".format(output_path) + ) + + return output_path_extracted + + return output_path + + +def split_s3_path(url): + """Split a full s3 path into the bucket name and path.""" + parsed = urlparse(url) + if not parsed.netloc or not parsed.path: + raise ValueError("bad s3 path {}".format(url)) + bucket_name = parsed.netloc + s3_path = parsed.path + # Remove '/' at beginning of path. + if s3_path.startswith("/"): + s3_path = s3_path[1:] + return bucket_name, s3_path + + +def s3_request(func): + """ + Wrapper function for s3 requests in order to create more helpful error + messages. + """ + + @wraps(func) + def wrapper(url, *args, **kwargs): + try: + return func(url, *args, **kwargs) + except ClientError as exc: + if int(exc.response["Error"]["Code"]) == 404: + raise EnvironmentError("file {} not found".format(url)) + else: + raise + + return wrapper + + +@s3_request +def s3_etag(url, proxies=None): + """Check ETag on S3 object.""" + s3_resource = boto3.resource("s3", config=Config(proxies=proxies)) + bucket_name, s3_path = split_s3_path(url) + s3_object = s3_resource.Object(bucket_name, s3_path) + return s3_object.e_tag + + +@s3_request +def s3_get(url, temp_file, proxies=None): + """Pull a file directly from S3.""" + s3_resource = boto3.resource("s3", config=Config(proxies=proxies)) + bucket_name, s3_path = split_s3_path(url) + s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) + + +def http_get(url, temp_file, proxies=None, resume_size=0, user_agent=None): + ua = "transformers/{}; python/{}".format(__version__, sys.version.split()[0]) + if is_torch_available(): + ua += "; torch/{}".format(torch.__version__) + if is_tf_available(): + ua += "; tensorflow/{}".format(tf.__version__) + if isinstance(user_agent, dict): + ua += "; " + "; ".join("{}/{}".format(k, v) for k, v in user_agent.items()) + elif isinstance(user_agent, str): + ua += "; " + user_agent + headers = {"user-agent": ua} + if resume_size > 0: + headers["Range"] = "bytes=%d-" % (resume_size,) + response = requests.get(url, stream=True, proxies=proxies, headers=headers) + if response.status_code == 416: # Range not satisfiable + return + content_length = response.headers.get("Content-Length") + total = resume_size + int(content_length) if content_length is not None else None + progress = tqdm( + unit="B", + unit_scale=True, + total=total, + initial=resume_size, + desc="Downloading", + disable=bool(logger.getEffectiveLevel() == logging.NOTSET), + ) + for chunk in response.iter_content(chunk_size=1024): + if chunk: # filter out keep-alive new chunks + progress.update(len(chunk)) + temp_file.write(chunk) + progress.close() + + +def get_from_cache( + url, + cache_dir=None, + force_download=False, + proxies=None, + etag_timeout=10, + resume_download=False, + user_agent=None, + local_files_only=False, +) -> Optional[str]: + """ + Given a URL, look for the corresponding file in the local cache. + If it's not there, download it. Then return the path to the cached file. + + Return: + None in case of non-recoverable file (non-existent or inaccessible url + no cache on disk). + Local path (string) otherwise + """ + if cache_dir is None: + cache_dir = TRANSFORMERS_CACHE + if isinstance(cache_dir, Path): + cache_dir = str(cache_dir) + + os.makedirs(cache_dir, exist_ok=True) + + etag = None + if not local_files_only: + # Get eTag to add to filename, if it exists. + if url.startswith("s3://"): + etag = s3_etag(url, proxies=proxies) + else: + try: + response = requests.head( + url, allow_redirects=True, proxies=proxies, timeout=etag_timeout + ) + if response.status_code == 200: + etag = response.headers.get("ETag") + except (EnvironmentError, requests.exceptions.Timeout): + # etag is already None + pass + + filename = url_to_filename(url, etag) + + # get cache path to put the file + cache_path = os.path.join(cache_dir, filename) + + # etag is None = we don't have a connection, or url doesn't exist, or is otherwise inaccessible. + # try to get the last downloaded one + if etag is None: + if os.path.exists(cache_path): + return cache_path + else: + matching_files = [ + file + for file in fnmatch.filter(os.listdir(cache_dir), filename + ".*") + if not file.endswith(".json") and not file.endswith(".lock") + ] + if len(matching_files) > 0: + return os.path.join(cache_dir, matching_files[-1]) + else: + # If files cannot be found and local_files_only=True, + # the models might've been found if local_files_only=False + # Notify the user about that + if local_files_only: + raise ValueError( + "Cannot find the requested files in the cached path and outgoing traffic has been" + " disabled. To enable model look-ups and downloads online, set 'local_files_only'" + " to False." + ) + return None + + # From now on, etag is not None. + if os.path.exists(cache_path) and not force_download: + return cache_path + + # Prevent parallel downloads of the same file with a lock. + lock_path = cache_path + ".lock" + with FileLock(lock_path): + + if resume_download: + incomplete_path = cache_path + ".incomplete" + + @contextmanager + def _resumable_file_manager(): + with open(incomplete_path, "a+b") as f: + yield f + + temp_file_manager = _resumable_file_manager + if os.path.exists(incomplete_path): + resume_size = os.stat(incomplete_path).st_size + else: + resume_size = 0 + else: + temp_file_manager = partial( + tempfile.NamedTemporaryFile, dir=cache_dir, delete=False + ) + resume_size = 0 + + # Download to temporary file, then copy to cache dir once finished. + # Otherwise you get corrupt cache entries if the download gets interrupted. + with temp_file_manager() as temp_file: + logger.info( + "%s not found in cache or force_download set to True, downloading to %s", + url, + temp_file.name, + ) + + # GET file object + if url.startswith("s3://"): + if resume_download: + logger.warn( + 'Warning: resumable downloads are not implemented for "s3://" urls' + ) + s3_get(url, temp_file, proxies=proxies) + else: + http_get( + url, + temp_file, + proxies=proxies, + resume_size=resume_size, + user_agent=user_agent, + ) + + logger.info("storing %s in cache at %s", url, cache_pa th) + os.rename(temp_file.name, cache_path) + + logger.info("creating metadata file for %s", cache_path) + meta = {"url": url, "etag": etag} + meta_path = cache_path + ".json" + with open(meta_path, "w") as meta_file: + json.dump(meta, meta_file) + + return cache_path diff --git a/model/third_party/HMNet/ThirdParty/Huggingface/Transformers/src/transformers/modeling_encoder_decoder.py b/model/third_party/HMNet/ThirdParty/Huggingface/Transformers/src/transformers/modeling_encoder_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..a4d8a9cdead5c33b9c4507d1bf38277ef05e3f91 --- /dev/null +++ b/model/third_party/HMNet/ThirdParty/Huggingface/Transformers/src/transformers/modeling_encoder_decoder.py @@ -0,0 +1,1410 @@ +# coding=utf-8 +# Copyright 2018 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Classes to support Encoder-Decoder architectures """ + + +import logging +import os + +import torch +from torch import nn +from torch.nn import CrossEntropyLoss +from torch.nn import functional as F + + +logger = logging.getLogger(__name__) + + +class PreTrainedEncoderDecoder(nn.Module): + r""" + :class:`~transformers.PreTrainedEncoderDecoder` is a generic model class that will be + instantiated as a transformer architecture with one of the base model + classes of the library as encoder and (optionally) another one as + decoder when created with the `AutoModel.from_pretrained(pretrained_model_name_or_path)` + class method. + """ + + def __init__(self, encoder, decoder): + super().__init__() + self.encoder = encoder + self.decoder = decoder + # manually set the self.config + self.config = decoder.config + self.config.is_encoder_decoder = True + + @classmethod + def from_pretrained( + cls, + encoder_pretrained_model_name_or_path=None, + decoder_pretrained_model_name_or_path=None, + *model_args, + **kwargs, + ): + r"""Instantiates an encoder and a decoder from one or two base classes of the library from pre-trained model checkpoints. + + + The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated) + To train the model, you need to first set it back in training mode with `model.train()` + + Params: + encoder_pretrained_model_name_or_path: information necessary to initiate the encoder. Either: + + - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``. + - a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``. + - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/encoder``. + - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. + + decoder_pretrained_model_name_or_path: information necessary to initiate the decoder. Either: + + - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``. + - a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``. + - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/decoder``. + - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. + + model_args: (`optional`) Sequence of positional arguments: + All remaning positional arguments will be passed to the underlying model's ``__init__`` method + + config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`: + Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when: + + - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or + - the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory. + - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory. + + state_dict: (`optional`) dict: + an optional state dictionnary for the model to use instead of a state dictionary loaded from saved weights file. + This option can be used if you want to create a model from a pretrained configuration but load your own weights. + In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option. + + cache_dir: (`optional`) string: + Path to a directory in which a downloaded pre-trained model + configuration should be cached if the standard cache should not be used. + + force_download: (`optional`) boolean, default False: + Force to (re-)download the model weights and configuration files and override the cached versions if they exists. + + proxies: (`optional`) dict, default None: + A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. + The proxies are used on each request. + + output_loading_info: (`optional`) boolean: + Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages. + + kwargs: (`optional`) Remaining dictionary of keyword arguments. + Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded: + + - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done) + - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function. + + You can specify kwargs sep cific for the encoder and decoder by prefixing the key with `encoder_` and `decoder_` respectively. (e.g. ``decoder_output_attention=True``). The remaining kwargs will be passed to both encoders and decoders. + + Examples:: + + # For example purposes. Not runnable. + model = PreTrainedEncoderDecoder.from_pretrained('bert-base-uncased', 'bert-base-uncased') # initialize Bert2Bert + """ + + # keyword arguments come in 3 flavors: encoder-specific (prefixed by + # `encoder_`), decoder-specific (prefixed by `decoder_`) and those + # that apply to the model as a whole. + # We let the specific kwargs override the common ones in case of conflict. + kwargs_common = { + argument: value + for argument, value in kwargs.items() + if not argument.startswith("encoder_") + and not argument.startswith("decoder_") + } + kwargs_decoder = kwargs_common.copy() + kwargs_encoder = kwargs_common.copy() + kwargs_encoder.update( + { + argument[len("encoder_") :]: value + for argument, value in kwargs.items() + if argument.startswith("encoder_") + } + ) + kwargs_decoder.update( + { + argument[len("decoder_") :]: value + for argument, value in kwargs.items() + if argument.startswith("decoder_") + } + ) + + # Load and initialize the encoder and decoder + # The distinction between encoder and decoder at the model level is made + # by the value of the flag `is_decoder` that we need to set correctly. + encoder = kwargs_encoder.pop("model", None) + if encoder is None: + encoder = AutoModel.from_pretrained( + encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder + ) + encoder.config.is_decoder = False + + decoder = kwargs_decoder.pop("model", None) + if decoder is None: + decoder = AutoModelWithLMHead.from_pretrained( + decoder_pretrained_model_name_or_path, **kwargs_decoder + ) + decoder.config.is_decoder = True + + model = cls(encoder, decoder) + + return model + + def save_pretrained(self, save_directory): + """Save a Seq2Seq model and its configuration file in a format such + that it can be loaded using `:func:`~transformers.PreTrainedEncoderDecoder.from_pretrained` + + We save the encoder' and decoder's parameters in two separate directories. + """ + + # If the root output directory does not exist, create it + if not os.path.exists(save_directory): + os.mkdir(save_directory) + + # Check whether the output directory is empty or not + sub_directories = [ + directory + for directory in os.listdir(save_directory) + if os.path.isdir(os.path.join(save_directory, directory)) + ] + + if len(sub_directories) > 0: + if "encoder" in sub_directories and "decoder" in sub_directories: + print( + "WARNING: there is an older version of encoder-decoder saved in" + + " the output directory. The default behaviour is to overwrite them." + ) + + # Empty the output directory + for directory_to_remove in sub_directories: + # Remove all files into the subdirectory + files_to_remove = os.listdir( + os.path.join(save_directory, directory_to_remove) + ) + for file_to_remove in files_to_remove: + os.remove( + os.path.join( + save_directory, directory_to_remove, file_to_remove + ) + ) + # Remove the subdirectory itself + os.rmdir(os.path.join(save_directory, directory_to_remove)) + + assert len(os.listdir(save_directory)) == 0 # sanity check + + # Create the "encoder" directory inside the output directory and save the encoder into it + if not os.path.exists(os.path.join(save_directory, "encoder")): + os.mkdir(os.path.join(save_directory, "encoder")) + self.encoder.save_pretrained(os.path.join(save_directory, "encoder")) + + # Create the "encoder" directory inside the output directory and save the decoder into it + if not os.path.exists(os.path.join(save_directory, "decoder")): + os.mkdir(os.path.join(save_directory, "decoder")) + self.decoder.save_pretrained(os.path.join(save_directory, "decoder")) + + @staticmethod + def prepare_model_kwargs(**kwargs): + """Prepare the encoder and decoder's keyword arguments. + Keyword arguments come in 3 flavors: + - encoder-specific (prefixed by `encoder_`) + - decoder-specific (prefixed by `decoder_`) + - those that apply to the model as whole. + We let the specific kwargs override the common ones in case of + conflict. + """ + kwargs_common = { + argument: value + for argument, value in kwargs.items() + if not argument.startswith("encoder_") + and not argument.startswith("decoder_") + } + decoder_kwargs = kwargs_common.copy() + encoder_kwargs = kwargs_common.copy() + encoder_kwargs.update( + { + argument[len("encoder_") :]: value + for argument, value in kwargs.items() + if argument.startswith("encoder_") + } + ) + decoder_kwargs.update( + { + argument[len("decoder_") :]: value + for argument, value in kwargs.items() + if argument.startswith("decoder_") + } + ) + decoder_kwargs["encoder_attention_mask"] = encoder_kwargs.get( + "attention_mask", None + ) + return encoder_kwargs, decoder_kwargs + + def forward(self, encoder_input_ids=None, decoder_input_ids=None, **kwargs): + """The forward pass on a seq2eq depends what we are performing: + + - During training we perform one forward pass through both the encoder + and decoder; + - During prediction, we perform one forward pass through the encoder, + and then perform several forward passes with the encoder's hidden + state through the decoder to decode a full sequence. + + Therefore, we skip the forward pass on the encoder if an argument named + `encoder_hidden_state` is passed to this function. + + Params: + encoder_input_ids: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)`` + Indices of encoder input sequence tokens in the vocabulary. + decoder_input_ids: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)`` + Indices of decoder input sequence tokens in the vocabulary. + kwargs: (`optional`) Remaining dictionary of keyword arguments. + """ + kwargs_encoder, kwargs_decoder = self.prepare_model_kwargs(**kwargs) + + # Encode if needed (training, first prediction pass) + encoder_hidden_states = kwargs_encoder.pop("hidden_states", None) + if encoder_hidden_states is None: + encoder_outputs = self.encoder(encoder_input_ids, **kwargs_encoder) + encoder_hidden_states = encoder_outputs[0] + else: + encoder_outputs = () + + kwargs_decoder["encoder_hidden_states"] = encoder_hidden_states + decoder_outputs = self.decoder(decoder_input_ids, **kwargs_decoder) + + return decoder_outputs + encoder_outputs + + def prepare_inputs_for_generation(self, input_ids, past, attention_mask, **kwargs): + assert past is not None, "past has to be defined for encoder_outputs" + + # first step + if type(past) is tuple: + encoder_outputs = past + else: + encoder_outputs = (past,) + + return { + "decoder_input_ids": input_ids, + "encoder_outputs": encoder_outputs, + "encoder_hidden_states": encoder_outputs[0], + "decoder_attention_mask": None, + } + + def prepare_scores_for_generation(self, scores, **kwargs): + return scores + + def _do_output_past(self, outputs): + """During generation, decide whether to pass the `past` variable to the next forward pass.""" + has_output_past = getattr(self.config, "output_past", False) + mem_len = getattr(self.config, "mem_len", 0) + if len(outputs) <= 1: + return False + if mem_len > 0 or has_output_past: + return True + return False + + def enforce_repetition_penalty_( + self, lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty + ): + """repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858).""" + for i in range(batch_size * num_beams): + for previous_token in set(prev_output_tokens[i].tolist()): + # if score < 0 then repetition penalty has to multiplied to reduce the previous token probability + if lprobs[i, previous_token] < 0: + lprobs[i, previous_token] *= repetition_penalty + else: + lprobs[i, previous_token] /= repetition_penalty + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + def get_output_embeddings(self): + return self.decoder.get_output_embeddings() + + @torch.no_grad() + def generate( + self, + input_ids=None, + max_length=None, + min_length=None, + do_sample=None, + early_stopping=None, + num_beams=None, + temperature=None, + top_k=None, + top_p=None, + repetition_penalty=None, + bad_words_ids=None, + bos_token_id=None, + pad_token_id=None, + eos_token_id=None, + length_penalty=None, + no_repeat_ngram_size=None, + num_return_sequences=None, + attention_mask=None, + decoder_start_token_id=None, + ): + r"""Generates sequences for models with a LM head. The method currently supports greedy decoding, beam-search decoding, sampling with temperature, sampling with top-k or nucleus sampling. + + Adapted in part from `Facebook's XLM beam search code`_. + + .. _`Facebook's XLM beam search code`: + https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529 + + + Parameters: + + input_ids: (`optional`) `torch.LongTensor` of shape `(batch_size, sequence_length)` + The sequence used as a prompt for the generation. If `None` the method initializes + it as an empty `torch.LongTensor` of shape `(1,)`. + + max_length: (`optional`) int + The max length of the sequence to be generated. Between `min_length` and infinity. Default to 20. + + min_length: (`optional`) int + The min length of the sequence to be generated. Between 0 and infinity. Default to 0. + + do_sample: (`optional`) bool + If set to `False` greedy decoding is used. Otherwise sampling is used. Defaults to `False` as defined in `configuration_utils.PretrainedConfig`. + + early_stopping: (`optional`) bool + if set to `True` beam search is stopped when at least `num_beams` sentences finished per batch. Defaults to `False` as defined in `configuration_utils.PretrainedConfig`. + + num_beams: (`optional`) int + Number of beams for beam search. Must be between 1 and infinity. 1 means no beam search. Default to 1. + + temperature: (`optional`) float + The value used to module the next token probabilities. Must be strictly positive. Default to 1.0. + + top_k: (`optional`) int + The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50. + + top_p: (`optional`) float + The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1. + + repetition_penalty: (`optional`) float + The parameter for repetition penalty. Between 1.0 and infinity. 1.0 means no penalty. Default to 1.0. + + pad_token_id: (`optional`) int + Padding token. Default to specicic model pad_token_id or None if it does not exist. + + bos_token_id: (`optional`) int + BOS token. Defaults to `bos_token_id` as defined in the models config. + + eos_token_id: (`optional`) int + EOS token. Defaults to `eos_token_id` as defined in the models config. + + length_penalty: (`optional`) float + Exponential penalty to the length. Default to 1. + + no_repeat_ngram_size: (`optional`) int + If set to int > 0, all ngrams of size `no_repeat_ngram_size` can only occur once. + bad_words_ids: (`optional`) list of lists of int + `bad_words_ids` contains tokens that are not allowed to be generated. In order to get the tokens of the words that should not appear in the generated text, use `tokenizer.encode(bad_word, add_prefix_space=True)`. + + num_return_sequences: (`optional`) int + The number of independently computed returned sequences for each element in the batch. Default to 1. + + attention_mask (`optional`) obj: `torch.LongTensor` of same shape as `input_ids` + Mask to avoid performing attention on padding token indices. + Mask values selected in ``[0, 1]``: + ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. + Defaults to `None`. + + `What are attention masks? <../glossary.html#attention-mask>`__ + + decoder_start_token_id=None: (`optional`) int + If an encoder-decoder model starts decoding with a different token than BOS. + Defaults to `None` and is changed to `BOS` later. + + Return: + + output: `torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)` + sequence_length is either equal to max_length or shorter if all batches finished early due to the `eos_token_id` + + Examples:: + + tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer + model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache. + outputs = model.generate(max_length=40) # do greedy decoding + print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True))) + + tokenizer = AutoTokenizer.from_pretrained('openai-gpt') # Initialize tokenizer + model = AutoModelWithLMHead.from_pretrained('openai-gpt') # Download model and configuration from S3 and cache. + input_context = 'The dog' + input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context + outputs = model.generate(input_ids=input_ids, num_beams=5, num_return_sequences=3, temperature=1.5) # generate 3 independent sequences using beam search decoding (5 beams) with sampling from initial context 'The dog' + for i in range(3): # 3 output sequences were generated + print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True))) + + tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer + model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache. + input_context = 'The dog' + input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context + outputs = model.generate(input_ids=input_ids, max_length=40, temperature=0.7, num_return_sequences=3) # 3 generate sequences using by sampling + for i in range(3): # 3 output sequences were generated + print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True))) + + tokenizer = AutoTokenizer.from_pretrained('ctrl') # Initialize tokenizer + model = AutoModelWithLMHead.from_pretrained('ctrl') # Download model and configuration from S3 and cache. + input_context = 'Legal My neighbor is' # "Legal" is one of the control codes for ctrl + input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context + outputs = model.generate(input_ids=input_ids, max_length=50, temperature=0.7, repetition_penalty=1.2) # generate sequences + print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True))) + + tokenizer = AutoTokenizer.from_pretrained('gpt2') # Initialize tokenizer + model = AutoModelWithLMHead.from_pretrained('gpt2') # Download model and configuration from S3 and cache. + input_context = 'My cute dog' # "Legal" is one of the control codes for ctrl + bad_words_ids = [tokenizer.encode(bad_word, add_prefix_space=True) for bad_word in ['idiot', 'stupid', 'shut up']] + input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context + outputs = model.generate(input_ids=input_ids, max_length=100, do_sample=True, bad_words_ids=bad_words_ids) # generate sequences without allowing bad_words to be generated + """ + + # We cannot generate if the model does not have a LM head + if self.get_output_embeddings() is None: + raise AttributeError( + "You tried to generate sequences with a model that does not have a LM Head." + "Please use another model class (e.g. `OpenAIGPTLMHeadModel`, `XLNetLMHeadModel`, `GPT2LMHeadModel`, `CTRLLMHeadModel`, `T5WithLMHeadModel`, `TransfoXLLMHeadModel`, `XLMWithLMHeadModel`, `BartForConditionalGeneration` )" + ) + + max_length = max_length if max_length is not None else self.config.max_length + min_length = min_length if min_length is not None else self.config.min_length + do_sample = do_sample if do_sample is not None else self.config.do_sample + early_stopping = ( + early_stopping if early_stopping is not None else self.config.early_stopping + ) + num_beams = num_beams if num_beams is not None else self.config.num_beams + temperature = ( + temperature if temperature is not None else self.config.temperature + ) + top_k = top_k if top_k is not None else self.config.top_k + top_p = top_p if top_p is not None else self.config.top_p + repetition_penalty = ( + repetition_penalty + if repetition_penalty is not None + else self.config.repetition_penalty + ) + bos_token_id = ( + bos_token_id if bos_token_id is not None else self.config.bos_token_id + ) + pad_token_id = ( + pad_token_id if pad_token_id is not None else self.config.pad_token_id + ) + eos_token_id = ( + eos_token_id if eos_token_id is not None else self.config.eos_token_id + ) + length_penalty = ( + length_penalty if length_penalty is not None else self.config.length_penalty + ) + no_repeat_ngram_size = ( + no_repeat_ngram_size + if no_repeat_ngram_size is not None + else self.config.no_repeat_ngram_size + ) + bad_words_ids = ( + bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids + ) + num_return_sequences = ( + num_return_sequences + if num_return_sequences is not None + else self.config.num_return_sequences + ) + decoder_start_token_id = ( + decoder_start_token_id + if decoder_start_token_id is not None + else self.config.decoder_start_token_id + ) + + if input_ids is not None: + batch_size = input_ids.shape[0] # overriden by the input batch_size + else: + batch_size = 1 + + assert ( + isinstance(max_length, int) and max_length > 0 + ), "`max_length` should be a strictly positive integer." + assert ( + isinstance(min_length, int) and min_length >= 0 + ), "`min_length` should be a positive integer." + assert isinstance(do_sample, bool), "`do_sample` should be a boolean." + assert isinstance(early_stopping, bool), "`early_stopping` should be a boolean." + assert ( + isinstance(num_beams, int) and num_beams > 0 + ), "`num_beams` should be a strictly positive integer." + assert temperature > 0, "`temperature` should be strictly positive." + assert ( + isinstance(top_k, int) and top_k >= 0 + ), "`top_k` should be a positive integer." + assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1." + assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1." + assert input_ids is not None or ( + isinstance(bos_token_id, int) and bos_token_id >= 0 + ), "If input_ids is not defined, `bos_token_id` should be a positive integer." + assert pad_token_id is None or ( + isinstance(pad_token_id, int) and (pad_token_id >= 0) + ), "`pad_token_id` should be a positive integer." + assert (eos_token_id is None) or ( + isinstance(eos_token_id, int) and (eos_token_id >= 0) + ), "`eos_token_id` should be a positive integer." + assert length_penalty > 0, "`length_penalty` should be strictly positive." + assert ( + isinstance(no_repeat_ngram_size, int) and no_repeat_ngram_size >= 0 + ), "`no_repeat_ngram_size` should be a positive integer." + assert ( + isinstance(num_return_sequences, int) and num_return_sequences > 0 + ), "`num_return_sequences` should be a strictly positive integer." + assert ( + bad_words_ids is None + or isinstance(bad_words_ids, list) + and isinstance(bad_words_ids[0], list) + ), "`bad_words_ids` is either `None` or a list of lists of tokens that should not be generated" + + if input_ids is None: + assert isinstance(bos_token_id, int) and bos_token_id >= 0, ( + "you should either supply a context to complete as `input_ids` input " + "or a `bos_token_id` (integer >= 0) as a first token to start the generation." + ) + input_ids = torch.full( + (batch_size, 1), + bos_token_id, + dtype=torch.long, + device=next(self.parameters()).device, + ) + else: + assert ( + input_ids.dim() == 2 + ), "Input prompt should be of shape (batch_size, sequence length)." + + # not allow to duplicate outputs when greedy decoding + if do_sample is False: + if num_beams == 1: + # no_beam_search greedy generation conditions + assert ( + num_return_sequences == 1 + ), "Greedy decoding will always produce the same output for num_beams == 1 and num_return_sequences > 1. Please set num_return_sequences = 1" + + else: + # beam_search greedy generation conditions + assert ( + num_beams >= num_return_sequences + ), "Greedy beam search decoding cannot return more sequences than it has beams. Please set num_beams >= num_return_sequences" + + # create attention mask if necessary + # TODO (PVP): this should later be handled by the forward fn() in each model in the future see PR 3140 + i f ( + (attention_mask is None) + and (pad_token_id is not None) + and (pad_token_id in input_ids) + ): + attention_mask = input_ids.ne(pad_token_id).long() + elif attention_mask is None: + attention_mask = input_ids.new_ones(input_ids.shape) + + # set pad_token_id to eos_token_id if not set. Important that this is done after + # attention_mask is created + if pad_token_id is None and eos_token_id is not None: + logger.warning( + "Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format( + eos_token_id + ) + ) + pad_token_id = eos_token_id + + # current position and vocab size + vocab_size = self.config.vocab_size + + # set effective batch size and effective batch multiplier according to do_sample + if do_sample: + effective_batch_size = batch_size * num_return_sequences + effective_batch_mult = num_return_sequences + else: + effective_batch_size = batch_size + effective_batch_mult = 1 + + if self.config.is_encoder_decoder: + if decoder_start_token_id is None: + decoder_start_token_id = bos_token_id + + assert ( + decoder_start_token_id is not None + ), "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation" + assert hasattr( + self, "get_encoder" + ), "{} should have a 'get_encoder' function defined".format(self) + assert callable(self.get_encoder), "{} should be a method".format( + self.get_encoder + ) + + # get encoder and store encoder outputs + encoder = self.get_encoder() + + encoder_outputs = encoder(input_ids, attention_mask=attention_mask) + + # Expand input ids if num_beams > 1 or num_return_sequences > 1 + if num_return_sequences > 1 or num_beams > 1: + input_ids_len = input_ids.shape[-1] + input_ids = input_ids.unsqueeze(1).expand( + batch_size, effective_batch_mult * num_beams, input_ids_len + ) + attention_mask = attention_mask.unsqueeze(1).expand( + batch_size, effective_batch_mult * num_beams, input_ids_len + ) + + input_ids = input_ids.contiguous().view( + effective_batch_size * num_beams, input_ids_len + ) # shape: (batch_size * num_return_sequences * num_beams, cur_len) + attention_mask = attention_mask.contiguous().view( + effective_batch_size * num_beams, input_ids_len + ) # shape: (batch_size * num_return_sequences * num_beams, cur_len) + + if self.config.is_encoder_decoder: + # create empty decoder_input_ids + input_ids = torch.full( + (effective_batch_size * num_beams, 1), + decoder_start_token_id, + dtype=torch.long, + device=next(self.parameters()).device, + ) + cur_len = 1 + + assert ( + batch_size == encoder_outputs[0].shape[0] + ), f"expected encoder_outputs[0] to have 1st dimension bs={batch_size}, got {encoder_outputs[0].shape[0]} " + + # expand batch_idx to assign correct encoder output for expanded input_ids (due to num_beams > 1 and num_return_sequences > 1) + expanded_batch_idxs = ( + torch.arange(batch_size) + .view(-1, 1) + .repeat(1, num_beams * effective_batch_mult) + .view(-1) + .to(input_ids.device) + ) + # expand encoder_outputs + encoder_outputs = ( + encoder_outputs[0].index_select(0, expanded_batch_idxs), + *encoder_outputs[1:], + ) + + else: + encoder_outputs = None + cur_len = input_ids .shape[-1] + + if num_beams > 1: + output = self._generate_beam_search( + input_ids, + cur_len=cur_len, + max_length=max_length, + min_length=min_length, + do_sample=do_sample, + early_stopping=early_stopping, + temperature=temperature, + top_k=top_k, + top_p=top_p, + repetition_penalty=repetition_penalty, + no_repeat_ngram_size=no_repeat_ngram_size, + bad_words_ids=bad_words_ids, + bos_token_id=bos_token_id, + pad_token_id=pad_token_id, + decoder_start_token_id=decoder_start_token_id, + eos_token_id=eos_token_id, + batch_size=effective_batch_size, + num_return_sequences=num_return_sequences, + length_penalty=length_penalty, + num_beams=num_beams, + vocab_size=vocab_size, + encoder_outputs=encoder_outputs, + attention_mask=attention_mask, + ) + else: + output = self._generate_no_beam_search( + input_ids, + cur_len=cur_len, + max_length=max_length, + min_length=min_length, + do_sample=do_sample, + temperature=temperature, + top_k=top_k, + top_p=top_p, + repetition_penalty=repetition_penalty, + no_repeat_ngram_size=no_repeat_ngram_size, + bad_words_ids=bad_words_ids, + bos_token_id=bos_token_id, + pad_token_id=pad_token_id, + decoder_start_token_id=decoder_start_token_id, + eos_token_id=eos_token_id, + batch_size=effective_batch_size, + encoder_outputs=encoder_outputs, + attention_mask=attention_mask, + ) + + return output + + def _generate_no_beam_search( + self, + input_ids, + cur_len, + max_length, + min_length, + do_sample, + temperature, + top_k, + top_p, + repetition_penalty, + no_repeat_ngram_size, + bad_words_ids, + bos_token_id, + pad_token_id, + eos_token_id, + decoder_start_token_id, + batch_size, + encoder_outputs, + attention_mask, + ): + """Generate sequences for each example without beam search (num_beams == 1). + All returned sequence are generated independantly. + """ + # length of generated sentences / unfinished sentences + unfinished_sents = input_ids.new(batch_size).fill_(1) + sent_lengths = input_ids.new(batch_size).fill_(max_length) + + past = encoder_outputs # defined for encoder-decoder models, None for decoder-only models + + while cur_len < max_length: + model_inputs = self.prepare_inputs_for_generation( + input_ids, past=past, attention_mask=attention_mask + ) + + outputs = self(**model_inputs) + next_token_logits = outputs[0][:, -1, :] + + # if model has past, then set the past variable to speed up decoding + if self._do_output_past(outputs): + past = outputs[1] + + # repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858) + if repetition_penalty != 1.0: + self.enforce_repetition_penalty_( + next_token_logits, batch_size, 1, input_ids, repetition_penalty + ) + + if no_repeat_ngram_size > 0: + # calculate a list of banned tokens to prevent repetitively generating the same ngrams + # from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345 + banned_tokens = calc_banned_ngram_tokens( + input_ids, batch_size, no_repe at_ngram_size, cur_len + ) + for batch_idx in range(batch_size): + next_token_logits[batch_idx, banned_tokens[batch_idx]] = -float( + "inf" + ) + + if bad_words_ids is not None: + # calculate a list of banned tokens according to bad words + banned_tokens = calc_banned_bad_words_ids(input_ids, bad_words_ids) + + for batch_idx in range(batch_size): + next_token_logits[batch_idx, banned_tokens[batch_idx]] = -float( + "inf" + ) + + # set eos token prob to zero if min_length is not reached + if eos_token_id is not None and cur_len < min_length: + next_token_logits[:, eos_token_id] = -float("inf") + + if do_sample: + # Temperature (higher temperature => more likely to sample low probability tokens) + if temperature != 1.0: + next_token_logits = next_token_logits / temperature + # Top-p/top-k filtering + next_token_logits = top_k_top_p_filtering( + next_token_logits, top_k=top_k, top_p=top_p + ) + # Sample + probs = F.softmax(next_token_logits, dim=-1) + next_token = torch.multinomial(probs, num_samples=1).squeeze(1) + else: + # Greedy decoding + next_token = torch.argmax(next_token_logits, dim=-1) + + # update generations and finished sentences + if eos_token_id is not None: + # pad finished sentences if eos_token_id exist + tokens_to_add = next_token * unfinished_sents + (pad_token_id) * ( + 1 - unfinished_sents + ) + else: + tokens_to_add = next_token + + input_ids = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=-1) + + if eos_token_id is not None: + eos_in_sents = tokens_to_add == eos_token_id + # if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length + is_sents_unfinished_and_token_to_add_is_eos = unfinished_sents.mul( + eos_in_sents.long() + ).bool() + sent_lengths.masked_fill_( + is_sents_unfinished_and_token_to_add_is_eos, cur_len + 1 + ) + # unfinished_sents is set to zero if eos in sentence + unfinished_sents.mul_((~eos_in_sents).long()) + + # stop when there is a in each sentence, or if we exceed the maximul length + if unfinished_sents.max() == 0: + break + + # extend attention_mask for new generated input if only decoder + if self.config.is_encoder_decoder is False: + attention_mask = torch.cat( + [ + attention_mask, + attention_mask.new_ones((attention_mask.shape[0], 1)), + ], + dim=-1, + ) + + cur_len = cur_len + 1 + + # if there are different sentences lengths in the batch, some batches have to be padded + if sent_lengths.min().item() != sent_lengths.max().item(): + assert ( + pad_token_id is not None + ), "`Pad_token_id` has to be defined if batches have different lengths" + # finished sents are filled with pad_token + decoded = input_ids.new(batch_size, sent_lengths.max().item()).fill_( + pad_token_id + ) + else: + decoded = input_ids + + for hypo_idx, hypo in enumerate(input_ids): + decoded[hypo_idx, : sent_lengths[hypo_idx]] = hypo[: sent_lengths[hypo_idx]] + + return decoded + + def _generate_beam_search( + self, + input_ids, + cur_len, + max_length , + min_length, + do_sample, + early_stopping, + temperature, + top_k, + top_p, + repetition_penalty, + no_repeat_ngram_size, + bad_words_ids, + bos_token_id, + pad_token_id, + eos_token_id, + decoder_start_token_id, + batch_size, + num_return_sequences, + length_penalty, + num_beams, + vocab_size, + encoder_outputs, + attention_mask, + ): + """Generate sequences for each example with beam search.""" + + # generated hypotheses + generated_hyps = [ + BeamHypotheses( + num_beams, max_length, length_penalty, early_stopping=early_stopping + ) + for _ in range(batch_size) + ] + + # scores for each sentence in the beam + beam_scores = torch.zeros( + (batch_size, num_beams), dtype=torch.float, device=input_ids.device + ) + + # for greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times + if do_sample is False: + beam_scores[:, 1:] = -1e9 + beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,) + + # cache compute states + past = encoder_outputs # defined for encoder-decoder models, None for decoder-only models + + # done sentences + done = [False for _ in range(batch_size)] + + while cur_len < max_length: + model_inputs = self.prepare_inputs_for_generation( + input_ids, past=past, attention_mask=attention_mask + ) + outputs = self( + **model_inputs + ) # (batch_size * num_beams, cur_len, vocab_size) + next_token_logits = outputs[0][ + :, -1, : + ] # (batch_size * num_beams, vocab_size) + + # if model has past, then set the past variable to speed up decoding + if self._do_output_past(outputs): + past = outputs[1] + + # repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858) + if repetition_penalty != 1.0: + self.enforce_repetition_penalty_( + next_token_logits, + batch_size, + num_beams, + input_ids, + repetition_penalty, + ) + + if temperature != 1.0: + next_token_logits = next_token_logits / temperature + + scores = F.log_softmax( + next_token_logits, dim=-1 + ) # (batch_size * num_beams, vocab_size) + if self.config.is_encoder_decoder and do_sample is False: + # TODO (PVP) still a bit hacky here - there might be a better solutino + scores = self.prepare_scores_for_generation( + scores, cur_len=cur_len, max_length=max_length + ) + + # set eos token prob to zero if min_length is not reached + if eos_token_id is not None and cur_len < min_length: + scores[:, eos_token_id] = -float("inf") + + if no_repeat_ngram_size > 0: + # calculate a list of banned tokens to prevent repetitively generating the same ngrams + num_batch_hypotheses = batch_size * num_beams + # from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345 + banned_batch_tokens = calc_banned_ngram_tokens( + input_ids, num_batch_hypotheses, no_repeat_ngram_size, cur_len + ) + for i, banned_tokens in enumerate(banned_batch_tokens): + scores[i, banned_tokens] = -float("inf") + + if bad_words_ids is not None: + # calculate a list of banned tokens according to bad words + banned_tokens = calc_banned_bad_words_ids(input_ids, bad_words_ids) + + for i, banned_tokens in enumerate(banned_tokens): + scores[i, banned_tokens] = -float("inf") + + assert scores.shape == ( + batch_size * num_beams, + vocab_size, + ), "Shapes of scores: {} != {}".format( + scores.shape, (batch_size * num_beams, vocab_size) + ) + + if do_sample: + _scores = scores + beam_scores[:, None].expand_as( + scores + ) # (batch_size * num_beams, vocab_size) + # Top-p/top-k filtering + _scores = top_k_top_p_filtering( + _scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=2 + ) # (batch_size * num_beams, vocab_size) + # re-organize to group the beam together to sample from all beam_idxs + _scores = _scores.contiguous().view( + batch_size, num_beams * vocab_size + ) # (batch_size, num_beams * vocab_size) + + # Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search) + probs = F.softmax(_scores, dim=-1) + next_tokens = torch.multinomial( + probs, num_samples=2 * num_beams + ) # (batch_size, num_beams * 2) + # Compute next scores + next_scores = torch.gather( + _scores, -1, next_tokens + ) # (batch_size, num_beams * 2) + # sort the sampled vector to make sure that the first num_beams samples are the best + next_scores, next_scores_indices = torch.sort( + next_scores, descending=True, dim=1 + ) + next_tokens = torch.gather( + next_tokens, -1, next_scores_indices + ) # (batch_size, num_beams * 2) + + else: + next_scores = scores + beam_scores[:, None].expand_as( + scores + ) # (batch_size * num_beams, vocab_size) + + # re-organize to group the beam together (we are keeping top hypothesis accross beams) + next_scores = next_scores.view( + batch_size, num_beams * vocab_size + ) # (batch_size, num_beams * vocab_size) + + next_scores, next_tokens = torch.topk( + next_scores, 2 * num_beams, dim=1, largest=True, sorted=True + ) + + assert ( + next_scores.size() == next_tokens.size() == (batch_size, 2 * num_beams) + ) + + # next batch beam content + next_batch_beam = [] + + # for each sentence + for batch_idx in range(batch_size): + + # if we are done with this sentence + if done[batch_idx]: + assert ( + len(generated_hyps[batch_idx]) >= num_beams + ), "Batch can only be done if at least {} beams have been generated".format( + num_beams + ) + assert ( + eos_token_id is not None and pad_token_id is not None + ), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined" + next_batch_beam.extend( + [(0, pad_token_id, 0)] * num_beams + ) # pad the batch + continue + + # next sentence beam content + next_sent_beam = [] + + # next tokens for this sentence + for beam_token_rank, (beam_token_id, beam_token_score) in enumerate( + zip(next_tokens[batch_idx], next_scores[batch_idx]) + ): + # get beam and token IDs + beam_id = beam_token_id // vocab_size + token_id = beam_token_id % vocab_size + + effective_beam_id = batch_idx * num_beams + beam_id + # add to generated hypotheses if end of sentence or last iteration + if (eos_token_id is not None) and (token_id.item() == eos_token_id): + # if beam_token does not belong to top num_beams tokens, it should not be added + is_beam_token_worse_than_top_num_beams = ( + beam_token_rank >= num_beams + ) + if is_beam_token_worse_than_top_num_beams: + continue + generated_hyps[batch_idx].add( + input_ids[effective_beam_id].clone(), + beam_token_score.item(), + ) + else: + # add next predicted token if it is not eos_token + next_sent_beam.append( + (beam_token_score, token_id, effective_beam_id) + ) + + # the beam for next step is full + if len(next_sent_beam) == num_beams: + break + + # Check if were done so that we can save a pad step if all(done) + done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done( + next_scores[batch_idx].max().item(), cur_len=cur_len + ) + + # update next beam content + assert len(next_sent_beam) == num_beams, "Beam should always be full" + next_batch_beam.extend(next_sent_beam) + assert len(next_batch_beam) == num_beams * (batch_idx + 1) + + # stop when we are done with each sentence + if all(done): + break + + # sanity check / prepare next batch + assert len(next_batch_beam) == batch_size * num_beams + beam_scores = beam_scores.new([x[0] for x in next_batch_beam]) + beam_tokens = input_ids.new([x[1] for x in next_batch_beam]) + beam_idx = input_ids.new([x[2] for x in next_batch_beam]) + + # re-order batch + input_ids = input_ids[beam_idx, :] + input_ids = torch.cat([input_ids, beam_tokens.unsqueeze(1)], dim=-1) + # re-order internal states + if past is not None: + past = self._reorder_cache(past, beam_idx) + + # extend attention_mask for new generated input if only decoder + if self.config.is_encoder_decoder is False: + attention_mask = torch.cat( + [ + attention_mask, + attention_mask.new_ones((attention_mask.shape[0], 1)), + ], + dim=-1, + ) + + # update current length + cur_len = cur_len + 1 + + # finalize all open beam hypotheses and end to generated hypotheses + for batch_idx in range(batch_size): + if done[batch_idx]: + continue + + # test that beam scores match previously calculated scores if not eos and batch_idx not done + if eos_token_id is not None and all( + (token_id % vocab_size).item() is not eos_token_id + for token_id in next_tokens[batch_idx] + ): + assert torch.all( + next_scores[batch_idx, :num_beams] + == beam_scores.view(batch_size, num_beams)[batch_idx] + ), "If batch_idx is not done, final next scores: {} have to equal to accumulated beam_scores: {}".format( + next_scores[:, :num_beams][batch_idx], + beam_scores.view(batch_size, num_beams)[batch_idx], + ) + + # need to add best num_beams hypotheses to generated hyps + for beam_id in range(num_beams): + effective_beam_id = batch_idx * num_beams + beam_id + final_score = beam_scores[effective_beam_id].i tem() + final_tokens = input_ids[effective_beam_id] + generated_hyps[batch_idx].add(final_tokens, final_score) + + # depending on whether greedy generation is wanted or not define different output_batch_size and output_num_return_sequences_per_batch + output_batch_size = ( + batch_size if do_sample else batch_size * num_return_sequences + ) + output_num_return_sequences_per_batch = 1 if do_sample else num_return_sequences + + # select the best hypotheses + sent_lengths = input_ids.new(output_batch_size) + best = [] + + # retrieve best hypotheses + for i, hypotheses in enumerate(generated_hyps): + sorted_hyps = sorted(hypotheses.beams, key=lambda x: x[0]) + for j in range(output_num_return_sequences_per_batch): + effective_batch_idx = output_num_return_sequences_per_batch * i + j + best_hyp = sorted_hyps.pop()[1] + sent_lengths[effective_batch_idx] = len(best_hyp) + best.append(best_hyp) + + # shorter batches are filled with pad_token + if sent_lengths.min().item() != sent_lengths.max().item(): + assert pad_token_id is not None, "`Pad_token_id` has to be defined" + sent_max_len = min(sent_lengths.max().item() + 1, max_length) + decoded = input_ids.new(output_batch_size, sent_max_len).fill_(pad_token_id) + + # fill with hypothesis and eos_token_id if necessary + for i, hypo in enumerate(best): + decoded[i, : sent_lengths[i]] = hypo + if sent_lengths[i] < max_length: + decoded[i, sent_lengths[i]] = eos_token_id + else: + # none of the hypotheses have an eos_token + assert (len(hypo) == max_length for hypo in best) + decoded = ( + torch.stack(best).type(torch.long).to(next(self.parameters()).device) + ) + + return decoded + + # force one of token_ids to be generated by setting prob of all other tokens to 0. + def _force_token_ids_generation(self, scores, token_ids): + if isinstance(token_ids, int): + token_ids = [token_ids] + all_but_token_ids_mask = torch.tensor( + [x for x in range(self.config.vocab_size) if x not in token_ids], + dtype=torch.long, + device=next(self.parameters()).device, + ) + assert ( + len(scores.shape) == 2 + ), "scores should be of rank 2 with shape: [batch_size, vocab_size]" + scores[:, all_but_token_ids_mask] = -float("inf") + + @staticmethod + def _reorder_cache(past, beam_idx): + reordered_past = [] + for layer_past in past: + # get the correct batch idx from layer past batch dim + # batch dim of `past` and `mems` is at 2nd position + reordered_layer_past = [ + layer_past[:, i].unsqueeze(1).clone().detach() for i in beam_idx + ] + reordered_layer_past = torch.cat(reordered_layer_past, dim=1) + # check that shape matches + assert reordered_layer_past.shape == layer_past.shape + reordered_past.append(reordered_layer_past) + past = tuple(reordered_past) + return past + + +def calc_banned_ngram_tokens(prev_input_ids, num_hypos, no_repeat_ngram_size, cur_len): + # Copied from fairseq for no_repeat_ngram in beam_search""" + if cur_len + 1 < no_repeat_ngram_size: + # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet + return [[] for _ in range(num_hypos)] + generated_ngrams = [{} for _ in range(num_hypos)] + for idx in range(num_hypos): + gen_tokens = prev_input_ids[idx].tolist() + generated_ngram = generated_ngrams[idx] + for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]): + prev_ngram_tuple = tuple(ngram[:-1]) + generated_ngram[prev_ngram_tuple] = generated_ngram.get( + prev_ngram_tuple, [] + ) + [ngram[-1]] + + def _get_generated_ngrams(hypo_idx): + # Before decoding the next token, prevent decoding of ngrams that have already appeared + start_idx = cur_len + 1 - no_repeat_ngram_size + ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].tolist()) + return generated_ngrams[hypo_idx].get(ngram_idx, []) + + banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)] + return banned_tokens + + +def calc_banned_bad_words_ids(prev_input_ids, bad_words_ids): + banned_tokens = [] + + def _tokens_match(prev_tokens, tokens): + if len(tokens) == 0: + # if bad word tokens is just one token always ban it + return True + if len(tokens) > len(prev_input_ids): + # if bad word tokens are longer then prev input_ids they can't be equal + return False + + if prev_tokens[-len(tokens) :] == tokens: + # if tokens match + return True + else: + return False + + for prev_input_ids_slice in prev_input_ids: + banned_tokens_slice = [] + + for banned_token_seq in bad_words_ids: + assert ( + len(banned_token_seq) > 0 + ), "Banned words token sequences {} cannot have an empty list".format( + bad_words_ids + ) + + if ( + _tokens_match(prev_input_ids_slice.tolist(), banned_token_seq[:-1]) + is False + ): + # if tokens do not match continue + continue + + banned_tokens_slice.append(banned_token_seq[-1]) + + banned_tokens.append(banned_tokens_slice) + + return banned_tokens + + +def top_k_top_p_filtering( + logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1 +): + """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering + Args: + logits: logits distribution shape (batch size, vocabulary size) + if top_k > 0: keep only top k tokens with highest probability (top-k filtering). + if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). + Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) + Make sure we keep at least min_tokens_to_keep per batch example in the output + From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 + """ + if top_k > 0: + top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check + # Remove all tokens with a probability less than the last token of the top-k + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits[indices_to_remove] = filter_value + + if top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + + # Remove tokens with cumulative probability above the threshold (token with 0 are kept) + sorted_indices_to_remove = cumulative_probs > top_p + if min_tokens_to_keep > 1: + # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) + sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 + # Shift the indices to the right to keep also the first token above the threshold + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + + # scatter sorted tensors to original indexing + indices_to_remove = sorted_indices_to_remove.scatter( + 1, sorted_indices, sorted_indices_to_remove + ) + logits[indices_to_remove] = filter_value + return logits + + +class BeamHypotheses(object): + def __init__(self, num_beams, max_length, length_penalty, early_stopping): + """ + Initialize n-best list of hypotheses. + """ + self. max_length = max_length - 1 # ignoring bos_token + self.length_penalty = length_penalty + self.early_stopping = early_stopping + self.num_beams = num_beams + self.beams = [] + self.worst_score = 1e9 + + def __len__(self): + """ + Number of hypotheses in the list. + """ + return len(self.beams) + + def add(self, hyp, sum_logprobs): + """ + Add a new hypothesis to the list. + """ + score = sum_logprobs / len(hyp) ** self.length_penalty + if len(self) < self.num_beams or score > self.worst_score: + self.beams.append((score, hyp)) + if len(self) > self.num_beams: + sorted_scores = sorted( + [(s, idx) for idx, (s, _) in enumerate(self.beams)] + ) + del self.beams[sorted_scores[0][1]] + self.worst_score = sorted_scores[1][0] + else: + self.worst_score = min(score, self.worst_score) + + def is_done(self, best_sum_logprobs, cur_len=None): + """ + If there are enough hypotheses and that none of the hypotheses being generated + can become better than the worst one in the heap, then we are done with this sentence. + """ + + if len(self) < self.num_beams: + return False + elif self.early_stopping: + return True + else: + if cur_len is None: + cur_len = self.max_length + cur_score = best_sum_logprobs / cur_len ** self.length_penalty + ret = self.worst_score >= cur_score + return ret diff --git a/model/third_party/HMNet/ThirdParty/Huggingface/Transformers/src/transformers/tokenization_transfo_xl.py b/model/third_party/HMNet/ThirdParty/Huggingface/Transformers/src/transformers/tokenization_transfo_xl.py new file mode 100644 index 0000000000000000000000000000000000000000..930a84de77b2e5ac1f4f25a59cef6dab837f8798 --- /dev/null +++ b/model/third_party/HMNet/ThirdParty/Huggingface/Transformers/src/transformers/tokenization_transfo_xl.py @@ -0,0 +1,842 @@ +# coding=utf-8 +# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Tokenization classes for Transformer XL model. + Adapted from https://github.com/kimiyoung/transformer-xl. +""" + + +import glob +import logging +import os +import pickle +import re +from collections import Counter, OrderedDict +from typing import List, Optional, Tuple, Union + +import numpy as np +from tokenizers import Encoding, Tokenizer +from tokenizers.implementations import BaseTokenizer +from tokenizers.models import WordLevel +from tokenizers.normalizers import Lowercase, Sequence, unicode_normalizer_from_str +from tokenizers.pre_tokenizers import CharDelimiterSplit, WhitespaceSplit +from tokenizers.processors import BertProcessing + +from .file_utils import cached_path, is_torch_available +from .tokenization_utils import PreTrainedTokenizer, PreTrainedTokenizerFast + + +if is_torch_available(): + import torch + + +logger = logging.getLogger(__name__) + +VOCAB_FILES_NAMES = {"pretrained_vocab_file": "vocab.bin", "vocab_file": "vocab.txt"} +VOCAB_FILES_NAMES_FAST = { + "pretrained_vocab_file": "vocab.json", + "vocab_file": "vocab.json", +} + +PRETRAINED_VOCAB_FILES_MAP = { + "pretrained_vocab_file": { + "transfo-xl-wt103": "https://s3.amazonaws.com/models.huggin gface.co/bert/transfo-xl-wt103-vocab.bin", + } +} + +PRETRAINED_VOCAB_FILES_MAP_FAST = { + "pretrained_vocab_file": { + "transfo-xl-wt103": "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-vocab.json", + } +} + +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { + "transfo-xl-wt103": None, +} + +PRETRAINED_CORPUS_ARCHIVE_MAP = { + "transfo-xl-wt103": "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-corpus.bin", +} +CORPUS_NAME = "corpus.bin" + + +class TransfoXLTokenizer(PreTrainedTokenizer): + """ + Transformer-XL tokenizer adapted from Vocab class in https://github.com/kimiyoung/transformer-xl + + This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users + should refer to the superclass for more information regarding methods. + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + + def __init__( + self, + special=None, + min_freq=0, + max_size=None, + lower_case=False, + delimiter=None, + vocab_file=None, + pretrained_vocab_file=None, + never_split=None, + unk_token="", + eos_token="", + additional_special_tokens=[""], + **kwargs + ): + super().__init__( + unk_token=unk_token, + eos_token=eos_token, + additional_special_tokens=additional_special_tokens, + **kwargs, + ) + + self.max_len_single_sentence = ( + self.max_len + ) # no default special tokens - you can update this value if you add special tokens + self.max_len_sentences_pair = ( + self.max_len + ) # no default special tokens - you can update this value if you add special tokens + + if never_split is None: + never_split = self.all_special_tokens + if special is None: + special = [] + self.counter = Counter() + self.special = special + self.min_freq = min_freq + self.max_size = max_size + self.lower_case = lower_case + self.delimiter = delimiter + self.vocab_file = vocab_file + self.never_split = never_split + self.punctuation_symbols = '!"#$%&()*+,-./\:;<=>?@[\\]^_`{|}~' # noqa: W605 + self.punction_without_space_before_pattern = re.compile( + r"[^\s][{}]".format(self.punctuation_symbols) + ) + self.punctuation_with_space_around_pattern = ( + self._compile_space_around_punctuation_pattern() + ) + + try: + if pretrained_vocab_file is not None: + # Hack because, honestly this tokenizer was not made to be used + # in a library like ours, at all. + vocab_dict = torch.load(pretrained_vocab_file) + for key, value in vocab_dict.items(): + if key not in self.__dict__: + self.__dict__[key] = value + + if vocab_file is not None: + self.build_vocab() + except Exception: + raise ValueError( + "Unable to parse file {}. Unknown format. " + "If you tried to load a model saved through TransfoXLTokenizerFast," + "please note they are not compatible.".format(pretrained_vocab_file) + ) + + if vocab_file is not None: + self.build_vocab() + + def _compile_space_around_punctuation_pattern(self): + look_ahead_for_special_token = "(?=[{}])".format(self.punctuation_symbols) + look_ahead_to_match_all_except_space = "(?=[^\s])" # noqa: W605 + return re.compile( + r"" + look_ahead_for_special_token + look_ahead_to_match_all_except_space + ) + + def count_file(self, path, verbose=False, add_eos=False): + if verbose: + logger.info("counting file {} ...".format(path)) + assert os.path.exists(path) + + sents = [] + with open(path, "r", encoding="utf-8") as f: + for idx, line in enumerate(f): + if verbose and idx > 0 and idx % 500000 == 0: + logger.info(" line {}".format(idx)) + symbols = self.tokenize(line, add_eos=add_eos) + self.counter.update(symbols) + sents.append(symbols) + + return sents + + def count_sents(self, sents, verbose=False): + """ + sents : a list of sentences, each a list of tokenized symbols + """ + if verbose: + logger.info("counting {} sents ...".format(len(sents))) + for idx, symbols in enumerate(sents): + if verbose and idx > 0 and idx % 500000 == 0: + logger.info(" line {}".format(idx)) + self.counter.update(symbols) + + def _build_from_file(self, vocab_file): + self.idx2sym = [] + self.sym2idx = OrderedDict() + + with open(vocab_file, "r", encoding="utf-8") as f: + for line in f: + symb = line.strip().split()[0] + self.add_symbol(symb) + if "" in self.sym2idx: + self.unk_idx = self.sym2idx[""] + elif "" in self.sym2idx: + self.unk_idx = self.sym2idx[""] + else: + raise ValueError("No token in vocabulary") + + def save_vocabulary(self, vocab_path): + """ + Save the vocabulary and special tokens file to a directory. + + Args: + vocab_path (:obj:`str`): + The directory in which to save the vocabulary. + + Returns: + :obj:`Tuple(str)`: Paths to the files saved. + """ + + logger.warning( + "Please note you will not be able to load the save vocabulary in" + " Rust-based TransfoXLTokenizerFast as they don't share the same structure." + ) + + if os.path.isdir(vocab_path): + vocab_file = os.path.join( + vocab_path, VOCAB_FILES_NAMES["pretrained_vocab_file"] + ) + else: + vocab_file = vocab_path + torch.save(self.__dict__, vocab_file) + return (vocab_file,) + + def build_vocab(self): + if self.vocab_file: + logger.info("building vocab from {}".format(self.vocab_file)) + self._build_from_file(self.vocab_file) + logger.info("final vocab size {}".format(len(self))) + else: + logger.info( + "building vocab with min_freq={}, max_size={}".format( + self.min_freq, self.max_size + ) + ) + self.idx2sym = [] + self.sym2idx = OrderedDict() + + for sym in self.special: + self.add_special(sym) + + for sym, cnt in self.counter.most_common(self.max_size): + if cnt < self.min_freq: + break + self.add_symbol(sym) + + logger.info( + "final vocab size {} from {} unique tokens".format( + len(self), len(self.counter) + ) + ) + + def encode_file( + self, path, ordered=False, verbose=False, add_eos=True, add_double_eos=False + ): + if verbose: + logger.info("encoding file {} ...".format(path)) + assert os.path.exists(path) + encoded = [] + with open(path, "r", encoding="utf-8") as f: + for idx, line in enumerate(f): + if verbose and idx > 0 and idx % 500000 == 0: + logger.info(" line {}".format(idx)) + symbols = self.tokenize( + line, add_eos=add_eos, add_double_eos=add_double_eos + ) + encoded.append(self.convert_to_tensor(symbols)) + + if ordered: + encoded = torch.cat(encoded) + + return encoded + + def encode_sents(self, sents, ordered=False, verbose=False): + if verbose: + logger.info("encoding {} sents ...".format(len(sents))) + encoded = [] + for idx, symbols in enumerate(sents): + if verbose and idx > 0 and idx % 500000 == 0: + logger.info(" line {}".format(idx)) + encoded.append(self.convert_to_tensor(symbols)) + + if ordered: + encoded = torch.cat(encoded) + + return encoded + + def add_special(self, sym): + if sym not in self.sym2idx: + self.idx2sym.append(sym) + self.sym2idx[sym] = len(self.idx2sym) - 1 + setattr(self, "{}_idx".format(sym.strip("<>")), self.sym2idx[sym]) + + def add_symbol(self, sym): + if sym not in self.sym2idx: + self.idx2sym.append(sym) + self.sym2idx[sym] = len(self.idx2sym) - 1 + + def _convert_id_to_token(self, idx): + """Converts an id in a token (BPE) using the vocab.""" + assert 0 <= idx < len(self), "Index {} out of vocabulary range".format(idx) + return self.idx2sym[idx] + + def _convert_token_to_id(self, sym): + """Converts a token (str) in an id using the vocab.""" + if sym in self.sym2idx: + return self.sym2idx[sym] + else: + # logger.info('encounter unk {}'.format(sym)) + # assert '' not in sym + if hasattr(self, "unk_idx"): + return self.sym2idx.get(sym, self.unk_idx) + # Backward compatibility with pre-trained models + elif "" in self.sym2idx: + return self.sym2idx[""] + elif "" in self.sym2idx: + return self.sym2idx[""] + else: + raise ValueError( + "Token not in vocabulary and no token in vocabulary for replacement" + ) + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + out_string = " ".join(tokens).strip() + return out_string + + def convert_to_tensor(self, symbols): + return torch.LongTensor(self.convert_tokens_to_ids(symbols)) + + @property + def vocab_size(self): + return len(self.idx2sym) + + def get_vocab(self): + return dict(self.sym2idx, **self.added_tokens_encoder) + + def _tokenize(self, line, add_eos=False, add_double_eos=False): + line = line.strip() + # convert to lower case + if self.lower_case: + line = line.lower() + + # empty delimiter '' will evaluate False + if self.delimiter == "": + symbols = line + else: + symbols = line.split(self.delimiter) + + if add_double_eos: # lm1b + return [""] + symbols + [""] + elif add_eos: + return symbols + [""] + else: + return symbols + + def prepare_for_tokenization(self, text, **kwargs): + # add spaces before punctuation symbols as should be done in transfo-xl + text = self.punctuation_with_space_around_pattern.sub(r" ", text) + + # if "add_space_before_punct_symbol" in kwargs and kwargs["add_space_before_punct_symbol"]: + # text = self.punctuation_with_space_around_pattern.sub(r" ", text) + # elif self.punction_without_space_before_pattern.search(text): + # # searches until the first occurence of a punctuation symbol without surrounding spaces + # logger.warning( + # "You might want to consider setting `add_space_before_punct_symbol=True` as an argument to the `tokenizer.encode()` to avoid tokenizing words with punctuation symbols to the `` token" + # ) + + return text + + +class _TransfoXLDelimiterLookupTokenizer(BaseTokenizer): + def __init__( + self, + vocab_file, + delimiter, + lowercase, + unk_token, + eos_token, + add_eos=False, + add_double_eos=False, + normalization: Optional[str] = None, + ): + + try: + tokenizer = WordLevel.from_files(vocab_file, unk_token=unk_token) + tokenizer = Tokenizer(tokenizer) + except Exception: + raise ValueError( + "Unable to parse file {}. Unknown format. " + "If you tried to load a model saved through TransfoXLTokenizer," + "please note they are not compatible.".format(vocab_file) + ) + + # Create the correct normalization path + normalizer = [] + + # Include unicode normalization + if normalization: + normalizer += [unicode_normalizer_from_str(normalization)] + + # Include case normalization + if lowercase: + normalizer += [Lowercase()] + + if len(normalizer) > 0: + tokenizer.normalizer = ( + Sequence(normalizer) if len(normalizer) > 1 else normalizer[0] + ) + + # Setup the splitter + tokenizer.pre_tokenizer = ( + CharDelimiterSplit(delimiter) if delimiter else WhitespaceSplit() + ) + + if add_double_eos: + tokenizer.post_processor = BertProcessing( + (eos_token, tokenizer.token_to_id(eos_token)), + (eos_token, tokenizer.token_to_id(eos_token)), + ) + + parameters = { + "model": "TransfoXLModel", + "add_eos": add_eos, + "add_double_eos": add_double_eos, + "unk_token": unk_token, + "eos_token": eos_token, + "delimiter": delimiter, + "lowercase": lowercase, + } + + super().__init__(tokenizer, parameters) + + def encode_batch( + self, sequences: List[Union[str, Tuple[str, str]]] + ) -> List[Encoding]: + return super().encode_batch( + [ + seq.strip() + if isinstance(seq, str) + else (seq[0].strip(), seq[1].strip()) + for seq in sequences + ] + ) + + def encode(self, sequence: str, pair: Optional[str] = None) -> Encoding: + return super().encode(sequence.strip(), pair.strip() if pair else pair) + + +class TransfoXLTokenizerFast(PreTrainedTokenizerFast): + + vocab_files_names = VOCAB_FILES_NAMES_FAST + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP_FAST + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + + def __init__( + self, + special=None, + min_freq=0, + max_size=None, + lower_case=False, + delimiter=None, + vocab_file=None, + pretrained_vocab_file=None, + never_split=None, + unk_token="", + eos_token="", + additional_special_tokens=[""], + add_eos=False, + add_double_eos=False, + normalization=None, + **kwargs + ): + + super().__init__( + _TransfoXLDelimiterLookupTokenizer( + vocab_file=vocab_file or pretrained_vocab_file, + delimiter=delimiter, + lowercase=lower_case, + unk_token=unk_token, + eos_token=eos_token, + add_eos=add_eos, + add_double_eos=add_double_eos, + normalization=normalization, + ), + unk_token=unk_token, + eos_token=eos_token, + additional_special_tokens=additional_special_tokens, + **kwargs, + ) + + def save_pretrained(self, save_directory): + logger.warning( + "Please note you will not be able to load the vocabulary in" + " Python-based TransfoXLTokenizer as they don't share the same structure." + ) + + return super().save_pretrained(save_directory) + + +class LMOrderedIterator(object): + def __init__(self, data, bsz, bptt, device="cpu", ext_len=None): + """ + data -- LongTensor -- the LongTensor is strictly ordered + """ + self.bsz = bsz + self.bptt = bptt + self.ext_len = ext_len if ext_len is not None else 0 + + self.device = device + + # Work out how cleanly we can divide the dataset into bsz parts. + self.n_step = data.size(0) // bsz + + # Trim off any extra elements that wouldn't cleanly fit (remainders). + data = data.narrow(0, 0, self.n_step * bsz) + + # Evenly divide the data across the bsz batches. + self.data = data.view(bsz, -1).t().contiguous().to(device) + + # Number of mini-batches + self.n_batch = (self.n_step + self.bptt - 1) // self.bptt + + def get_batch(self, i, bptt=None): + if bptt is None: + bptt = self.bptt + seq_len = min(bptt, self.data.size(0) - 1 - i) + + end_idx = i + seq_len + beg_idx = max(0, i - self.ext_len) + + data = self.data[beg_idx:end_idx] + target = self.data[i + 1 : i + 1 + seq_len] + + data_out = data.transpose(0, 1).contiguous().to(self.device) + target_out = target.transpose(0, 1).contiguous().to(self.device) + + return data_out, target_out, seq_len + + def get_fixlen_iter(self, start=0): + for i in range(start, self.data.size(0) - 1, self.bptt): + yield self.get_batch(i) + + def get_varlen_iter(self, start=0, std=5, min_len=5, max_deviation=3): + max_len = self.bptt + max_deviation * std + i = start + while True: + bptt = self.bptt if np.random.random() < 0.95 else self.bptt / 2.0 + bptt = min(max_len, max(min_len, int(np.random.normal(bptt, std)))) + data, target, seq_len = self.get_batch(i, bptt) + i += seq_len + yield data, target, seq_len + if i >= self.data.size(0) - 2: + break + + def __iter__(self): + return self.get_fixlen_iter() + + +class LMShuffledIterator(object): + def __init__(self, data, bsz, bptt, device="cpu", ext_len=None, shuffle=False): + """ + data -- list[LongTensor] -- there is no order among the LongTensors + """ + self.data = data + + self.bsz = bsz + self.bptt = bptt + self.ext_len = ext_len if ext_len is not None else 0 + + self.device = device + self.shuffle = shuffle + + def get_sent_stream(self): + # index iterator + epoch_indices = ( + np.random.permutation(len(self.data)) + if self.shuffle + else np.array(range(len(self.data))) + ) + + # sentence iterator + for idx in epoch_indices: + yield self.data[idx] + + def stream_iterator(self, sent_stream): + # streams for each data in the batch + streams = [None] * self.bsz + + data = torch.LongTensor(self.bptt, self.bsz) + target = torch.LongTensor(self.bptt, self.bsz) + + n_retain = 0 + + while True: + # data : [n_retain+bptt x bsz] + # target : [bptt x bsz] + data[n_retain:].fill_(-1) + target.fill_(-1) + + valid_batch = True + + for i in range(self.bsz): + n_filled = 0 + try: + while n_filled < self.bptt: + if streams[i] is None or len(streams[i]) <= 1: + streams[i] = next(sent_stream) + # number of new tokens to fill in + n_new = min(len(streams[i]) - 1, self.bptt - n_filled) + # first n_retain tokens are retained from last batch + data[ + n_retain + n_filled : n_retain + n_filled + n_new, i + ] = streams[i][:n_new] + target[n_filled : n_filled + n_new, i] = streams[i][ + 1 : n_new + 1 + ] + streams[i] = streams[i][n_new:] + n_filled += n_new + except StopIteration: + valid_batch = False + break + + if not valid_batch: + return + + data_out = data.transpose(0, 1 ).contiguous().to(self.device) + target_out = target.transpose(0, 1).contiguous().to(self.device) + + yield data_out, target_out, self.bptt + + n_retain = min(data.size(0), self.ext_len) + if n_retain > 0: + data[:n_retain] = data[-n_retain:] + data.resize_(n_retain + self.bptt, data.size(1)) + + def __iter__(self): + # sent_stream is an iterator + sent_stream = self.get_sent_stream() + + for batch in self.stream_iterator(sent_stream): + yield batch + + +class LMMultiFileIterator(LMShuffledIterator): + def __init__( + self, paths, vocab, bsz, bptt, device="cpu", ext_len=None, shuffle=False + ): + + self.paths = paths + self.vocab = vocab + + self.bsz = bsz + self.bptt = bptt + self.ext_len = ext_len if ext_len is not None else 0 + + self.device = device + self.shuffle = shuffle + + def get_sent_stream(self, path): + sents = self.vocab.encode_file(path, add_double_eos=True) + if self.shuffle: + np.random.shuffle(sents) + sent_stream = iter(sents) + + return sent_stream + + def __iter__(self): + if self.shuffle: + np.random.shuffle(self.paths) + + for path in self.paths: + # sent_stream is an iterator + sent_stream = self.get_sent_stream(path) + for batch in self.stream_iterator(sent_stream): + yield batch + + +class TransfoXLCorpus(object): + @classmethod + def from_pretrained( + cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs + ): + """ + Instantiate a pre-processed corpus. + """ + vocab = TransfoXLTokenizer.from_pretrained( + pretrained_model_name_or_path, *inputs, **kwargs + ) + if pretrained_model_name_or_path in PRETRAINED_CORPUS_ARCHIVE_MAP: + corpus_file = PRETRAINED_CORPUS_ARCHIVE_MAP[pretrained_model_name_or_path] + else: + corpus_file = os.path.join(pretrained_model_name_or_path, CORPUS_NAME) + # redirect to the cache, if necessary + try: + resolved_corpus_file = cached_path(corpus_file, cache_dir=cache_dir) + except EnvironmentError: + logger.error( + "Corpus '{}' was not found in corpus list ({}). " + "We assumed '{}' was a path or url but couldn't find files {} " + "at this path or url.".format( + pretrained_model_name_or_path, + ", ".join(PRETRAINED_CORPUS_ARCHIVE_MAP.keys()), + pretrained_model_name_or_path, + corpus_file, + ) + ) + return None + if resolved_corpus_file == corpus_file: + logger.info("loading corpus file {}".format(corpus_file)) + else: + logger.info( + "loading corpus file {} from cache at {}".format( + corpus_file, resolved_corpus_file + ) + ) + + # Instantiate tokenizer. + corpus = cls(*inputs, **kwargs) + corpus_dict = torch.load(resolved_corpus_file) + for key, value in corpus_dict.items(): + corpus.__dict__[key] = value + corpus.vocab = vocab + if corpus.train is not None: + corpus.train = torch.tensor(corpus.train, dtype=torch.long) + if corpus.valid is not None: + corpus.valid = torch.tensor(corpus.valid, dtype=torch.long) + if corpus.test is not None: + corpus.test = torch.tensor(corpus.test, dtype=torch.long) + return corpus + + def __init__(self, *args, **kwargs): + self.vocab = TransfoXLTokenizer(*args, **kwargs) + self.dataset = None + self.train = None + self.valid = None + self.test = None + + def build_corpus(self, path, dataset): + self.dataset = dataset + + if self.dataset in ["ptb", "wt2", "enwik8", "text8"]: + sel f.vocab.count_file(os.path.join(path, "train.txt")) + self.vocab.count_file(os.path.join(path, "valid.txt")) + self.vocab.count_file(os.path.join(path, "test.txt")) + elif self.dataset == "wt103": + self.vocab.count_file(os.path.join(path, "train.txt")) + elif self.dataset == "lm1b": + train_path_pattern = os.path.join( + path, + "1-billion-word-language-modeling-benchmark-r13output", + "training-monolingual.tokenized.shuffled", + "news.en-*", + ) + train_paths = glob.glob(train_path_pattern) + # the vocab will load from file when build_vocab() is called + + self.vocab.build_vocab() + + if self.dataset in ["ptb", "wt2", "wt103"]: + self.train = self.vocab.encode_file( + os.path.join(path, "train.txt"), ordered=True + ) + self.valid = self.vocab.encode_file( + os.path.join(path, "valid.txt"), ordered=True + ) + self.test = self.vocab.encode_file( + os.path.join(path, "test.txt"), ordered=True + ) + elif self.dataset in ["enwik8", "text8"]: + self.train = self.vocab.encode_file( + os.path.join(path, "train.txt"), ordered=True, add_eos=False + ) + self.valid = self.vocab.encode_file( + os.path.join(path, "valid.txt"), ordered=True, add_eos=False + ) + self.test = self.vocab.encode_file( + os.path.join(path, "test.txt"), ordered=True, add_eos=False + ) + elif self.dataset == "lm1b": + self.train = train_paths + self.valid = self.vocab.encode_file( + os.path.join(path, "valid.txt"), ordered=False, add_double_eos=True + ) + self.test = self.vocab.encode_file( + os.path.join(path, "test.txt"), ordered=False, add_double_eos=True + ) + + def get_iterator(self, split, *args, **kwargs): + if split == "train": + if self.dataset in ["ptb", "wt2", "wt103", "enwik8", "text8"]: + data_iter = LMOrderedIterator(self.train, *args, **kwargs) + elif self.dataset == "lm1b": + kwargs["shuffle"] = True + data_iter = LMMultiFileIterator(self.train, self.vocab, *args, **kwargs) + elif split in ["valid", "test"]: + data = self.valid if split == "valid" else self.test + if self.dataset in ["ptb", "wt2", "wt103", "enwik8", "text8"]: + data_iter = LMOrderedIterator(data, *args, **kwargs) + elif self.dataset == "lm1b": + data_iter = LMShuffledIterator(data, *args, **kwargs) + + return data_iter + + +def get_lm_corpus(datadir, dataset): + fn = os.path.join(datadir, "cache.pt") + fn_pickle = os.path.join(datadir, "cache.pkl") + if os.path.exists(fn): + logger.info("Loading cached dataset...") + corpus = torch.load(fn_pickle) + elif os.path.exists(fn): + logger.info("Loading cached dataset from pickle...") + with open(fn, "rb") as fp: + corpus = pickle.load(fp) + else: + logger.info("Producing dataset {}...".format(dataset)) + kwargs = {} + if dataset in ["wt103", "wt2"]: + kwargs["special"] = [""] + kwargs["lower_case"] = False + elif dataset == "ptb": + kwargs["special"] = [""] + kwargs["lower_case"] = True + elif dataset == "lm1b": + kwargs["special"] = [] + kwargs["lower_case"] = False + kwargs["vocab_file"] = os.path.join(datadir, "1b_word_vocab.txt") + elif dataset in ["enwik8", "text8"]: + pass + + corpus = TransfoXLCorpus(datadir, dataset, **kwargs) + torch.save(corpus, fn) + + return corpus diff --git a/model/third_party/HMNet/ThirdParty/Huggingface/Transformers/src/transformers/tokenization_utils.py b/model/th ird_party/HMNet/ThirdParty/Huggingface/Transformers/src/transformers/tokenization_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..150d879c5cac5f762f11781294100a71811cb323 --- /dev/null +++ b/model/third_party/HMNet/ThirdParty/Huggingface/Transformers/src/transformers/tokenization_utils.py @@ -0,0 +1,2166 @@ +# coding=utf-8 +# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for OpenAI GPT.""" + +import copy +import itertools +import json +import logging +import os +import re +from collections import defaultdict +from contextlib import contextmanager +from typing import List, Optional, Tuple, Union + +from tokenizers.implementations import BaseTokenizer + +from .file_utils import ( + cached_path, + hf_bucket_url, + is_remote_url, + is_tf_available, + is_torch_available, +) + + +if is_tf_available(): + import tensorflow as tf +if is_torch_available(): + import torch + +logger = logging.getLogger(__name__) + +SPECIAL_TOKENS_MAP_FILE = "special_tokens_map.json" +ADDED_TOKENS_FILE = "added_tokens.json" +TOKENIZER_CONFIG_FILE = "tokenizer_config.json" + + +@contextmanager +def truncate_and_pad( + tokenizer: BaseTokenizer, + max_length: int, + stride: int, + strategy: str, + pad_to_max_length: bool, + padding_side: str, + pad_token_id: int, + pad_token_type_id: int, + pad_token: str, +): + """ + This contextmanager is in charge of defining the truncation and the padding strategies and then + restore the tokenizer settings afterwards. + + This contextmanager assumes the provider tokenizer has no padding / truncation strategy + before the managed section. If your tokenizer set a padding / truncation strategy before, + then it will be reset to no padding/truncation when exiting the managed section. + + :param tokenizer: + :param max_length: + :param stride: + :param strategy: + :param pad_to_max_length: + :param padding_side: + :param pad_token_id: + :param pad_token_type_id: + :param pad_token: + :return: + """ + + # Handle all the truncation and padding stuff + if max_length is not None: + tokenizer.enable_truncation(max_length, stride=stride, strategy=strategy) + + if pad_to_max_length and (pad_token and pad_token_id >= 0): + tokenizer.enable_padding( + max_length=max_length, + direction=padding_side, + pad_id=pad_token_id, + pad_type_id=pad_token_type_id, + pad_token=pad_token, + ) + elif pad_to_max_length: + logger.warning( + "Disabled padding because no padding token set (pad_token: {}, pad_token_id: {}).\n" + "To remove this error, you can add a new pad token and then resize model embedding:\n" + "\ttokenizer.pad_token = ''\n\tmodel.resize_token_embeddings(len(tokenizer))".format( + pad_token, pad_token_id + ) + ) + + yield + + if max_length is not None: + tokenizer.no_truncation() + + if pad_to_max_length and (pad_token and pad_token_id >= 0): + tokenizer.no_padding() + + +class PreTrainedTokenizer(object): + """Base class for all tokenizers. + Handle all the shared methods for tokenization and special tokens as well as methods downloading/caching/loading pretrained tokenizers as well as adding tokens to the vocabulary. + + This class also contain the added tokens in a unified way on top of all tokenizers so we don't have to handle the specific vocabulary augmentation methods of the various underlying dictionary structures (BPE, sentencepiece...). + + Class attributes (overridden by derived classes): + + - ``vocab_files_names``: a python ``dict`` with, as keys, the ``__init__`` keyword name of each vocabulary file required by the model, and as associated values, the filename for saving the associated file (string). + - ``pretrained_vocab_files_map``: a python ``dict of dict`` the high-level keys being the ``__init__`` keyword name of each vocabulary file required by the model, the low-level being the `short-cut-names` (string) of the pretrained models with, as associated values, the `url` (string) to the associated pretrained vocabulary file. + - ``max_model_input_sizes``: a python ``dict`` with, as keys, the `short-cut-names` (string) of the pretrained models, and as associated values, the maximum length of the sequence inputs of this model, or None if the model has no maximum input size. + - ``pretrained_init_configuration``: a python ``dict`` with, as keys, the `short-cut-names` (string) of the pretrained models, and as associated values, a dictionnary of specific arguments to pass to the ``__init__``method of the tokenizer class for this pretrained model when loading the tokenizer with the ``from_pretrained()`` method. + + Parameters: + + - ``bos_token``: (`Optional`) string: a beginning of sentence token. Will be associated to ``self.bos_token`` and ``self.bos_token_id`` + + - ``eos_token``: (`Optional`) string: an end of sentence token. Will be associated to ``self.eos_token`` and ``self.eos_token_id`` + + - ``unk_token``: (`Optional`) string: an unknown token. Will be associated to ``self.unk_token`` and ``self.unk_token_id`` + + - ``sep_token``: (`Optional`) string: a separation token (e.g. to separate context and query in an input sequence). Will be associated to ``self.sep_token`` and ``self.sep_token_id`` + + - ``pad_token``: (`Optional`) string: a padding token. Will be associated to ``self.pad_token`` and ``self.pad_token_id`` + + - ``cls_token``: (`Optional`) string: a classification token (e.g. to extract a summary of an input sequence leveraging self-attention along the full depth of the model). Will be associated to ``self.cls_token`` and ``self.cls_token_id`` + + - ``mask_token``: (`Optional`) string: a masking token (e.g. when training a model with masked-language modeling). Will be associated to ``self.mask_token`` and ``self.mask_token_id`` + + - ``additional_special_tokens``: (`Optional`) list: a list of additional special tokens. Adding all special tokens here ensure they won't be split by the tokenization process. Will be associated to ``self.additional_special_tokens`` and ``self.additional_special_tokens_ids`` + """ + + vocab_files_names = {} + pretrained_vocab_files_map = {} + pretrained_init_configuration = {} + max_model_input_sizes = {} + model_input_names = ["token_type_ids", "attention_mask"] + + SPECIAL_TOKENS_ATTRIBUTES = [ + "bos_token", + "eos_token", + "unk_token", + "sep_token", + "pad_token", + "cls_token", + "mask_token", + "additional_special_tokens", + ] + + padding_side = "right" + + NO_PAD_TOKEN_FOR_BATCH_MSG = ( + "No padding token is set for this model, therefore no batch can be made with uneven " + "sequences. Set a padding token or adjust the lengths of the sequences building the " + "batch so that every sequence is of the same length." + ) + + UNEVEN_SEQUENCES_FOR_BATCH_MSG = ( + "The sequences building the batch are not of the same size, no tensor " + "can be built. Set `pad_to_max_length=True` to pad the smaller sequences" + "up to the larger sequence's length." + ) + + @property + def bos_token(self): + """Beginning of sentence token (string). Log an error if used while not having been set.""" + if self._bos_token is None: + logger.error("Using bos_token, but it is not set yet.") + return self._bos_token + + @property + def eos_token(self): + """End of sentence token (string). Log an error if used while not having been set.""" + if self._eos_token is None: + logger.error("Using eos_token, but it is not set yet.") + return self._eos_token + + @property + def unk_token(self): + """Unknown token (string). Log an error if used while not having been set.""" + if self._unk_token is None: + logger.error("Using unk_token, but it is not set yet.") + return self._unk_token + + @property + def sep_token(self): + """Separation token (string). E.g. separate context and query in an input sequence. Log an error if used while not having been set.""" + if self._sep_token is None: + logger.error("Using sep_token, but it is not set yet.") + return self._sep_token + + @property + def pad_token(self): + """Padding token (string). Log an error if used while not having been set.""" + if self._pad_token is None: + logger.error("Using pad_token, but it is not set yet.") + return self._pad_token + + @property + def cls_token(self): + """Classification token (string). E.g. to extract a summary of an input sequence leveraging self-attention along the full depth of the model. Log an error if used while not having been set.""" + if self._cls_token is None: + logger.error("Using cls_token, but it is not set yet.") + return self._cls_token + + @property + def mask_token(self): + """Mask token (string). E.g. when training a model with masked-language modeling. Log an error if used while not having been set.""" + if self._mask_token is None: + logger.error("Using mask_token, but it is not set yet.") + return self._mask_token + + @property + def additional_special_tokens(self): + """All the additional special tokens you may want to use (list of strings). Log an error if used while not having been set.""" + if self._additional_special_tokens is None: + logger.error("Using additional_special_tokens, but it is not set yet.") + return self._additional_special_tokens + + @bos_token.setter + def bos_token(self, value): + self._bos_token = value + + @eos_token.setter + def eos_token(self, value): + self._eos_token = value + + @unk_token.setter + def unk_token(self, value): + self._unk_token = value + + @sep_token.setter + def sep_token(self, value): + self._sep_token = value + + @pad_token.setter + def pad_token(self, value): + self._pad_token = value + + @cls_token.setter + def cls_token(self, value): + self._cls_token = value + + @mask_token.setter + def mask_token(self, value): + self._mask_token = value + + @additional_special_tokens.setter + def additional_special_tokens(self, value): + self._additional_special_tokens = value + + @property + def bos_token_id(self): + """Id of the beginning of sentence token in the vocabulary. Log an error if used while not having been set.""" + return self.convert_tokens_to_ids(self.bos_token) + + @property + def eos_token_id(self): + """Id of the end of sentence token in the vocabulary. Log an error if used while not having been set.""" + return self.convert_tokens_to_ids(self.eos_token) + + @property + def unk_token_id(self): + """Id of the unknown token in the vocabulary. Log an error if used while not having been set.""" + return self.convert_tokens_to_ids(self.unk_token) + + @property + def sep_token_id(self): + """Id of the separation token in the vocabulary. E.g. separate context and query in an input sequence. Log an error if used while not having been set.""" + return self.convert_tokens_to_ids(self.sep_token) + + @property + def pad_token_id(self): + """Id of the padding token in the vocabulary. Log an error if used while not having been set.""" + return self.convert_tokens_to_ids(self.pad_token) + + @property + def pad_token_type_id(self): + """Id of the padding token type in the vocabulary.""" + return self._pad_token_type_id + + @property + def cls_token_id(self): + """Id of the classification token in the vocabulary. E.g. to extract a summary of an input sequence leveraging self-attention along the full depth of the model. Log an error if used while not having been set.""" + return self.convert_tokens_to_ids(self.cls_token) + + @property + def mask_token_id(self): + """Id of the mask token in the vocabulary. E.g. when training a model with masked-language modeling. Log an error if used while not having been set.""" + return self.convert_tokens_to_ids(self.mask_token) + + @property + def additional_special_tokens_ids(self): + """Ids of all the additional special tokens in the vocabulary (list of integers). Log an error if used while not having been set.""" + return self.convert_tokens_to_ids(self.additional_special_tokens) + + def get_vocab(self): + """Returns the vocabulary as a dict of {token: index} pairs. `tokenizer.get_vocab()[token]` is equivalent to `tokenizer.convert_tokens_to_ids(token)` when `token` is in the vocab.""" + raise NotImplementedError() + + def __init__(self, max_len=None, **kwargs): + self._bos_token = None + self._eos_token = None + self._unk_token = None + self._sep_token = None + self._pad_token = None + self._cls_token = None + self._mask_token = None + self._pad_token_type_id = 0 + self._additional_special_tokens = [] + + self.max_len = max_len if max_len is not None else int(1e12) + + # Padding side is right by default and over-riden in subclasses. If specified in the kwargs, it is changed. + self.padding_side = kwargs.pop("padding_side", self.padding_side) + self.model_input_names = kwargs.pop("model_input_names", self.model_input_names) + + # Added tokens + self.added_tokens_encoder = {} + self.unique_added_tokens_encoder = set() + self.added_tokens_decoder = {} + + # inputs and kwargs for saving and re-loading (see ``from_pretrained`` and ``save_pretrained``) + self.init_inputs = () + self.init_kwargs = {} + + for key, value in kwargs.items(): + if key in self.SPECIAL_TOKENS_ATTRIBUTES: + if key == "additional_special_tokens": + assert isinstance(value, (list, tuple)) and all( + isinstance(t, str) for t in value + ) + else: + assert isinstance(value, str) + setattr(self, key, value) + + @classmethod + def from_pretrained(cls, *inputs, **kwargs): + r""" + Instantiate a :class:`~transformers.PreTrainedTokenizer` (or a derived class) from a predefined tokenizer. + + Args: + pretrained_model_name_or_path: either: + + - a string with the `shortcut name` of a predefined tokenizer to load from cache or download, e.g.: ``bert-base-uncased``. + - a string with the `identifier name` of a predefined tokenizer that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``. + - a path to a `directory` containing vocabulary files required by the tokenizer, for instance saved using the :func:`~transformers.PreTrainedTokenizer.save_pretrained` method, e.g.: ``./my_model_directory/``. + - (not applicable to all derived classes, deprecated) a path or url to a single saved vocabulary file if and only if the tokenizer only requires a single vocabulary file (e.g. Bert, XLNet), e.g.: ``./my_model_directory/vocab.txt``. + + cache_dir: (`optional`) string: + Path to a directory in which a downloaded predefined tokenizer vocabulary files should be cached if the standard cache should not be used. + + force_download: (`optional`) boolean, default False: + Force to (re-)download the vocabulary files and override the cached versions if they exists. + + resume_download: (`optional`) boolean, default False: + Do not delete incompletely recieved file. Attempt to resume the download if such a file exists. + + proxies: (`optional`) dict, default None: + A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. + The proxies are used on each request. + + inputs: (`optional`) positional arguments: will be passed to the Tokenizer ``__init__`` method. + + kwargs: (`optional`) keyword arguments: will be passed to the Tokenizer ``__init__`` method. Can be used to set special tokens like ``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``mask_token``, ``additional_special_tokens``. See parameters in the doc string of :class:`~transformers.PreTrainedTokenizer` for details. + + Examples:: + + # We can't instantiate directly the base class `PreTrainedTokenizer` so let's show our examples on a derived class: BertTokenizer + + # Download vocabulary from S3 and cache. + tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') + + # Download vocabulary from S3 (user-uploaded) and cache. + tokenizer = BertTokenizer.from_pretrained('dbmdz/bert-base-german-cased') + + # If vocabulary files are in a directory (e.g. tokenizer was saved using `save_pretrained('./test/saved_model/')`) + tokenizer = BertTokenizer.from_pretrained('./test/saved_model/') + + # If the tokenizer uses a single vocabulary file, you can point directly to this file + tokenizer = BertTokenizer.from_pretrained('./test/saved_model/my_vocab.txt') + + # You can link tokens to special vocabulary when instantiating + tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', unk_token='') + # You should be sure '' is in the vocabulary when doing that. + # Otherwise use tokenizer.add_special_tokens({'unk_token': ''}) instead) + assert tokenizer.unk_token == '' + + """ + return cls._from_pretrained(*inputs, **kwargs) + + @classmethod + def _from_pretrained(cls, pretrained_model_name_or_path, *init_inputs, **kwargs): + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", False) + + s3_models = list(cls.max_model_input_sizes.keys()) + vocab_files = {} + init_configuration = {} + if pretrained_model_name_or_path in s3_models: + # Get the vocabulary from AWS S3 bucket + for file_id, map_list in cls.pretrained_vocab_files_map.items(): + vocab_files[file_id] = map_list[pretrained_model_name_or_path] + if ( + cls.pretrained_init_configuration + and pretrained_model_name_or_path in cls.pretrained_init_configuration + ): + init_configuration = cls.pretrained_init_configuration[ + pretrained_model_name_or_path + ].copy() + else: + # Get the vocabulary from local files + logger.info( + "Model name '{}' not found in model shortcut name list ({}). " + "Assuming '{}' is a path, a model identifier, or url to a directory containing tokenizer files.".format( + pretrained_model_name_or_path, + ", ".join(s3_models), + pretrained_model_name_or_path, + ) + ) + + if os.path.isfile(pretrained_model _name_or_path) or is_remote_url( + pretrained_model_name_or_path + ): + if len(cls.vocab_files_names) > 1: + raise ValueError( + "Calling {}.from_pretrained() with the path to a single file or url is not supported." + "Use a model identifier or the path to a directory instead.".format( + cls.__name__ + ) + ) + logger.warning( + "Calling {}.from_pretrained() with the path to a single file or url is deprecated".format( + cls.__name__ + ) + ) + file_id = list(cls.vocab_files_names.keys())[0] + vocab_files[file_id] = pretrained_model_name_or_path + else: + # At this point pretrained_model_name_or_path is either a directory or a model identifier name + additional_files_names = { + "added_tokens_file": ADDED_TOKENS_FILE, + "special_tokens_map_file": SPECIAL_TOKENS_MAP_FILE, + "tokenizer_config_file": TOKENIZER_CONFIG_FILE, + } + # Look for the tokenizer main vocabulary files + the additional tokens files + for file_id, file_name in { + **cls.vocab_files_names, + **additional_files_names, + }.items(): + if os.path.isdir(pretrained_model_name_or_path): + full_file_name = os.path.join( + pretrained_model_name_or_path, file_name + ) + if not os.path.exists(full_file_name): + logger.info( + "Didn't find file {}. We won't load it.".format( + full_file_name + ) + ) + full_file_name = None + else: + full_file_name = hf_bucket_url( + pretrained_model_name_or_path, postfix=file_name + ) + + vocab_files[file_id] = full_file_name + + # Get files from url, cache, or disk depending on the case + try: + resolved_vocab_files = {} + for file_id, file_path in vocab_files.items(): + if file_path is None: + resolved_vocab_files[file_id] = None + else: + resolved_vocab_files[file_id] = cached_path( + file_path, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + ) + except EnvironmentError: + if pretrained_model_name_or_path in s3_models: + msg = "Couldn't reach server at '{}' to download vocabulary files." + else: + msg = ( + "Model name '{}' was not found in tokenizers model name list ({}). " + "We assumed '{}' was a path or url to a directory containing vocabulary files " + "named {}, but couldn't find such vocabulary files at this path or url.".format( + pretrained_model_name_or_path, + ", ".join(s3_models), + pretrained_model_name_or_path, + list(cls.vocab_files_names.values()), + ) + ) + + raise EnvironmentError(msg) + + if all( + full_file_name is None for full_file_name in resolved_vocab_files.values() + ): + raise EnvironmentError( + "Model name '{}' was not found in tokenizers model name list ({}). " + "We assumed '{}' was a path, a model identifier, or url to a directory containing vocabulary files " + "named {} but couldn't find such vocabulary files at this path or url.".format( + pretrained_model_name_or_path, + ", ".join(s3_models), + pretrained_model_name_or_path, + list(cls.vocab_files_names.values()), + ) + ) + + for file_id, file_path in vocab_files.items(): + if file_path == resolved_vocab_files[file_id]: + logger.info("loading file {}".format(file_path)) + else: + logger.info( + "loading file {} from cache at {}".format( + file_path, resolved_vocab_files[file_id] + ) + ) + + # Prepare tokenizer initialization kwargs + # Did we saved some inputs and kwargs to reload ? + tokenizer_config_file = resolved_vocab_files.pop("tokenizer_config_file", None) + if tokenizer_config_file is not None: + with open( + tokenizer_config_file, encoding="utf-8" + ) as tokenizer_config_handle: + init_kwargs = json.load(tokenizer_config_handle) + saved_init_inputs = init_kwargs.pop("init_inputs", ()) + if not init_inputs: + init_inputs = saved_init_inputs + else: + init_kwargs = init_configuration + + # Update with newly provided kwargs + init_kwargs.update(kwargs) + + # Set max length if needed + if pretrained_model_name_or_path in cls.max_model_input_sizes: + # if we're using a pretrained model, ensure the tokenizer + # wont index sequences longer than the number of positional embeddings + max_len = cls.max_model_input_sizes[pretrained_model_name_or_path] + if max_len is not None and isinstance(max_len, (int, float)): + init_kwargs["max_len"] = min( + init_kwargs.get("max_len", int(1e12)), max_len + ) + + # Merge resolved_vocab_files arguments in init_kwargs. + added_tokens_file = resolved_vocab_files.pop("added_tokens_file", None) + special_tokens_map_file = resolved_vocab_files.pop( + "special_tokens_map_file", None + ) + for args_name, file_path in resolved_vocab_files.items(): + if args_name not in init_kwargs: + init_kwargs[args_name] = file_path + if special_tokens_map_file is not None: + with open( + special_tokens_map_file, encoding="utf-8" + ) as special_tokens_map_handle: + special_tokens_map = json.load(special_tokens_map_handle) + for key, value in special_tokens_map.items(): + if key not in init_kwargs: + init_kwargs[key] = value + + # Instantiate tokenizer. + try: + tokenizer = cls(*init_inputs, **init_kwargs) + except OSError: + raise OSError( + "Unable to load vocabulary from file. " + "Please check that the provided vocabulary is accessible and not corrupted." + ) + + # Save inputs and kwargs for saving and re-loading with ``save_pretrained`` + tokenizer.init_inputs = init_inputs + tokenizer.init_kwargs = init_kwargs + + # update unique_added_tokens_encoder with special tokens for correct tokenization + tokenizer.unique_added_tokens_encoder.update(set(tokenizer.all_special_tokens)) + + # Add supplementary tokens. + if added_tokens_file is not None: + with open(added_tokens_file, encoding="utf-8") as added_tokens_handle: + added_tok_encoder = json.load(added_tokens_handle) + added_tok_decoder = {v: k for k, v in added_tok_encoder.items()} + tokenizer.added_tokens_encoder.update(added_tok_encoder) + tokenizer.added_tokens_decoder.update(added_tok_decoder) + tokenizer.unique_added_tokens_encoder.update( + set(tokenizer.added_tokens_encoder.keys()) + ) + + return tokenizer + + def save_pretrained(self, save_directory): + """Save the tokenizer vocabulary files together with: + - added tokens, + - special-tokens-to-class-attributes-mapping, + - tokenizer instantiation positional and keywords inputs (e.g. do_lower_case for Bert). + + This won't save modifications other than (added tokens and special token mapping) you may have + applied to the tokenizer after the instantiation (e.g. modifying tokenizer.do_lower_case after creation). + + This method make sure the full tokenizer can then be re-loaded using the :func:`~transformers.PreTrainedTokenizer.from_pretrained` class method. + """ + if not os.path.isdir(save_directory): + logger.error( + "Saving directory ({}) should be a directory".format(save_directory) + ) + return + + special_tokens_map_file = os.path.join(save_directory, SPECIAL_TOKENS_MAP_FILE) + added_tokens_file = os.path.join(save_directory, ADDED_TOKENS_FILE) + tokenizer_config_file = os.path.join(save_directory, TOKENIZER_CONFIG_FILE) + + tokenizer_config = copy.deepcopy(self.init_kwargs) + if len(self.init_inputs) > 0: + tokenizer_config["init_inputs"] = copy.deepcopy(self.init_inputs) + for file_id in self.vocab_files_names.keys(): + tokenizer_config.pop(file_id, None) + + with open(tokenizer_config_file, "w", encoding="utf-8") as f: + f.write(json.dumps(tokenizer_config, ensure_ascii=False)) + + with open(special_tokens_map_file, "w", encoding="utf-8") as f: + f.write(json.dumps(self.special_tokens_map, ensure_ascii=False)) + + if len(self.added_tokens_encoder) > 0: + with open(added_tokens_file, "w", encoding="utf-8") as f: + out_str = json.dumps(self.added_tokens_encoder, ensure_ascii=False) + f.write(out_str) + + vocab_files = self.save_vocabulary(save_directory) + + return vocab_files + (special_tokens_map_file, added_tokens_file) + + def save_vocabulary(self, save_directory): + """Save the tokenizer vocabulary to a directory. This method does *NOT* save added tokens + and special token mappings. + + Please use :func:`~transformers.PreTrainedTokenizer.save_pretrained` `()` to save the full Tokenizer state if you want to reload it using the :func:`~transformers.PreTrainedTokenizer.from_pretrained` class method. + """ + raise NotImplementedError + + def vocab_size(self): + """Size of the base vocabulary (without the added tokens)""" + raise NotImplementedError + + def __len__(self): + """Size of the full vocabulary with the added tokens""" + return self.vocab_size + len(self.added_tokens_encoder) + + def add_tokens(self, new_tokens): + """ + Add a list of new tokens to the tokenizer class. If the new tokens are not in the + vocabulary, they are added to it with indices starting from length of the current vocabulary. + + Args: + new_tokens: string or list of string. Each string is a token to add. Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer assign the index of the ``unk_token`` to them). + + Returns: + Number of tokens added to the vocabulary. + + Examples:: + + # Let's see how to increase the vocabulary of Bert model and tokenizer + tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') + model = BertModel.from_pretrained('bert-base-uncased') + + num_added_toks = tokenizer.add_tokens(['new_tok1', 'my_new-tok2']) + print('We have added', num_added_toks, 'tokens') + model.resize_token_embeddings(len(tokenizer)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer. + """ + if not new_tokens: + return 0 + + if not isinstance(new_tokens, list): + new_tokens = [new_tokens] + + to_add_tokens = [] + for token in new_tokens: + assert isinstance(token, str) + if ( + self.init_kwargs.get("do_lower_case", False) + and token not in self.all_special_tokens + ): + token = token.lower() + if ( + token != self.unk_token + and self.convert_tokens_to_ids(token) + == self.convert_tokens_to_ids(self.unk_token) + and token not in to_add_tokens + ): + to_add_tokens.append(token) + logger.info("Adding %s to the vocabulary", token) + + added_tok_encoder = dict( + (tok, len(self) + i) for i, tok in enumerate(to_add_tokens) + ) + added_tok_decoder = {v: k for k, v in added_tok_encoder.items()} + self.added_tokens_encoder.update(added_tok_encoder) + self.unique_added_tokens_encoder = set(self.added_tokens_encoder.keys()).union( + set(self.all_special_tokens) + ) + self.added_tokens_decoder.update(added_tok_decoder) + + return len(to_add_tokens) + + def num_added_tokens(self, pair=False): + """ + Returns the number of added tokens when encoding a sequence with special tokens. + + Note: + This encodes inputs and checks the number of added tokens, and is therefore not efficient. Do not put this + inside your training loop. + + Args: + pair: Returns the number of added tokens in the case of a sequence pair if set to True, returns the + number of added tokens in the case of a single sequence if set to False. + + Returns: + Number of tokens added to sequences + """ + token_ids_0 = [] + token_ids_1 = [] + return len( + self.build_inputs_with_special_tokens( + token_ids_0, token_ids_1 if pair else None + ) + ) + + def add_special_tokens(self, special_tokens_dict): + """ + Add a dictionary of special tokens (eos, pad, cls...) to the encoder and link them + to class attributes. If special tokens are NOT in the vocabulary, they are added + to it (indexed starting from the last index of the current vocabulary). + + Using `add_special_tokens` will ensure your special tokens can be used in several ways: + + - special tokens are carefully handled by the tokenizer (they are never split) + - you can easily refer to special tokens using tokenizer class attributes like `tokenizer.cls_token`. This makes it easy to develop model-agnostic training and fine-tuning scripts. + + When possible, special tokens are already registered for provided pretrained models (ex: BertTokenizer cls_token is already registered to be '[CLS]' and XLM's one is also registered to be '') + + Args: + special_tokens_dict: dict of string. Keys should be in the list of predefined special attributes: + [``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``mask_token``, + ``additional_special_tokens``]. + + Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer assign the index of the ``unk_token`` to them). + + Returns: + Number of tokens added to the vocabulary. + + Examples:: + + # Let's see how to add a new classification token to GPT-2 + tokenizer = GPT2Tokenizer.from_pretrained('gpt2') + model = GPT2Model.from_pretrained('gpt2') + + special_tokens_dict = {'cls_token': ''} + + num_added_toks = tokenizer.add_special_tokens(special_tokens_dict) + print('We have added', num_added_toks, 'tokens') + model.resize_token_embeddings(len(tokenizer)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer. + + assert tokenizer.cls_token == '' + """ + if not special_tokens_dict: + return 0 + + added_tokens = 0 + for key, value in special_tokens_dict.items(): + assert key in self.SPECIAL_TOKENS_ATTRIBUTES + if key == "additional_special_tokens": + assert isinstance(value, (list, tuple)) and all( + isinstance(t, str) for t in value + ) + added_tokens += self.add_tokens(value) + else: + assert isinstance(value, str) + added_tokens += self.add_tokens([value]) + logger.info("Assigning %s to the %s key of the tokenizer", value, key) + setattr(self, key, value) + + return added_tokens + + def tokenize(self, text, **kwargs): + """Converts a string in a sequence of tokens (string), using the tokenizer. + Split in words for word-based vocabulary or sub-words for sub-word-based + vocabularies (BPE/SentencePieces/WordPieces). + + Take care of added tokens. + + text: The sequence to be encoded. + add_prefix_space: Only applies to GPT-2 and RoBERTa tokenizers. When `True`, this ensures that the sequence + begins with an empty space. False by default except for when using RoBERTa with `add_special_tokens=True`. + **kwargs: passed to the `prepare_for_tokenization` preprocessing method. + """ + all_special_tokens = self.all_special_tokens + text = self.prepare_for_tokenization(text, **kwargs) + + def lowercase_text(t): + # convert non-special tokens to lowercase + escaped_special_toks = [re.escape(s_tok) for s_tok in all_special_tokens] + pattern = r"(" + r"|".join(escaped_special_toks) + r")|" + r"(.+?)" + return re.sub(pattern, lambda m: m.groups()[0] or m.groups()[1].lower(), t) + + if self.init_kwargs.get("do_lower_case", False): + text = lowercase_text(text) + + def split_on_token(tok, text): + result = [] + split_text = text.split(tok) + for i, sub_text in enumerate(split_text): + sub_text = sub_text.rstrip() + if i == 0 and not sub_text: + result += [tok] + elif i == len(split_text) - 1: + if sub_text: + result += [sub_text] + else: + pass + else: + if sub_text: + result += [sub_text] + result += [tok] + return result + + def split_on_tokens(tok_list, text): + if not text.strip(): + return [] + if not tok_list: + return self._tokenize(text) + + tokenized_text = [] + text_list = [text] + for tok in tok_list: + tokenized_text = [] + for sub_text in text_list: + if sub_text not in self.unique_added_tokens_encoder: + tokenized_text += split_on_token(tok, sub_text) + else: + tokenized_text += [sub_text] + text_list = tokenized_text + + return list( + itertools.chain.from_iterable( + ( + self._tokenize(token) + if token not in self.unique_added_tokens_encoder + else [token] + for token in tokenized_text + ) + ) + ) + + added_tokens = self.unique_added_tokens_encoder + tokenized_text = split_on_tokens(added_tokens, text) + return tokenized_text + + def _tokenize(self, text, **kwargs): + """Converts a string in a sequence of tokens (string), using the tokenizer. + Split in words for word-based vocabulary or sub-words for sub-word-based + vocabularies (BPE/SentencePieces/WordPieces). + + Do NOT take care of added tokens. + """ + raise NotImplementedError + + def convert_tokens_to_ids(self, tokens): + """Converts a single token, or a sequence of tokens, (str) in a single integer id + (resp. a sequence of ids), using the vocabulary. + """ + if tokens is None: + return None + + if isinstance(tokens, str): + return self._convert_token_to_id_with_added_voc(tokens) + + ids = [] + for token in tokens: + ids.append(self._convert_token_to_id_with_added_voc(token)) + return ids + + def _convert_token_to_id_with_added_voc(self, token): + if token is None: + return None + + if token in self.added_tokens_encoder: + return self.added_tokens_encoder[token] + return self._convert_token_to_id(token) + + def _convert_token_to_id(self, token): + raise NotImplementedError + + def encode( + self, + text: str, + text_pair: Optional[str] = None, + add_special_tokens: bool = True, + max_length: Optional[int] = None, + stride: int = 0, + truncation_strategy: str = "longest_first", + pad_to_max_length: bool = False, + return_tensors: Optional[str] = None, + **kwargs + ): + """ + Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary. + + Same as doing ``self.convert_tokens_to_ids(self.tokenize(text))``. + + Args: + text (:obj:`str` or :obj:`List[str]`): + The first sequence to be encoded. This can be a string, a list of strings (tokenized string using + the `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids` + method) + text_pair (:obj:`str` or :obj:`List[str]`, `optional`, defaults to :obj:`None`): + Optional second sequence to be encoded. This can be a string, a list of strings (tokenized + string using the `tokenize` method) or a list of integers (tokenized string ids using the + `convert_tokens_to_ids` method) + add_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`True`): + If set to ``True``, the sequences will be encoded with the special tokens relative + to their model. + max_length (:obj:`int`, `optional`, defaults to :obj:`None`): + If set to a number, will limit the total sequence returned so that it has a maximum length. + If there are overflowing tokens, those will be added to the returned dictionary + stride (:obj:`int`, `optional`, defaults to ``0``): + If set to a number along with max_length, the overflowing tokens returned will contain some tokens + from the main sequence returned. The value of this argument defines the number of additional tokens. + truncation_strategy (:obj:`str`, `optional`, defaults to `longest_first`): + String selected in the following options: + + - 'longest_first' (default) Iteratively reduce the inputs sequence until the input is under max_length + starting from the longest one at each token (when there is a pair of input sequences) + - 'only_first': Only truncate the first sequence + - 'only_second': Only truncate the second sequence + - 'do_not_truncate': Does not truncate (raise an error if the input sequence is longer than max_length) + pad_to_max_length (:obj:`bool`, `optional`, defaults to :obj:`False`): + If set to True, the returned sequences will be padded according to the model's padding side and + padding index, up to their max length. If no max length is specifie d, the padding is done up to the + model's max length. The tokenizer padding sides are handled by the class attribute `padding_side` + which can be set to the following strings: + + - 'left': pads on the left of the sequences + - 'right': pads on the right of the sequences + Defaults to False: no padding. + return_tensors (:obj:`str`, `optional`, defaults to :obj:`None`): + Can be set to 'tf' or 'pt' to return respectively TensorFlow :obj:`tf.constant` + or PyTorch :obj:`torch.Tensor` instead of a list of python integers. + **kwargs: passed to the `self.tokenize()` method + """ + encoded_inputs = self.encode_plus( + text, + text_pair=text_pair, + max_length=max_length, + add_special_tokens=add_special_tokens, + stride=stride, + truncation_strategy=truncation_strategy, + pad_to_max_length=pad_to_max_length, + return_tensors=return_tensors, + **kwargs, + ) + + return encoded_inputs["input_ids"] + + def encode_plus( + self, + text: str, + text_pair: Optional[str] = None, + add_special_tokens: bool = True, + max_length: Optional[int] = None, + stride: int = 0, + truncation_strategy: str = "longest_first", + pad_to_max_length: bool = False, + return_tensors: Optional[str] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + **kwargs + ): + """ + Returns a dictionary containing the encoded sequence or sequence pair and additional information: + the mask for sequence classification and the overflowing elements if a ``max_length`` is specified. + + Args: + text (:obj:`str` or :obj:`List[str]`): + The first sequence to be encoded. This can be a string, a list of strings (tokenized string using + the `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids` + method) + text_pair (:obj:`str` or :obj:`List[str]`, `optional`, defaults to :obj:`None`): + Optional second sequence to be encoded. This can be a string, a list of strings (tokenized + string using the `tokenize` method) or a list of integers (tokenized string ids using the + `convert_tokens_to_ids` method) + add_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`True`): + If set to ``True``, the sequences will be encoded with the special tokens relative + to their model. + max_length (:obj:`int`, `optional`, defaults to :obj:`None`): + If set to a number, will limit the total sequence returned so that it has a maximum length. + If there are overflowing tokens, those will be added to the returned dictionary + stride (:obj:`int`, `optional`, defaults to ``0``): + If set to a number along with max_length, the overflowing tokens returned will contain some tokens + from the main sequence returned. The value of this argument defines the number of additional tokens. + truncation_strategy (:obj:`str`, `optional`, defaults to `longest_first`): + String selected in the following options: + + - 'longest_first' (default) Iteratively reduce the inputs sequence until the input is under max_length + starting from the longest one at each token (when there is a pair of input sequences) + - 'only_first': Only truncate the first sequence + - 'only_second': Only truncate the second sequence + - 'do_not_truncate': Does not truncate (raise an error if the input sequence is longer than max_length) + pad_to_max_length (:obj:`bool`, `optional`, defaults to :obj:`False`): + If set to True, the returned sequences will be padded according to the model's padding side and + padding index, up to their max length. If no max length is specified, the padding is done up to the + model's max length. The tokenizer padding sides are handled by the class attribute `padding_side` + which can be set to the following strings: + + - 'left': pads on the left of the sequences + - 'right': pads on the right of the sequences + Defaults to False: no padding. + return_tensors (:obj:`str`, `optional`, defaults to :obj:`None`): + Can be set to 'tf' or 'pt' to return respectively TensorFlow :obj:`tf.constant` + or PyTorch :obj:`torch.Tensor` instead of a list of python integers. + return_token_type_ids (:obj:`bool`, `optional`, defaults to :obj:`None`): + Whether to return token type IDs. If left to the default, will return the token type IDs according + to the specific tokenizer's default, defined by the :obj:`return_outputs` attribute. + + `What are token type IDs? <../glossary.html#token-type-ids>`_ + return_attention_mask (:obj:`bool`, `optional`, defaults to :obj:`none`): + Whether to return the attention mask. If left to the default, will return the attention mask according + to the specific tokenizer's default, defined by the :obj:`return_outputs` attribute. + + `What are attention masks? <../glossary.html#attention-mask>`__ + return_overflowing_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`): + Set to True to return overflowing token information (default False). + return_special_tokens_mask (:obj:`bool`, `optional`, defaults to :obj:`False`): + Set to True to return special tokens mask information (default False). + return_offsets_mapping (:obj:`bool`, `optional`, defaults to :obj:`False`): + Set to True to return (char_start, char_end) for each token (default False). + If using Python's tokenizer, this method will raise NotImplementedError. This one is only available on + Rust-based tokenizers inheriting from PreTrainedTokenizerFast. + **kwargs: passed to the `self.tokenize()` method + + Return: + A Dictionary of shape:: + + { + input_ids: list[int], + token_type_ids: list[int] if return_token_type_ids is True (default) + attention_mask: list[int] if return_attention_mask is True (default) + overflowing_tokens: list[int] if a ``max_length`` is specified and return_overflowing_tokens is True + num_truncated_tokens: int if a ``max_length`` is specified and return_overflowing_tokens is True + special_tokens_mask: list[int] if ``add_special_tokens`` if set to ``True`` and return_special_tokens_mask is True + } + + With the fields: + + - ``input_ids``: list of token ids to be fed to a model + - ``token_type_ids``: list of token type ids to be fed to a model + - ``attention_mask``: list of indices specifying which tokens should be attended to by the model + - ``overflowing_tokens``: list of overflowing tokens if a max length is specified. + - ``num_truncated_tokens``: number of overflowing tokens a ``max_length`` is specified + - ``special_tokens_mask``: if adding special tokens, this is a list of [0, 1], with 0 specifying special added + tokens and 1 specifying sequence tokens. + """ + + def get_input_ids(text): + if isinstance(text, str): + tokens = self.tokenize( + text, add_special_tokens=add_speci al_tokens, **kwargs + ) + return self.convert_tokens_to_ids(tokens) + elif ( + isinstance(text, (list, tuple)) + and len(text) > 0 + and isinstance(text[0], str) + ): + return self.convert_tokens_to_ids(text) + elif ( + isinstance(text, (list, tuple)) + and len(text) > 0 + and isinstance(text[0], int) + ): + return text + else: + raise ValueError( + "Input is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers." + ) + + if return_offsets_mapping: + raise NotImplementedError( + "return_offset_mapping is not available when using Python tokenizers." + "To use this feature, change your tokenizer to one deriving from " + "transformers.PreTrainedTokenizerFast." + "More information on available tokenizers at " + "https://github.com/huggingface/transformers/pull/2674" + ) + + # Throw an error if we can pad because there is no padding token + if pad_to_max_length and self.pad_token_id is None: + raise ValueError( + "Unable to set proper padding strategy as the tokenizer does not have a padding token. In this case please set the `pad_token` `(tokenizer.pad_token = tokenizer.eos_token e.g.)` or add a new pad token via the function add_special_tokens if you want to use a padding strategy" + ) + + first_ids = get_input_ids(text) + second_ids = get_input_ids(text_pair) if text_pair is not None else None + + return self.prepare_for_model( + first_ids, + pair_ids=second_ids, + max_length=max_length, + pad_to_max_length=pad_to_max_length, + add_special_tokens=add_special_tokens, + stride=stride, + truncation_strategy=truncation_strategy, + return_tensors=return_tensors, + return_attention_mask=return_attention_mask, + return_token_type_ids=return_token_type_ids, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + ) + + def batch_encode_plus( + self, + batch_text_or_text_pairs: Union[str, List[str]], + add_special_tokens: bool = True, + max_length: Optional[int] = None, + stride: int = 0, + truncation_strategy: str = "longest_first", + pad_to_max_length: bool = False, + return_tensors: Optional[str] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_masks: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_masks: bool = False, + return_offsets_mapping: bool = False, + return_input_lengths: bool = False, + **kwargs + ): + """ + Returns a dictionary containing the encoded sequence or sequence pair and additional information: + the mask for sequence classification and the overflowing elements if a ``max_length`` is specified. + + Args: + batch_text_or_text_pairs (:obj:`List[str]` or :obj:`List[List[str]]`): + Batch of sequences or pair of sequences to be encoded. + This can be a list of string/string-sequences/int-sequences or a list of pair of + string/string-sequences/int-sequence (see details in encode_plus) + add_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`True`): + If set to ``True``, the sequences will be encoded with the special tokens relative + to their model. + max_length (:obj:`int`, `optional`, defaults to :obj:`None`): + If set to a number, will limit the total sequence returned so that it has a maximum length. + If there are overf lowing tokens, those will be added to the returned dictionary + stride (:obj:`int`, `optional`, defaults to ``0``): + If set to a number along with max_length, the overflowing tokens returned will contain some tokens + from the main sequence returned. The value of this argument defines the number of additional tokens. + truncation_strategy (:obj:`str`, `optional`, defaults to `longest_first`): + String selected in the following options: + + - 'longest_first' (default) Iteratively reduce the inputs sequence until the input is under max_length + starting from the longest one at each token (when there is a pair of input sequences) + - 'only_first': Only truncate the first sequence + - 'only_second': Only truncate the second sequence + - 'do_not_truncate': Does not truncate (raise an error if the input sequence is longer than max_length) + pad_to_max_length (:obj:`bool`, `optional`, defaults to :obj:`False`): + If set to True, the returned sequences will be padded according to the model's padding side and + padding index, up to their max length. If no max length is specified, the padding is done up to the + model's max length. The tokenizer padding sides are handled by the class attribute `padding_side` + which can be set to the following strings: + + - 'left': pads on the left of the sequences + - 'right': pads on the right of the sequences + Defaults to False: no padding. + return_tensors (:obj:`str`, `optional`, defaults to :obj:`None`): + Can be set to 'tf' or 'pt' to return respectively TensorFlow :obj:`tf.constant` + or PyTorch :obj:`torch.Tensor` instead of a list of python integers. + return_token_type_ids (:obj:`bool`, `optional`, defaults to :obj:`None`): + Whether to return token type IDs. If left to the default, will return the token type IDs according + to the specific tokenizer's default, defined by the :obj:`return_outputs` attribute. + + `What are token type IDs? <../glossary.html#token-type-ids>`_ + return_attention_masks (:obj:`bool`, `optional`, defaults to :obj:`none`): + Whether to return the attention mask. If left to the default, will return the attention mask according + to the specific tokenizer's default, defined by the :obj:`return_outputs` attribute. + + `What are attention masks? <../glossary.html#attention-mask>`__ + return_overflowing_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`): + Set to True to return overflowing token information (default False). + return_special_tokens_masks (:obj:`bool`, `optional`, defaults to :obj:`False`): + Set to True to return special tokens mask information (default False). + return_offsets_mapping (:obj:`bool`, `optional`, defaults to :obj:`False`): + Set to True to return (char_start, char_end) for each token (default False). + If using Python's tokenizer, this method will raise NotImplementedError. This one is only available on + Rust-based tokenizers inheriting from PreTrainedTokenizerFast. + return_input_lengths (:obj:`bool`, `optional`, defaults to :obj:`False`): + If set the resulting dictionary will include the length of each sample + **kwargs: passed to the `self.tokenize()` method + + Return: + A Dictionary of shape:: + + { + input_ids: list[List[int]], + token_type_ids: list[List[int]] if return_token_type_ids is True (default) + attention_mask: list[List[int]] if return_attention_mask is True (default) + overflowing_tokens: list[List[int]] if a ``max_length`` is specified and return_ove rflowing_tokens is True + num_truncated_tokens: List[int] if a ``max_length`` is specified and return_overflowing_tokens is True + special_tokens_mask: list[List[int]] if ``add_special_tokens`` if set to ``True`` and return_special_tokens_mask is True + } + + With the fields: + + - ``input_ids``: list of token ids to be fed to a model + - ``token_type_ids``: list of token type ids to be fed to a model + - ``attention_mask``: list of indices specifying which tokens should be attended to by the model + - ``overflowing_tokens``: list of overflowing tokens if a max length is specified. + - ``num_truncated_tokens``: number of overflowing tokens a ``max_length`` is specified + - ``special_tokens_mask``: if adding special tokens, this is a list of [0, 1], with 0 specifying special added + tokens and 1 specifying sequence tokens. + """ + + def get_input_ids(text): + if isinstance(text, str): + tokens = self.tokenize( + text, add_special_tokens=add_special_tokens, **kwargs + ) + return self.convert_tokens_to_ids(tokens) + elif ( + isinstance(text, (list, tuple)) + and len(text) > 0 + and isinstance(text[0], str) + ): + return self.convert_tokens_to_ids(text) + elif ( + isinstance(text, (list, tuple)) + and len(text) > 0 + and isinstance(text[0], int) + ): + return text + else: + raise ValueError( + "Input is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers." + ) + + # Throw an error if we can pad because there is no padding token + if pad_to_max_length and self.pad_token_id is None: + raise ValueError( + "Unable to set proper padding strategy as the tokenizer does not have a padding token. In this case please set the `pad_token` `(tokenizer.pad_token = tokenizer.eos_token e.g.)` or add a new pad token via the function add_special_tokens if you want to use a padding strategy" + ) + + if return_offsets_mapping: + raise NotImplementedError( + "return_offset_mapping is not available when using Python tokenizers." + "To use this feature, change your tokenizer to one deriving from " + "transformers.PreTrainedTokenizerFast." + "More information on available tokenizers at " + "https://github.com/huggingface/transformers/pull/2674" + ) + + input_ids = [] + for ids_or_pair_ids in batch_text_or_text_pairs: + if isinstance(ids_or_pair_ids, (list, tuple)) and len(ids_or_pair_ids) == 2: + ids, pair_ids = ids_or_pair_ids + else: + ids, pair_ids = ids_or_pair_ids, None + + first_ids = get_input_ids(ids) + second_ids = get_input_ids(pair_ids) if pair_ids is not None else None + input_ids.append((first_ids, second_ids)) + + if max_length is None and pad_to_max_length: + + def total_sequence_length(input_pairs): + first_ids, second_ids = input_pairs + return len(first_ids) + ( + self.num_added_tokens() + if second_ids is None + else (len(second_ids) + self.num_added_tokens(pair=True)) + ) + + max_length = max([total_sequence_length(ids) for ids in input_ids]) + + batch_outputs = {} + for first_ids, second_ids in input_ids: + # Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by + # the model. It adds special tokens, truncates sequences if overflowing while taking into account + # the special tokens an d manages a window stride for overflowing tokens + outputs = self.prepare_for_model( + first_ids, + pair_ids=second_ids, + max_length=max_length, + pad_to_max_length=pad_to_max_length, + add_special_tokens=add_special_tokens, + stride=stride, + truncation_strategy=truncation_strategy, + return_attention_mask=return_attention_masks, + return_token_type_ids=return_token_type_ids, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_masks, + ) + + # Append the non-padded length to the output + if return_input_lengths: + outputs["input_len"] = len(outputs["input_ids"]) + + for key, value in outputs.items(): + if key not in batch_outputs: + batch_outputs[key] = [] + batch_outputs[key].append(value) + + if return_tensors is not None: + + # Do the tensor conversion in batch + for key, value in batch_outputs.items(): + if return_tensors == "tf" and is_tf_available(): + try: + batch_outputs[key] = tf.constant(value) + except ValueError: + if None in [item for sequence in value for item in sequence]: + raise ValueError(self.NO_PAD_TOKEN_FOR_BATCH_MSG) + else: + raise ValueError(self.UNEVEN_SEQUENCES_FOR_BATCH_MSG) + elif return_tensors == "pt" and is_torch_available(): + try: + batch_outputs[key] = torch.tensor(value) + except ValueError: + raise ValueError(self.UNEVEN_SEQUENCES_FOR_BATCH_MSG) + except RuntimeError: + if None in [item for sequence in value for item in sequence]: + raise ValueError(self.NO_PAD_TOKEN_FOR_BATCH_MSG) + else: + raise + elif return_tensors is not None: + logger.warning( + "Unable to convert output to tensors format {}, PyTorch or TensorFlow is not available.".format( + return_tensors + ) + ) + + return batch_outputs + + def prepare_for_model( + self, + ids: List[int], + pair_ids: Optional[List[int]] = None, + max_length: Optional[int] = None, + add_special_tokens: bool = True, + stride: int = 0, + truncation_strategy: str = "longest_first", + pad_to_max_length: bool = False, + return_tensors: Optional[str] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + ): + """ + Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. + It adds special tokens, truncates + sequences if overflowing while taking into account the special tokens and manages a window stride for + overflowing tokens + + Args: + ids: list of tokenized input ids. Can be obtained from a string by chaining the + `tokenize` and `convert_tokens_to_ids` methods. + pair_ids: Optional second list of input ids. Can be obtained from a string by chaining the + `tokenize` and `convert_tokens_to_ids` methods. + max_length: maximum length of the returned list. Will truncate by taking into account the special tokens. + add_special_tokens: if set to ``True``, the sequences will be encoded with the special tokens relative + to their model. + stride: window s tride for overflowing tokens. Can be useful for edge effect removal when using sequential + list of inputs. + truncation_strategy: string selected in the following options: + - 'longest_first' (default) Iteratively reduce the inputs sequence until the input is under max_length + starting from the longest one at each token (when there is a pair of input sequences) + - 'only_first': Only truncate the first sequence + - 'only_second': Only truncate the second sequence + - 'do_not_truncate': Does not truncate (raise an error if the input sequence is longer than max_length) + pad_to_max_length: if set to True, the returned sequences will be padded according to the model's padding side and + padding index, up to their max length. If no max length is specified, the padding is done up to the model's max length. + The tokenizer padding sides are handled by the following strings: + - 'left': pads on the left of the sequences + - 'right': pads on the right of the sequences + Defaults to False: no padding. + return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant + or PyTorch torch.Tensor instead of a list of python integers. + return_token_type_ids: (optional) Set to False to avoid returning token_type_ids (default True). + return_attention_mask: (optional) Set to False to avoid returning attention mask (default True) + return_overflowing_tokens: (optional) Set to True to return overflowing token information (default False). + return_special_tokens_mask: (optional) Set to True to return special tokens mask information (default False). + + Return: + A Dictionary of shape:: + + { + input_ids: list[int], + token_type_ids: list[int] if return_token_type_ids is True (default) + overflowing_tokens: list[int] if a ``max_length`` is specified and return_overflowing_tokens is True + num_truncated_tokens: int if a ``max_length`` is specified and return_overflowing_tokens is True + special_tokens_mask: list[int] if ``add_special_tokens`` if set to ``True`` and return_special_tokens_mask is True + } + + With the fields: + ``input_ids``: list of token ids to be fed to a model + ``token_type_ids``: list of token type ids to be fed to a model + + ``overflowing_tokens``: list of overflowing tokens if a max length is specified. + ``num_truncated_tokens``: number of overflowing tokens a ``max_length`` is specified + ``special_tokens_mask``: if adding special tokens, this is a list of [0, 1], with 0 specifying special added + tokens and 1 specifying sequence tokens. + """ + pair = bool(pair_ids is not None) + len_ids = len(ids) + len_pair_ids = len(pair_ids) if pair else 0 + + if return_token_type_ids is None: + return_token_type_ids = "token_type_ids" in self.model_input_names + if return_attention_mask is None: + return_attention_mask = "attention_mask" in self.model_input_names + + encoded_inputs = {} + + # Handle max sequence length + total_len = ( + len_ids + + len_pair_ids + + (self.num_added_tokens(pair=pair) if add_special_tokens else 0) + ) + if max_length and total_len > max_length: + ids, pair_ids, overflowing_tokens = self.truncate_sequences( + ids, + pair_ids=pair_ids, + num_tokens_to_remove=total_len - max_length, + truncation_strategy=truncation_strategy, + stride=stride, + ) + if return_overflowing_tokens: + encoded_inputs["overflowing_tokens"] = overflowing_tokens + encoded_inputs["num_truncated_tokens"] = total_len - max_length + + # Handle special_tokens + if add_special_tokens: + sequence = self.build_inputs_with_special_tokens(ids, pair_ids) + token_type_ids = self.create_token_type_ids_from_sequences(ids, pair_ids) + else: + sequence = ids + pair_ids if pair else ids + token_type_ids = [0] * len(ids) + ([1] * len(pair_ids) if pair else []) + + if return_special_tokens_mask: + if add_special_tokens: + encoded_inputs["special_tokens_mask"] = self.get_special_tokens_mask( + ids, pair_ids + ) + else: + encoded_inputs["special_tokens_mask"] = [0] * len(sequence) + + encoded_inputs["input_ids"] = sequence + if return_token_type_ids: + encoded_inputs["token_type_ids"] = token_type_ids + + if max_length and len(encoded_inputs["input_ids"]) > max_length: + encoded_inputs["input_ids"] = encoded_inputs["input_ids"][:max_length] + if return_token_type_ids: + encoded_inputs["token_type_ids"] = encoded_inputs["token_type_ids"][ + :max_length + ] + if return_special_tokens_mask: + encoded_inputs["special_tokens_mask"] = encoded_inputs[ + "special_tokens_mask" + ][:max_length] + + if max_length is None and len(encoded_inputs["input_ids"]) > self.max_len: + logger.warning( + "Token indices sequence length is longer than the specified maximum sequence length " + "for this model ({} > {}). Running this sequence through the model will result in " + "indexing errors".format(len(ids), self.max_len) + ) + + needs_to_be_padded = pad_to_max_length and ( + max_length + and len(encoded_inputs["input_ids"]) < max_length + or max_length is None + and len(encoded_inputs["input_ids"]) < self.max_len + and self.max_len <= 10000 + ) + + if pad_to_max_length and max_length is None and self.max_len > 10000: + logger.warning( + "Sequence can't be padded as no maximum length is specified and the model maximum length is too high." + ) + + if needs_to_be_padded: + difference = (max_length if max_length is not None else self.max_len) - len( + encoded_inputs["input_ids"] + ) + + if self.padding_side == "right": + if return_attention_mask: + encoded_inputs["attention_mask"] = [1] * len( + encoded_inputs["input_ids"] + ) + [0] * difference + if return_token_type_ids: + encoded_inputs["token_type_ids"] = ( + encoded_inputs["token_type_ids"] + + [self.pad_token_type_id] * difference + ) + if return_special_tokens_mask: + encoded_inputs["special_tokens_mask"] = ( + encoded_inputs["special_tokens_mask"] + [1] * difference + ) + encoded_inputs["input_ids"] = ( + encoded_inputs["input_ids"] + [self.pad_token_id] * difference + ) + elif self.padding_side == "left": + if return_attention_mask: + encoded_inputs["attention_mask"] = [0] * difference + [1] * len( + encoded_inputs["input_ids"] + ) + if return_token_type_ids: + encoded_inputs["token_type_ids"] = [ + self.pad_token_type_id + ] * difference + encoded_inputs["token_type_ids"] + if return_special_tokens_mask: + encoded_inputs["special_tokens_mask"] = [ + 1 + ] * difference + encoded_inputs["special_tokens_mask"] + encoded_inputs["input_ids"] = [ + self.pad_token_id + ] * difference + encoded_inputs["input_ids"] + + else: + raise ValueError("Invalid padding strategy:" + str(self.padding_side)) + + elif return_attention_mask: + encoded_inputs["attention_mask"] = [1] * len(encoded_inputs["input_ids"]) + + # Prepare inputs as tensors if asked + if return_tensors == "tf" and is_tf_available(): + encoded_inputs["input_ids"] = tf.constant([encoded_inputs["input_ids"]]) + + if "token_type_ids" in encoded_inputs: + encoded_inputs["token_type_ids"] = tf.constant( + [encoded_inputs["token_type_ids"]] + ) + + if "attention_mask" in encoded_inputs: + encoded_inputs["attention_mask"] = tf.constant( + [encoded_inputs["attention_mask"]] + ) + + elif return_tensors == "pt" and is_torch_available(): + encoded_inputs["input_ids"] = torch.tensor([encoded_inputs["input_ids"]]) + + if "token_type_ids" in encoded_inputs: + encoded_inputs["token_type_ids"] = torch.tensor( + [encoded_inputs["token_type_ids"]] + ) + + if "attention_mask" in encoded_inputs: + encoded_inputs["attention_mask"] = torch.tensor( + [encoded_inputs["attention_mask"]] + ) + elif return_tensors is not None: + logger.warning( + "Unable to convert output to tensors format {}, PyTorch or TensorFlow is not available.".format( + return_tensors + ) + ) + + return encoded_inputs + + def prepare_for_tokenization(self, text, **kwargs): + """Performs any necessary transformations before tokenization""" + return text + + def truncate_sequences( + self, + ids, + pair_ids=None, + num_tokens_to_remove=0, + truncation_strategy="longest_first", + stride=0, + ): + """Truncates a sequence pair in place to the maximum length. + truncation_strategy: string selected in the following options: + - 'longest_first' (default) Iteratively reduce the inputs sequence until the input is under max_length + starting from the longest one at each token (when there is a pair of input sequences). + Overflowing tokens only contains overflow from the first sequence. + - 'only_first': Only truncate the first sequence. raise an error if the first sequence is shorter or equal to than num_tokens_to_remove. + - 'only_second': Only truncate the second sequence + - 'do_not_truncate': Does not truncate (raise an error if the input sequence is longer than max_length) + """ + if num_tokens_to_remove <= 0: + return ids, pair_ids, [] + + if truncation_strategy == "longest_first": + overflowing_tokens = [] + for _ in range(num_tokens_to_remove): + if pair_ids is None or len(ids) > len(pair_ids): + overflowing_tokens = [ids[-1]] + overflowing_tokens + ids = ids[:-1] + else: + pair_ids = pair_ids[:-1] + window_len = min(len(ids), stride) + if window_len > 0: + overflowing_tokens = ids[-window_len:] + overflowing_tokens + elif truncation_strategy == "only_first": + assert len(ids) > num_tokens_to_remove + window_len = min(len(ids), stride + num_tokens_to_remove) + overflowing_tokens = ids[-window_len:] + ids = ids[:-num_tokens_to_remove] + elif truncation_strategy == "only_second": + assert pair_ids is not None and len(pair_ids) > num_tokens_to_remove + window_len = min(len(pair_ids), stride + num_tokens_to_remove) + overflowing_tokens = pair_ids[-window_len:] + pair_ids = pair_ids[:-num_tokens_to_remove] + elif truncation_strategy == "do_not_truncate": + raise ValueError( + "Input sequence are too long for max_length. Please select a truncation strategy." + ) + else: + raise ValueError( + "Truncation_strategy should be selected in ['longest_first', 'only_first', 'only_second', 'do_not_truncate']" + ) + return (ids, pair_ids, overflowing_tokens) + + def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None): + if token_ids_1 is None: + return len(token_ids_0) * [0] + return [0] * len(token_ids_0) + [1] * len(token_ids_1) + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks + by concatenating and adding special tokens. + A RoBERTa sequence has the following format: + single sequence: X + pair of sequences: A B + """ + if token_ids_1 is None: + return token_ids_0 + return token_ids_0 + token_ids_1 + + def get_special_tokens_mask( + self, token_ids_0, token_ids_1=None, already_has_special_tokens=False + ): + """ + Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods. + + Args: + token_ids_0: list of ids (must not contain special tokens) + token_ids_1: Optional list of ids (must not contain special tokens), necessary when fetching sequence ids + for sequence pairs + already_has_special_tokens: (default False) Set to True if the token list is already formated with + special tokens for the model + + Returns: + A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + return [0] * ((len(token_ids_1) if token_ids_1 else 0) + len(token_ids_0)) + + def convert_ids_to_tokens(self, ids, skip_special_tokens=False): + """Converts a single index or a sequence of indices (integers) in a token " + (resp.) a sequence of tokens (str), using the vocabulary and added tokens. + + Args: + skip_special_tokens: Don't decode special tokens (self.all_special_tokens). Default: False + """ + if isinstance(ids, int): + if ids in self.added_tokens_decoder: + return self.added_tokens_decoder[ids] + else: + return self._convert_id_to_token(ids) + tokens = [] + for index in ids: + index = int(index) + if skip_special_tokens and index in self.all_special_ids: + continue + if index in self.added_tokens_decoder: + tokens.append(self.added_tokens_decoder[index]) + else: + tokens.append(self._convert_id_to_token(index)) + return tokens + + def _convert_id_to_token(self, index): + raise NotImplementedError + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string. + The most simple way to do it is ' '.join(self.convert_ids_to_tokens(token_ids)) + but we often want to remove sub-word tokenization artifacts at the same time. + """ + return " ".join(self.convert_ids_to_tokens(tokens)) + + def decode( + self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True + ): + """ + Converts a sequence of ids (integer) in a string, using the tokenizer and vocabulary + with options to remove special tokens and clean up tokenization spaces. + Similar to doing ``self.convert_tokens_to_string(self.convert_ids_to_tokens(token _ids))``. + + Args: + token_ids: list of tokenized input ids. Can be obtained using the `encode` or `encode_plus` methods. + skip_special_tokens: if set to True, will replace special tokens. + clean_up_tokenization_spaces: if set to True, will clean up the tokenization spaces. + """ + filtered_tokens = self.convert_ids_to_tokens( + token_ids, skip_special_tokens=skip_special_tokens + ) + + # To avoid mixing byte-level and unicode for byte-level BPT + # we need to build string separatly for added tokens and byte-level tokens + # cf. https://github.com/huggingface/transformers/issues/1133 + sub_texts = [] + current_sub_text = [] + for token in filtered_tokens: + if skip_special_tokens and token in self.all_special_ids: + continue + if token in self.added_tokens_encoder: + if current_sub_text: + sub_texts.append(self.convert_tokens_to_string(current_sub_text)) + current_sub_text = [] + sub_texts.append(token) + else: + current_sub_text.append(token) + if current_sub_text: + sub_texts.append(self.convert_tokens_to_string(current_sub_text)) + text = " ".join(sub_texts) + + if clean_up_tokenization_spaces: + clean_text = self.clean_up_tokenization(text) + return clean_text + else: + return text + + @property + def special_tokens_map(self): + """A dictionary mapping special token class attribute (cls_token, unk_token...) to their + values ('', ''...) + """ + set_attr = {} + for attr in self.SPECIAL_TOKENS_ATTRIBUTES: + attr_value = getattr(self, "_" + attr) + if attr_value: + set_attr[attr] = attr_value + return set_attr + + @property + def all_special_tokens(self): + """List all the special tokens ('', ''...) mapped to class attributes + (cls_token, unk_token...). + """ + all_toks = [] + set_attr = self.special_tokens_map + for attr_value in set_attr.values(): + all_toks = all_toks + ( + list(attr_value) + if isinstance(attr_value, (list, tuple)) + else [attr_value] + ) + all_toks = list(set(all_toks)) + return all_toks + + @property + def all_special_ids(self): + """List the vocabulary indices of the special tokens ('', ''...) mapped to + class attributes (cls_token, unk_token...). + """ + all_toks = self.all_special_tokens + all_ids = self.convert_tokens_to_ids(all_toks) + return all_ids + + @staticmethod + def clean_up_tokenization(out_string): + """Clean up a list of simple English tokenization artifacts like spaces before punctuations and abreviated forms.""" + out_string = ( + out_string.replace(" .", ".") + .replace(" ?", "?") + .replace(" !", "!") + .replace(" ,", ",") + .replace(" ' ", "'") + .replace(" n't", "n't") + .replace(" 'm", "'m") + .replace(" do not", " don't") + .replace(" 's", "'s") + .replace(" 've", "'ve") + .replace(" 're", "'re") + ) + return out_string + + +class PreTrainedTokenizerFast(PreTrainedTokenizer): + + model_input_names = ["token_type_ids", "attention_mask"] + + def __init__(self, tokenizer: BaseTokenizer, **kwargs): + if tokenizer is None: + raise ValueError("Provided tokenizer cannot be None") + self._tokenizer = tokenizer + + super().__init__(**kwargs) + self.max_len_single_sentence = self.max_len - self.num_added_tokens( + False + ) # take into account special tokens + self.max_len_sentences_pair = self.max_len - self.num_added_tokens( + True + ) # t ake into account special tokens + + @property + def tokenizer(self): + return self._tokenizer + + @property + def decoder(self): + return self._tokenizer._tokenizer.decoder + + @property + def vocab_size(self): + return self._tokenizer.get_vocab_size(with_added_tokens=False) + + def __len__(self): + return self._tokenizer.get_vocab_size(with_added_tokens=True) + + @PreTrainedTokenizer.bos_token.setter + def bos_token(self, value): + self._bos_token = value + self._update_special_tokens() + + @PreTrainedTokenizer.eos_token.setter + def eos_token(self, value): + self._eos_token = value + self._update_special_tokens() + + @PreTrainedTokenizer.unk_token.setter + def unk_token(self, value): + self._unk_token = value + self._update_special_tokens() + + @PreTrainedTokenizer.sep_token.setter + def sep_token(self, value): + self._sep_token = value + self._update_special_tokens() + + @PreTrainedTokenizer.pad_token.setter + def pad_token(self, value): + self._pad_token = value + self._update_special_tokens() + + @PreTrainedTokenizer.cls_token.setter + def cls_token(self, value): + self._cls_token = value + self._update_special_tokens() + + @PreTrainedTokenizer.mask_token.setter + def mask_token(self, value): + self._mask_token = value + self._update_special_tokens() + + @PreTrainedTokenizer.additional_special_tokens.setter + def additional_special_tokens(self, value): + self._additional_special_tokens = value + self._update_special_tokens() + + def _update_special_tokens(self): + if self._tokenizer is not None: + self._tokenizer.add_special_tokens(self.all_special_tokens) + + def _convert_encoding( + self, + encoding, + return_tensors=None, + return_token_type_ids=None, + return_attention_mask=None, + return_overflowing_tokens=False, + return_special_tokens_mask=False, + return_offsets_mapping=False, + ): + if return_token_type_ids is None: + return_token_type_ids = "token_type_ids" in self.model_input_names + if return_attention_mask is None: + return_attention_mask = "attention_mask" in self.model_input_names + + if return_overflowing_tokens and encoding.overflowing is not None: + encodings = [encoding] + encoding.overflowing + else: + encodings = [encoding] + + encoding_dict = defaultdict(list) + for e in encodings: + encoding_dict["input_ids"].append(e.ids) + + if return_token_type_ids: + encoding_dict["token_type_ids"].append(e.type_ids) + if return_attention_mask: + encoding_dict["attention_mask"].append(e.attention_mask) + if return_special_tokens_mask: + encoding_dict["special_tokens_mask"].append(e.special_tokens_mask) + if return_offsets_mapping: + encoding_dict["offset_mapping"].append( + [e.original_str.offsets(o) for o in e.offsets] + ) + + # Prepare inputs as tensors if asked + if return_tensors == "tf" and is_tf_available(): + encoding_dict["input_ids"] = tf.constant(encoding_dict["input_ids"]) + if "token_type_ids" in encoding_dict: + encoding_dict["token_type_ids"] = tf.constant( + encoding_dict["token_type_ids"] + ) + + if "attention_mask" in encoding_dict: + encoding_dict["attention_mask"] = tf.constant( + encoding_dict["attention_mask"] + ) + + elif return_tensors == "pt" and is_torch_available(): + encoding_dict["input_ids"] = torch.tensor(encoding_dict["input_ids"]) + if "token_type_ids" in encoding_dict: + encoding_dict["token_type_ids"] = torch.tensor( + encoding_dict["token_type_ids"] + ) + + if "attention_mask" in encoding_dict: + encoding_dict["attention_mask"] = torch.tensor( + encoding_dict["attention_mask"] + ) + elif return_tensors is not None: + logger.warning( + "Unable to convert output to tensors format {}, PyTorch or TensorFlow is not available.".format( + return_tensors + ) + ) + + return encoding_dict + + def _convert_token_to_id_with_added_voc(self, token): + id = self._tokenizer.token_to_id(token) + if id is None: + return self.unk_token_id + return id + + def _convert_id_to_token(self, index): + return self._tokenizer.id_to_token(int(index)) + + def convert_tokens_to_string(self, tokens): + return self._tokenizer.decode(tokens) + + def add_tokens(self, new_tokens): + if isinstance(new_tokens, str): + new_tokens = [new_tokens] + return self._tokenizer.add_tokens(new_tokens) + + def add_special_tokens(self, special_tokens_dict): + added = super().add_special_tokens(special_tokens_dict) + self._update_special_tokens() + return added + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + if token_ids_1 is None: + return token_ids_0 + else: + return token_ids_0 + token_ids_1 + + def num_added_tokens(self, pair=False): + return self.tokenizer.num_special_tokens_to_add(pair) + + def tokenize(self, text, **kwargs): + return self.tokenizer.encode(text).tokens + + def batch_encode_plus( + self, + batch_text_or_text_pairs: Optional[Union[List[str], List[Tuple[str]]]] = None, + add_special_tokens: bool = True, + max_length: Optional[int] = None, + stride: int = 0, + truncation_strategy: str = "longest_first", + pad_to_max_length: bool = False, + return_tensors: Optional[str] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + **kwargs + ): + if not add_special_tokens: + logger.warning( + "Fast tokenizers add special tokens by default. To remove special tokens, please specify" + "`add_special_tokens=False` during the initialisation rather than when calling `encode`," + "`encode_plus` or `batch_encode_plus`." + ) + + # Needed if we have to return a tensor + pad_to_max_length = pad_to_max_length or (return_tensors is not None) + + # Throw an error if we can pad because there is no padding token + if pad_to_max_length and self.pad_token_id is None: + raise ValueError( + "Unable to set proper padding strategy as the tokenizer does not have a padding token" + ) + + # Set the truncation and padding strategy and restore the initial configuration + with truncate_and_pad( + tokenizer=self._tokenizer, + max_length=max_length, + stride=stride, + strategy=truncation_strategy, + pad_to_max_length=pad_to_max_length, + padding_side=self.padding_side, + pad_token_id=self.pad_token_id, + pad_token_type_id=self.pad_token_type_id, + pad_token=self._pad_token, + ): + + if not isinstance(batch_text_or_text_pairs, list): + raise TypeError( + "batch_text_or_text_pairs has to be a list (got {})".format( + type(batch_text_or_text_pairs) + ) + ) + + # Avoid thread overhead if only one example. + if len(batch_text_or_text_pairs) == 1: + if isinstance(batch_text_or_text_pairs[0], (tuple, list)): + tokens = self._tokenizer.encode(*batch_text_or_text_pairs[0]) + else: + tokens = self._tokenizer.encode(batch_text_or_text_pairs[0]) + tokens = [tokens] + else: + tokens = self._tokenizer.encode_batch(batch_text_or_text_pairs) + + # Convert encoding to dict + tokens = [ + self._convert_encoding( + encoding=encoding, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + ) + for encoding in tokens + ] + + # Sanitize the output to have dict[list] from list[dict] + sanitized = {} + for key in tokens[0].keys(): + stack = [e for item in tokens for e in item[key]] + if return_tensors == "tf": + stack = tf.stack(stack, axis=0) + elif return_tensors == "pt": + stack = torch.stack(stack, dim=0) + elif not return_tensors and len(stack) == 1: + stack = stack[0] + + sanitized[key] = stack + + # If returning overflowing tokens, we need to return a mapping + # from the batch idx to the original sample + if return_overflowing_tokens: + overflow_to_sample_mapping = [ + i if len(item["input_ids"]) == 1 else [i] * len(item["input_ids"]) + for i, item in enumerate(tokens) + ] + sanitized["overflow_to_sample_mapping"] = overflow_to_sample_mapping + return sanitized + + def encode_plus( + self, + text: str, + text_pair: Optional[str] = None, + add_special_tokens: bool = False, + max_length: Optional[int] = None, + pad_to_max_length: bool = False, + stride: int = 0, + truncation_strategy: str = "longest_first", + return_tensors: Optional[bool] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + **kwargs + ): + batched_input = [(text, text_pair)] if text_pair else [text] + batched_output = self.batch_encode_plus( + batched_input, + add_special_tokens=add_special_tokens, + max_length=max_length, + stride=stride, + truncation_strategy=truncation_strategy, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + pad_to_max_length=pad_to_max_length, + **kwargs, + ) + + # Return tensor is None, then we can remove the leading batch axis + if not return_tensors: + return { + key: value[0] if isinstance(value[0], list) else value + for key, value in batched_output.items() + } + else: + return batched_output + + def decode( + self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True + ): + text = self.tokenizer.decode(token_ids, skip_special_tokens) + + if clean_up_tokenization_spaces: + clean_text = self.clean_up_tokenization(text) + return clean_text + else: + return text + + def save_vocabulary(self, save_directory): + if os.path.isdir(save_directory): + files = self._tokenizer.save(save_directory) + else: + folder, file = os.path.split(os.path.abspath(save_directory)) + files = self._tokenizer.save(folder, name=file) + + return tuple(files) + + +def trim_batch( + input_ids, + pad_token_id, + attention_mask=None, +): + """Remove columns that are populated exclusively by pad_token_id""" + keep_column_mask = input_ids.ne(pad_token_id).any(dim=0) + if attention_mask is None: + return input_ids[:, keep_column_mask] + else: + return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask]) diff --git a/model/third_party/HMNet/ThirdParty/ROUGE/ROUGE-1.5.5/README.txt b/model/third_party/HMNet/ThirdParty/ROUGE/ROUGE-1.5.5/README.txt new file mode 100644 index 0000000000000000000000000000000000000000..b7160cd67b2514b207fad1d259c8cf10276902de --- /dev/null +++ b/model/third_party/HMNet/ThirdParty/ROUGE/ROUGE-1.5.5/README.txt @@ -0,0 +1,295 @@ +A Brief Introduction of the ROUGE Summary Evaluation Package +by Chin-Yew LIN +Univeristy of Southern California/Information Sciences Institute +05/26/2005 + +<> + +(1) Correct the resampling routine which ignores the last evaluation + item in the evaluation list. Therefore, the average scores reported + by ROUGE is only based on the first N-1 evaluation items. + Thanks Barry Schiffman at Columbia University to report this bug. + This bug only affects ROUGE-1.5.X. For pre-1.5 ROUGE, it only affects + the computation of confidence interval (CI) estimation, i.e. CI is only + estimated by the first N-1 evaluation items, but it *does not* affect + average scores. +(2) Correct stemming on multi-token BE heads and modifiers. + Previously, only single token heads and modifiers were assumed. +(3) Change read_text and read_text_LCS functions to read exact words or + bytes required by users. Previous versions carry out whitespace + compression and other string clear up actions before enforce the length + limit. +(4) Add the capability to score summaries in Basic Element (BE) + format by using option "-3", standing for BE triple. There are 6 + different modes in BE scoring. We suggest using *"-3 HMR"* on BEs + extracted from Minipar parse trees based on our correlation analysis + of BE-based scoring vs. human judgements on DUC 2002 & 2003 automatic + summaries. +(5) ROUGE now generates three scores (recall, precision and F-measure) + for each evaluation. Previously, only one score is generated + (recall). Precision and F-measure scores are useful when the target + summary length is not enforced. Only recall scores were necessary since + DUC guideline dictated the limit on summary length. For comparison to + previous DUC results, please use the recall scores. The default alpha + weighting for computing F-measure is 0.5. Users can specify a + particular alpha weighting that fits their application scenario using + option "-p alpha-weight". Where *alpha-weight* is a number between 0 + and 1 inclusively. +(6) Pre-1.5 version of ROUGE used model average to compute the overall + ROUGE scores when there are multiple references. Starting from v1.5+, + ROUGE provides an option to use the best matching score among the + references as the final score. The model average option is specified + using "-f A" (for Average) and the best model option is specified + using "-f B" (for the Best). The "-f A" option is better when use + ROUGE in summarization evaluations; while "-f B" option is better when + use ROUGE in machine translation (MT) and definition + question-answering (DQA) evaluations since in a typical MT or DQA + evaluation scenario matching a single reference translation or + definition answer is sufficient. However, it is very likely that + multiple different but equally good summaries exist in summarization + evaluation. +(7) ROUGE v1.5+ also provides the option to specify whether model unit + level average will be used (macro-average, i.e. treating every model + unit equally) or token level average will be used (micro-average, + i.e. treating every token equally). In summarization evaluation, we + suggest using model unit level average and this is the default setting + in ROUGE. To specify other average mode, use "-t 0" (default) for + model unit level average, "-t 1" for token level average and "-t 2" + for output raw token counts in models, peers, and matches. +(8) ROUGE now offers the option to use file list as the configuration + file. The input format of the summary files are specified using the + "-z INPUT-FORMAT" option. The INPUT-FORMAT can be SEE, SPL, ISI or + SIMPLE. When "-z" is specified, ROUGE assumed that the ROUGE + evaluation configuration file is a file list with each evaluation + instance per line in the following format: + +peer_path1 model_path1 model_path2 ... model_pathN +peer_path2 model_path1 model_path2 ... model_pathN +... +peer_pathM model_path1 model_path2 ... model_pathN + + The first file path is the peer summary (system summary) and it + follows with a list of model summaries (reference summaries) separated + by white spaces (spaces or tabs). +(9) When stemming is applied, a new WordNet exception database based + on WordNet 2.0 is used. The new database is included in the data + directory. + +<> + +(1) Use "-h" option to see a list of options. + Summary: +Usage: ROUGE-1.5.4.pl + [-a (evaluate all systems)] + [-c cf] + [-d (print per evaluation scores)] + [-e ROUGE_EVAL_HOME] + [-h (usage)] + [-b n-bytes|-l n-words] + [-m (use Porter stemmer)] + [-n max-ngram] + [-s (remove stopwords)] + [-r number-of-samples (for resampling)] + [-2 max-gap-length (if < 0 then no gap length limit)] + [-3 ] + [-u (include unigram in skip-bigram) default no)] + [-U (same as -u but also compute regular skip-bigram)] + [-w weight (weighting factor for WLCS)] + [-v (verbose)] + [-x (do not calculate ROUGE-L)] + [-f A|B (scoring formula)] + [-p alpha (0 <= alpha <=1)] + [-t 0|1|2 (count by token instead of sentence)] + [-z ] + [] + + ROUGE-eval-config-file: Specify the evaluation setup. Three files come with the ROUGE + evaluation package, i.e. ROUGE-test.xml, verify.xml, and verify-spl.xml are + good examples. + + systemID: Specify which system in the ROUGE-eval-config-file to perform the evaluation. + If '-a' option is used, then all systems are evaluated and users do not need to + provide this argument. + + Default: + When running ROUGE without supplying any options (except -a), the following defaults are used: + (1) ROUGE-L is computed; + (2) 95% confidence interval; + (3) No stemming; + (4) Stopwords are inlcuded in the calculations; + (5) ROUGE looks for its data directory first through the ROUGE_EVAL_HOME environment variable. If + it is not set, the current directory is used. + (6) Use model average scoring formula. + (7) Assign equal importance of ROUGE recall and precision in computing ROUGE f-measure, i.e. alpha=0.5. + (8) Compute average ROUGE by averaging sentence (unit) ROUGE scores. + Options: + -2: Compute skip bigram (ROGUE-S) co-occurrence, also specify the maximum gap length between two words (skip-bigram) + -u: Compute skip bigram as -2 but include unigram, i.e. treat unigram as "start-sentence-symbol unigram"; -2 has to be specified. + -3: Compute BE score. + H -> head only scoring (does not applied to Minipar-based BEs). + HM -> head and modifier pair scoring. + HMR -> head, modifier and relation triple scoring. + HM1 -> H and HM scoring (same as HM for Minipar-based BEs). + HMR1 -> HM and HMR scoring (same as HMR for Minipar-based BEs). + HMR2 -> H, HM and HMR scoring (same as HMR for Minipar-based BEs). + -a: Evaluate all systems specified in the ROUGE-eval-config-file. + -c: Specify CF\% (0 <= CF <= 100) confidence interval to compute. The default is 95\% (i.e. CF=95). + -d: Print per evaluation average score for each system. + -e: Specify ROUGE_EVAL_HOME directory where the ROUGE data files can be found. + This will overwrite the ROUGE_EVAL_HOME specified in the environment variable. + -f: Select scoring formula: 'A' => model average; 'B' => best model + -h: Print usage information. + -b: Only use the first n bytes in the system/peer summary for the evaluation. + -l: Only use the first n words in the system/peer summary for the evaluation. + -m: Stem both model and system summaries using Porter stemmer before computing various statistics. + -n: Compute ROUGE-N up to max-ngram length will be computed. + -p: Relative importance of recall and precision ROUGE scores. Alpha -> 1 favors precision, Alpha -> 0 favors recall. + -s: Remove stopwords in model and system summaries before computing various statistics. + -t: Compute average ROUGE by averaging over the whole test corpus instead of sentences (units). + 0: use sentence as counting unit, 1: use token as couting unit, 2: same as 1 but output raw counts + instead of precision, recall, and f-measure scores. 2 is useful when computation of the final, + precision, recall, and f-measure scores will be conducted later. + -r: Specify the number of sampling point in bootstrap resampling (default is 1000). + Smaller number will speed up the evaluation but less reliable confidence interval. + -w: Compute ROUGE-W that gives consecutive matches of length L in an LCS a weight of 'L^weight' instead of just 'L' as in LCS. + Typically this is set to 1.2 or other number greater than 1. + -v: Print debugging information for diagnositic purpose. + -x: Do not calculate ROUGE-L. + -z: ROUGE-eval-config-file is a list of peer-model pair per line in the specified format (SEE|SPL|ISI|SIMPLE). + +(2) Please read RELEASE-NOTE.txt for information about updates from previous versions. + +(3) The following files coming with this package in the "sample-output" + directory are the expected output of the evaluation files in the + "sample-test" directory. + (a) use "data" as ROUGE_EVAL_HOME, compute 95% confidence interval, + compute ROUGE-L (longest common subsequence, default), + compute ROUGE-S* (skip bigram) without gap length limit, + compute also ROUGE-SU* (skip bigram with unigram), + run resampling 1000 times, + compute ROUGE-N (N=1 to 4), + compute ROUGE-W (weight = 1.2), and + compute these ROUGE scores for all systems: + ROUGE-test-c95-2-1-U-r1000-n4-w1.2-a.out + > ROUGE-1.5.4.pl -e data -c 95 -2 -1 -U -r 1000 -n 4 -w 1.2 -a ROUGE-test.xml + + (b) Same as (a) but apply Porter's stemmer on the input: + ROUGE-test-c95-2-1-U-r1000-n4-w1.2-a-m.out + > ROUGE-1.5.4.pl -e data -c 95 -2 -1 -U -r 1000 -n 4 -w 1.2 -m -a ROUGE-test.xml + + (c) Same as (b) but apply also a stopword list on the input: + ROUGE-test-c95-2-1-U-r1000-n4-w1.2-a-m-s.out + > ROUGE-1.5.4.pl -e data -c 95 -2 -1 -U -r 1000 -n 4 -w 1.2 -m -s -a ROUGE-test.xml + + (d) Same as (a) but apply a summary length limit of 10 words: + ROUGE-test-c95-2-1-U-r1000-n4-w1.2-l10-a.out + > ROUGE-1.5.4.pl -e data -c 95 -2 -1 -U -r 1000 -n 4 -w 1.2 -l 10 -a ROUGE-test.xml + + (e) Same as (d) but apply Porter's stemmer on the input: + ROUGE-test-c95-2-1-U-r1000-n4-w1.2-l10-a-m.out + > ROUGE-1.5.4.pl -e data -c 95 -2 -1 -U -r 1000 -n 4 -w 1.2 -l 10 -m -a ROUGE-test.xml + + (f) Same as (e) but apply also a stopword list on the input: + ROUGE-test-c95-2-1-U-r1000-n4-w1.2-l10-a-m-s.out + > ROUGE-1.5.4.pl -e data -c 95 -2 -1 -U -r 1000 -n 4 -w 1.2 -l 10 -m -s -a ROUGE-test.xml + + (g) Same as (a) but apply a summary lenght limit of 75 bytes: + ROUGE-test-c95-2-1-U-r1000-n4-w1.2-b75-a.out + > ROUGE-1.5.4.pl -e data -c 95 -2 -1 -U - r 1000 -n 4 -w 1.2 -b 75 -a ROUGE-test.xml + + (h) Same as (g) but apply Porter's stemmer on the input: + ROUGE-test-c95-2-1-U-r1000-n4-w1.2-b75-a-m.out + > ROUGE-1.5.4.pl -e data -c 95 -2 -1 -U -r 1000 -n 4 -w 1.2 -b 75 -m -a ROUGE-test.xml + + (i) Same as (h) but apply also a stopword list on the input: + ROUGE-test-c95-2-1-U-r1000-n4-w1.2-b75-a-m-s.out + > ROUGE-1.5.4.pl -e data -c 95 -2 -1 -U -r 1000 -n 4 -w 1.2 -b 75 -m -s -a ROUGE-test.xml + + Sample DUC2002 data (1 system and 1 model only per DUC 2002 topic), their BE and + ROUGE evaluation configuration file in XML and file list format, + and their expected output are also included for your reference. + + (a) Use DUC2002-BE-F.in.26.lst, a BE files list, as ROUGE the + configuration file: + command> ROUGE-1.5.4.pl -3 HM -z SIMPLE DUC2002-BE-F.in.26.lst 26 + output: DUC2002-BE-F.in.26.lst.out + (b) Use DUC2002-BE-F.in.26.simple.xml as ROUGE XML evaluation configuration file: + command> ROUGE-1.5.4.pl -3 HM DUC2002-BE-F.in.26.simple.xml 26 + output: DUC2002-BE-F.in.26.simple.out + (c) Use DUC2002-BE-L.in.26.lst, a BE files list, as ROUGE the + configuration file: + command> ROUGE-1.5.4.pl -3 HM -z SIMPLE DUC2002-BE-L.in.26.lst 26 + output: DUC2002-BE-L.in.26.lst.out + (d) Use DUC2002-BE-L.in.26.simple.xml as ROUGE XML evaluation configuration file: + command> ROUGE-1.5.4.pl -3 HM DUC2002-BE-L.in.26.simple.xml 26 + output: DUC2002-BE-L.in.26.simple.out + (e) Use DUC2002-ROUGE.in.26.spl.lst, a BE files list, as ROUGE the + configuration file: + command> ROUGE-1.5.4.pl -n 4 -z SPL DUC2002-ROUGE.in.26.spl.lst 26 + output: DUC2002-ROUGE.in.26.spl.lst.out + (f) Use DUC2002-ROUGE.in.26.spl.xml as ROUGE XML evaluation configuration file: + command> ROUGE-1.5.4.pl -n 4 DUC2002-ROUGE.in.26.spl.xml 26 + output: DUC2002-ROUGE.in.26.spl.out + +<> + +(1) You need to have DB_File installed. If the Perl script complains + about database version incompatibility, you can create a new + WordNet-2.0.exc.db by running the buildExceptionDB.pl script in + the "data/WordNet-2.0-Exceptions" subdirectory. +(2) You also need to install XML::DOM from http://www.cpan.org. + Direct link: http://www.cpan.org/modules/by-module/XML/XML-DOM-1.43.tar.gz. + You might need install extra Perl modules that are required by + XML::DOM. +(3) Setup an environment variable ROUGE_EVAL_HOME that points to the + "data" subdirectory. For example, if your "data" subdirectory + located at "/usr/local/ROUGE-1.5.4/data" then you can setup + the ROUGE_EVAL_HOME as follows: + (a) Using csh or tcsh: + $command_prompt>setenv ROUGE_EVAL_HOME /usr/local/ROUGE-1.5.4/data + (b) Using bash + $command_prompt>ROUGE_EVAL_HOME=/usr/local/ROUGE-1.5.4/data + $command_prompt>export ROUGE_EVAL_HOME +(4) Run ROUGE-1.5.4.pl without supplying any arguments will give + you a description of how to use the ROUGE script. +(5) Please look into the included ROUGE-test.xml, verify.xml. and + verify-spl.xml evaluation configuration files for preparing your + own evaluation setup. More detailed description will be provided + later. ROUGE-test.xml and verify.xml specify the input from + systems and references are in SEE (Summary Evaluation Environment) + format (http://www.isi.edu/~cyl/SEE); while verify-spl.xml specify + inputs are in sentence per line format. + +<> + +(1) Please look into the "docs" directory for more information about + ROUGE. +(2) ROUGE-Note-v1.4.2.pdf explains how ROUGE works. It was published in + Proceedings of the Workshop on Text Summarization Branches Out + (WAS 2004), Bacelona, Spain, 2004. +(3) NAACL2003.pdf presents the initial idea of applying n-gram + co-occurrence statistics in automatic evaluation of + summarization. It was publised in Proceedsings of 2003 Language + Technology Conference (HLT-NAACL 2003), Edmonton, Canada, 2003. +(4) NTCIR2004.pdf discusses the effect of sa mple size on the + reliability of automatic evaluation results using data in the past + Document Understanding Conference (DUC) as examples. It was + published in Proceedings of the 4th NTCIR Meeting, Tokyo, Japan, 2004. +(5) ACL2004.pdf shows how ROUGE can be applied on automatic evaluation + of machine translation. It was published in Proceedings of the 42nd + Annual Meeting of the Association for Computational Linguistics + (ACL 2004), Barcelona, Spain, 2004. +(6) COLING2004.pdf proposes a new meta-evaluation framework, ORANGE, for + automatic evaluation of automatic evaluation methods. We showed + that ROUGE-S and ROUGE-L were significantly better than BLEU, + NIST, WER, and PER automatic MT evalaution methods under the + ORANGE framework. It was published in Proceedings of the 20th + International Conference on Computational Linguistics (COLING 2004), + Geneva, Switzerland, 2004. +(7) For information about BE, please go to http://www.isi.edu/~cyl/BE. + +<> + + Thanks for using the ROUGE evaluation package. If you have any +questions or comments, please send them to cyl@isi.edu. I will do my +best to answer your questions. diff --git a/model/third_party/HMNet/ThirdParty/ROUGE/ROUGE-1.5.5/RELEASE-NOTE.txt b/model/third_party/HMNet/ThirdParty/ROUGE/ROUGE-1.5.5/RELEASE-NOTE.txt new file mode 100644 index 0000000000000000000000000000000000000000..39547b9578e58fd99943b52150b398de158d4c11 --- /dev/null +++ b/model/third_party/HMNet/ThirdParty/ROUGE/ROUGE-1.5.5/RELEASE-NOTE.txt @@ -0,0 +1,232 @@ +# Revision Note: 05/26/2005, Chin-Yew LIN +# 1.5.5 +# (1) Correct stemming on multi-token BE heads and modifiers. +# Previously, only single token heads and modifiers were assumed. +# (2) Correct the resampling routine which ignores the last evaluation +# item in the evaluation list. Therefore, the average scores reported +# by ROUGE is only based on the first N-1 evaluation items. +# Thanks Barry Schiffman at Columbia University to report this bug. +# This bug only affects ROUGE-1.5.X. For pre-1.5 ROUGE, it only affects +# the computation of confidence interval (CI) estimation, i.e. CI is only +# estimated by the first N-1 evaluation items, but it *does not* affect +# average scores. +# (3) Change read_text and read_text_LCS functions to read exact words or +# bytes required by users. Previous versions carry out whitespace +# compression and other string clear up actions before enforce the length +# limit. +# 1.5.4.1 +# (1) Minor description change about "-t 0" option. +# 1.5.4 +# (1) Add easy evalution mode for single reference evaluations with -z +# option. +# 1.5.3 +# (1) Add option to compute ROUGE score based on SIMPLE BE format. Given +# a set of peer and model summary file in BE format with appropriate +# options, ROUGE will compute matching scores based on BE lexical +# matches. +# There are 6 options: +# 1. H : Head only match. This is similar to unigram match but +# only BE Head is used in matching. BEs generated by +# Minipar-based breaker do not include head-only BEs, +# therefore, the score will always be zero. Use HM or HMR +# optiions instead. +# 2. HM : Head and modifier match. This is similar to bigram or +# skip bigram but it's head-modifier bigram match based on +# parse result. Only BE triples with non-NIL modifier are +# included in the matching. +# 3. HMR : Head, modifier, and relation match. This is similar to +# trigram match but it's head-modifier-relation trigram +# match based on parse result. Only BE triples with non-NIL +# relation are included in the matching. +# 4. HM1 : This is combination of H and HM. It is similar to unigram + +# bigram or skip bigram with unigram match but it's +# head-modifier bigram match based on parse result. +# In this case, the modifier field in a BE can be "NIL" +# 5. HMR1 : This is combination of HM and HMR. It is similar to +# trigram match but it's head-modifier-relation trigram +# match based on parse result. In this case, the relation +# field of the BE can be "NIL". +# 6. HMR2 : This is combination of H, HM and HMR. It is similar to +# trigram match but it's head-modifier-relation trigram +# match based on parse result. In this case, the modifier and +# relation fields of the BE can both be "NIL". +# 1.5.2 +# (1) Add option to compute ROUGE score by token using the whole corpus +# as average unit instead of individual sentences. Previous versions of +# ROUGE uses sentence (or unit) boundary to break counting unit and takes +# the average score from the counting unit as the final score. +# Using the whole corpus as one single counting unit can potentially +# improve the reliablity of the final score that treats each token as +# equally important; while the previous approach considers each sentence as +# equally important that ignores the length effect of each individual +# sentences (i.e. long sentences contribute equal weight to the final +# score as short sentences.) +# +v1.2 provide a choice of these two counting modes that users can +# choose the one that fits their scenarios. +# 1.5.1 +# (1) Add precision oriented measure and f-measure to deal with different lengths +# in candidates and references. Importance between recall and precision can +# be controled by 'alpha' parameter: +# alpha -> 0: recall is more important +# alpha -> 1: precision is more important +# Following Chapter 7 in C.J. van Rijsbergen's "Information Retrieval". +# http://www.dcs.gla.ac.uk/Keith/Chapter.7/Ch.7.html +# F = 1/(alpha * (1/P) + (1 - alpha) * (1/R)) ;;; weighted harmonic mean +# 1.4.2 +# (1) Enforce length limit at the time when summary text is read. Previously (before +# and including v1.4.1), length limit was enforced at tokenization time. +# 1.4.1 +# (1) Fix potential over counting in ROUGE-L and ROUGE-W +# In previous version (i.e. 1.4 and order), LCS hit is computed +# by summing union hit over all model sentences. Each model sentence +# is compared with all peer sentences and mark the union LCS. The +# length of the union LCS is the hit of that model sentence. The +# final hit is then sum over all model union LCS hits. This potentially +# would over count a peer sentence which already been marked as contributed +# to some other model sentence. Therefore, double counting is resulted. +# This is seen in evalution where ROUGE-L score is higher than ROUGE-1 and +# this is not correct. +# ROUGEeval-1.4.1.pl fixes this by add a clip function to prevent +# double counting. +# 1.4 +# (1) Remove internal Jackknifing proce dure: +# Now the ROUGE script will use all the references listed in the +# section in each section and no +# automatic Jackknifing is performed. +# If Jackknifing procedure is required when comparing human and system +# performance, then users have to setup the procedure in the ROUGE +# evaluation configuration script as follows: +# For example, to evaluate system X with 4 references R1, R2, R3, and R4. +# We do the following computation: +# +# for system: and for comparable human: +# s1 = X vs. R1, R2, R3 h1 = R4 vs. R1, R2, R3 +# s2 = X vs. R1, R3, R4 h2 = R2 vs. R1, R3, R4 +# s3 = X vs. R1, R2, R4 h3 = R3 vs. R1, R2, R4 +# s4 = X vs. R2, R3, R4 h4 = R1 vs. R2, R3, R4 +# +# Average system score for X = (s1+s2+s3+s4)/4 and for human = (h1+h2+h3+h4)/4 +# Implementation of this in a ROUGE evaluation configuration script is as follows: +# Instead of writing all references in a evaluation section as below: +# +# ... +# +#

systemX +# +# +# R1 +# R2 +# R3 +# R4 +# +# +# we write the following: +# +# +#

systemX +# +# +# R2 +# R3 +# R4 +# +# +# +# +#

systemX +# +# +# R1 +# R3 +# R4 +# +# +# +# +#

systemX +# +# +# R1 +# R2 +# R4 +# +# +# +# +#

systemX +# +# +# R1 +# R2 +# R3 +# +# +# +# In this case, the system and human numbers are comparable. +# ROUGE as it is implemented for summarization evaluation is a recall-based metric. +# As we increase the number of references, we are increasing the number of +# count units (n-gram or skip-bigram or LCSes) in the target pool (i.e. +# the number ends up in the denominator of any ROUGE formula is larger). +# Therefore, a candidate summary has more chance to hit but it also has to +# hit more. In the end, this means lower absolute ROUGE scores when more +# references are used and using different sets of rerferences should not +# be compared to each other. There is no nomalization mechanism in ROUGE +# to properly adjust difference due to different number of references used. +# +# In the ROUGE implementations before v1.4 when there are N models provided for +# evaluating system X in the ROUGE evaluation script, ROUGE does the +# following: +# (1) s1 = X vs. R2, R3, R4, ..., RN +# (2) s2 = X vs. R1, R3, R4, ..., RN +# (3) s3 = X vs. R1, R2, R4, ..., RN +# (4) s4 = X vs. R1, R2, R3, ..., RN +# (5) ... +# (6) sN= X vs. R1, R2, R3, ..., RN-1 +# And the final ROUGE score is computed by taking average of (s1, s2, s3, +# s4, ..., sN). When we provide only three references for evaluation of a +# human summarizer, ROUGE does the same thing but using 2 out 3 +# references, get three numbers, and then take the average as the final +# score. Now ROUGE (after v1.4) will use all references without this +# internal Jackknifing procedure. The speed of the evaluation should improve +# a lot, since only one set instead of four sets of computation will be +# conducted. +# 1.3 +# (1) Add skip bigram +# (2) Add an option to specify the number of sampling point (default is 1000) +# 1.2.3 +# (1) Correct the enviroment variable option: -e. Now users can specify evironment +# variable ROUGE_EVAL_HOME using the "-e" option; previously this option is +# not active. Thanks Zhouyan Li of Concordia University, Canada pointing this +# out. +# 1.2.2 +# (1) Correct confidence interval calculation for median, maximum, and minimum. +# Line 390. +# 1.2.1 +# (1) Add sentence per line format input format. See files in Verify-SPL for examples. +# (2) Streamline command line arguments. +# (3) Use bootstrap resampling to estimate confidence intervals instead of using t-test +# or z-test which assume a normal distribution. +# (4) Add LCS (longest common subsequence) evaluation method. +# (5) Add WLCS (weighted longest common subsequence) evaluation method. +# (6) Add length cutoff in bytes. +# (7) Add an option to specify the longest ngram to compute. The default is 4. +# 1.2 +# (1) Change zero condition check in subroutine &computeNGramScores when +# computing $gram1Score from +# if($totalGram2Count!=0) to +# if($totalGram1Count!=0) +# Thanks Ken Litkowski for this bug report. +# This original script will set gram1Score to zero if there is no +# bigram matches. This should rarely has significant affect the final score +# since (a) there are bigram matches most of time; (b) the computation +# of gram1Score is using Jackknifing procedure. However, this definitely +# did not compute the correct $gram1Score when there is no bigram matches. +# Therefore, users of version 1.1 should definitely upgrade to newer +# version of the script that does not contain this bug. +# Note: To use this script, two additional data files are needed: +# (1) smart_common_words.txt - contains stopword list from SMART IR engine +# (2) WordNet-1.6.exc.db - WordNet 1.6 exception inflexion database +# These two files have to be put in a directory pointed by the environment +# variable: "ROUGE_EVAL_HOME". +# If environment variable ROUGE_EVAL_HOME does not exist, this script will +# will assume it can find these two database files in the current directory. diff --git a/model/third_party/HMNet/ThirdParty/ROUGE/ROUGE-1.5.5/ROUGE-1.5.5.pl b/model/third_party/HMNet/ThirdParty/ROUGE/ROUGE-1.5.5/ROUGE-1.5.5.pl new file mod e 100644 index 0000000000000000000000000000000000000000..974c667f8a308ce418f9206a8ff76c2f977bc367 --- /dev/null +++ b/model/third_party/HMNet/ThirdParty/ROUGE/ROUGE-1.5.5/ROUGE-1.5.5.pl @@ -0,0 +1,3300 @@ +#!/usr/bin/perl -w +# Add current dir to include +use File::Basename; +use lib dirname (__FILE__); + +# Version: ROUGE v1.5.5 +# Date: 05/26/2005,05/19/2005,04/26/2005,04/03/2005,10/28/2004,10/25/2004,10/21/2004 +# Author: Chin-Yew Lin +# Description: Given an evaluation description file, for example: test.xml, +# this script computes the averages of the average ROUGE scores for +# the evaluation pairs listed in the ROUGE evaluation configuration file. +# For more information, please see: +# http://www.isi.edu/~cyl/ROUGE +# For more information about Basic Elements, please see: +# http://www.isi.edu/~cyl/BE +# Revision Note: +# 1.5.5 +# (1) Correct stemming on multi-token BE heads and modifiers. +# Previously, only single token heads and modifiers were assumed. +# (2) Correct the resampling routine which ignores the last evaluation +# item in the evaluation list. Therefore, the average scores reported +# by ROUGE is only based on the first N-1 evaluation items. +# Thanks Barry Schiffman at Columbia University to report this bug. +# This bug only affects ROUGE-1.5.X. For pre-1.5 ROUGE, it only affects +# the computation of confidence interval (CI) estimation, i.e. CI is only +# estimated by the first N-1 evaluation items, but it *does not* affect +# average scores. +# (3) Change read_text and read_text_LCS functions to read exact words or +# bytes required by users. Previous versions carry out whitespace +# compression and other string clear up actions before enforce the length +# limit. +# 1.5.4.1 +# (1) Minor description change about "-t 0" option. +# 1.5.4 +# (1) Add easy evalution mode for single reference evaluations with -z +# option. +# 1.5.3 +# (1) Add option to compute ROUGE score based on SIMPLE BE format. Given +# a set of peer and model summary file in BE format with appropriate +# options, ROUGE will compute matching scores based on BE lexical +# matches. +# There are 6 options: +# 1. H : Head only match. This is similar to unigram match but +# only BE Head is used in matching. BEs generated by +# Minipar-based breaker do not include head-only BEs, +# therefore, the score will always be zero. Use HM or HMR +# optiions instead. +# 2. HM : Head and modifier match. This is similar to bigram or +# skip bigram but it's head-modifier bigram match based on +# parse result. Only BE triples with non-NIL modifier are +# included in the matching. +# 3. HMR : Head, modifier, and relation match. This is similar to +# trigram match but it's head-modifier-relation trigram +# match based on parse result. Only BE triples with non-NIL +# relation are included in the matching. +# 4. HM1 : This is combination of H and HM. It is similar to unigram + +# bigram or skip bigram with unigram match but it's +# head-modifier bigram match based on parse result. +# In this case, the modifier field in a BE can be "NIL" +# 5. HMR1 : This is combination of HM and HMR. It is similar to +# trigram match but it's head-modifier-relation trigram +# match based on parse result. In this case, the relation +# field of the BE can be "NIL". +# 6. HMR2 : This is combination of H, HM and HMR. It is similar to +# trigram match but it's head-modifier-relation trigram +# match based on parse result. In this case, the modifier and +# relation fields of the BE can both be "NIL". +# 1.5.2 +# (1) Add option to compute ROUGE score by token using the whole corpus +# as average unit instead of individual sentences. Previous versions of +# ROUGE uses sentence (or unit) boundary to break counting unit and takes +# the average score from the counting unit as the final score. +# Using the whole corpus as one single counting unit can potentially +# improve the reliablity of the final score that treats each token as +# equally important; while the previous approach considers each sentence as +# equally important that ignores the length effect of each individual +# sentences (i.e. long sentences contribute equal weight to the final +# score as short sentences.) +# +v1.2 provide a choice of these two counting modes that users can +# choose the one that fits their scenarios. +# 1.5.1 +# (1) Add precision oriented measure and f-measure to deal with different lengths +# in candidates and references. Importance between recall and precision can +# be controled by 'alpha' parameter: +# alpha -> 0: recall is more important +# alpha -> 1: precision is more important +# Following Chapter 7 in C.J. van Rijsbergen's "Information Retrieval". +# http://www.dcs.gla.ac.uk/Keith/Chapter.7/Ch.7.html +# F = 1/(alpha * (1/P) + (1 - alpha) * (1/R)) ;;; weighted harmonic mean +# 1.4.2 +# (1) Enforce length limit at the time when summary text is read. Previously (before +# and including v1.4.1), length limit was enforced at tokenization time. +# 1.4.1 +# (1) Fix potential over counting in ROUGE-L and ROUGE-W +# In previous version (i.e. 1.4 and order), LCS hit is computed +# by summing union hit over all model sentences. Each model sentence +# is compared with all peer sentences and mark the union LCS. The +# length of the union LCS is the hit of that model sentence. The +# final hit is then sum over all model union LCS hits. This potentially +# would over count a peer sentence which already been marked as contributed +# to some other model sentence. Therefore, double counting is resulted. +# This is seen in evalution where ROUGE-L score is higher than ROUGE-1 and +# this is not correct. +# ROUGEeval-1.4.1.pl fixes this by add a clip function to prevent +# double counting. +# 1.4 +# (1) Remove internal Jackknifing procedure: +# Now the ROUGE script will use all the references listed in the +# section in each section and no +# automatic Jackknifing is performed. Please see RELEASE-NOTE.txt +# for more details. +# 1.3 +# (1) Add skip bigram +# (2) Add an option to specify the number of sampling point (default is 1000) +# 1.2.3 +# (1) Correct the enviroment variable option: -e. Now users can specify evironment +# variable ROUGE_EVAL_HOME using the "-e" option; previously this option is +# not active. Thanks Zhouyan Li of Concordia University, Canada pointing this +# out. +# 1.2.2 +# (1) Correct confidence interval calculation for median, maximum, and minimum. +# Line 390. +# 1.2.1 +# (1) Add sentence per line format input format. See files in Verify-SPL for examples. +# (2) Streamline command line arguments. +# (3) Use bootstrap resampling to estimate confidence intervals instead of using t-test +# or z-test which assume a normal distribution. +# (4) Add LCS (longest common subsequence) evaluation method. +# (5) Add WLCS (weighted longest common subsequence) evaluation method. +# (6) Add length cutoff in bytes. +# (7) Add an option to specify the longest ngram to compute. The default is 4. +# 1.2 +# (1) Change zero condition check in subroutine &computeNGramScores when +# computing $gram1Score from +# if($totalGram2Count!=0) to +# if($totalGram1Count!=0) +# Thanks Ken Litkowski for this bug report. +# This original script will set gram1Score to zero if there is no +# bigram matches. This should rarely has significant affect the final score +# since (a) there are bigram matches most of time; (b) the computation +# of gram1Score is using Jackknifing procedure. However, this definitely +# did not compute the correct $gram1Score when there is no bigram matches. +# Therefore, users of version 1.1 should definitely upgrade to newer +# version of the script that does not contain this bug. +# Note: To use this script, two additional data files are needed: +# (1) smart_common_words.txt - contains stopword list from SMART IR engine +# (2) WordNet-2.0.exc.db - WordNet 2.0 exception inflexion database +# These two files have to be put in a directory pointed by the environment +# variable: "ROUGE_EVAL_HOME". +# If environment variable ROUGE_EVAL_HOME does not exist, this script will +# will assume it can find these two database files in the current directory. +# COPYRIGHT (C) UNIVERSITY OF SOUTHERN CALIFORNIA, 2002,2003,2004 +# University of Southern California +# Information Sciences Institute +# 4676 Admiralty Way +# Marina Del Rey, California 90292-6695 +# +# This software was partially developed under SPAWAR Grant No. +# N66001-00-1-8916 , and the Government holds license rights under +# DAR 7-104.9(a)(c)(1). It is +# transmitted outside of the University of Southern California only under +# written license agreements or software exchange agreements, and its use +# is limited by these agreements. At no time shall any recipient use +# this software in any manner which conflicts or interferes with the +# governmental license rights or other provisions of the governing +# agreement under which it is obtained. It is supplied "AS IS," without +# any warranties of any kind. It is furnished only on the basis that any +# party who receives it indemnifies and holds harmless the parties who +# furnish and originate it against any claims, demands or liabilities +# connected with using it, furnishing it to others or providing it to a +# third party. THIS NOTICE MUST NOT BE REMOVED FROM THE SOFTWARE, +# AND IN THE EVENT THAT THE SOFTWARE IS DIVIDED, IT SHOULD BE +# ATTACHED TO EVERY PART. +# +# Contributor to its design is Chin-Yew Lin. + +use XML::DOM; +use DB_File; +use Getopt::Std; +#------------------------------------------------------------------------------------- +use vars qw($opt_a $opt_b $opt_c $opt_d $opt_e $opt_f $opt_h $opt_H $opt_m $opt_n $opt_p $opt_s $opt_t $opt_l $opt_v $opt_w $opt_2 $opt_u $opt_x $opt_U $opt_3 $opt_M $opt_z); +my $usageFull="$0\n [-a (evaluate all systems)] + [-c cf] + [-d (print per evaluation scores)] + [-e ROUGE_EVAL_HOME] + [-h (usage)] + [-H (detailed usage)] + [-b n-bytes|-l n-words] + [-m (use Porter stemmer)] + [-n max-ngram] + [-s (remove stopwords)] + [-r number-of-samples (for resampling)] + [-2 max-gap-length (if < 0 then no gap length limit)] + [-3 (for scoring based on BE)] + [-u (include unigram in skip-bigram) default no)] + [-U (same as -u but also compute regular skip-bigram)] + [-w weight (weighting factor for WLCS)] + [-v (verbose)] + [-x (do not calculate ROUGE-L)] + [-f A|B (scoring formula)] + [-p alpha (0 <= alpha <=1)] + [-t 0|1|2 (count by token instead of sentence)] + [-z ] + []\n +". + "ROUGE-eval-config-file: Specify the evaluation setup. Three files come with the ROUGE evaluation package, i.e.\n". + " ROUGE-test.xml, verify.xml, and verify-spl.xml are good examples.\n". + "systemID: Specify which system in the ROUGE-eval-config-file to perform the evaluation.\n". + " If '-a' option is used, then all systems are evaluated and users do not need to\n". + " provide this argument.\n". + "Default:\n". + " When running ROUGE without supplying any options (except -a), the following defaults are used:\n". + " (1) ROUGE-L is computed;\n". + " (2) 95% confidence interval;\n". + " (3) No stemming;\n". + " (4) Stopwords are inlcuded in the calculations;\n". + " (5) ROUGE looks for its data directory first through the ROUGE_EVAL_HOME environment variable. If\n". + " it is not set, the current directory is used.\n". + " (6) Use model average scoring formula.\n". + " (7) Assign equal importance of ROUGE recall and precision in computing ROUGE f-measure, i.e. alpha=0.5.\n". + " (8) Compute average ROUGE by averaging sentence (unit) ROUGE scores.\n". + "Options:\n". + " -2: Compute skip bigram (ROGUE-S) co-occurrence, also specify the maximum gap length between two words (skip-bigram)\n". + " -u: Compute skip bigram as -2 but include unigram, i.e. treat unigram as \"start-sentence-symbol unigram\"; -2 has to be specified.\n". + " -3: Compute BE score. Currently only SIMPLE BE triple format is supported.\n". + " H -> head only scoring (does not applied to Minipar-based BEs).\n". + " HM -> head and modifier pair scoring.\n". + " HMR -> head, modifier and relation triple scoring.\n". + " HM1 -> H and HM scoring (same as HM for Minipar-based BEs).\n". + " HMR1 -> HM and HMR scoring (same as HMR for Minipar-based BEs).\n". + " HMR2 -> H, HM and HMR scoring (same as HMR for Minipar-based BEs).\n". + " -a: Evaluate all systems specified in the ROUGE-eval-config-file.\n". + " -c: Specify CF\% (0 <= CF <= 100) confidence interval to compute. The default is 95\% (i.e. CF=95).\n". + " -d: Print per evaluation average score for each system.\n". + " -e: Specify ROUGE_EVAL_HOME directory where the ROUGE data files can be found.\n". + " This will overwrite the ROUGE_EVAL_HOME specified in the environment variable.\n". + " -f: Select scoring formula: 'A' => model average; 'B' => best model\n". + " -h: Print usage information.\n". + " -H: Print detailed usage information.\n". + " -b: Only use the first n bytes in the system/peer summary for the evaluation.\n". + " -l: Only use the first n words in the system/peer summary for the evaluation.\n". + " -m: Stem both model and system summaries using Porter stemmer before computing various statistics.\n". + " -n: Compute ROUGE-N up to max-ngram length will be computed.\n". + " -p: Relative importance of recall and precision ROUGE scores. Alpha -> 1 favors precision, Alpha -> 0 favors recall.\n". + " -s: Remove stopwords in model and system summaries before computing various statistics.\n". + " -t: Compute average ROUGE by averaging over the whole test corpus instead of sentences (units).\n". + " 0: use sentence as counting unit, 1: use token as couting unit, 2: same as 1 but output raw counts\n". + " instead of precision, recall, and f-measure scores. 2 is useful when computation of the final,\n". + " precision, recall, and f-measure scores will be conducted later.\n". + " -r: Specify the number of sampling point in bootstrap resampling (default is 1000).\n". + " Smaller number will speed up the evaluation but less reliable confidence interval.\n". + " -w: Compute ROUGE-W that gives consecutive matches of length L in an LCS a weight of 'L^weight' instead of just 'L' as in LCS.\n". + " Typically this is set to 1.2 or other number greater than 1.\n". + " -v: Print debugging information for diagnositic purpose.\n". + " -x: Do not calculate ROUGE-L.\n". + " -z: ROUGE-eval-config-file is a list of peer-model pair per line in the specified format (SEE|SPL|ISI|SIMPLE).\n"; + +my $usage="$0\n [-a (evaluate all systems)] + [-c cf] + [-d (print per evaluation scores)] + [-e ROUGE_EVAL_HOME] + [-h (usage)] + [-H (detailed usage)] + [-b n-bytes|-l n-words] + [-m (use Porter stemmer)] + [-n max-ngram] + [-s (remove stopwords)] + [-r number-of-samples (for resampling)] + [-2 max-gap-length (if < 0 then no gap length limit)] + [-3 (for scoring based on BE)] + [-u (include unigram in skip-bigram) default no)] + [-U (same as -u but also compute regular skip-bigram)] + [-w weight (weighting factor for WLCS)] + [-v (verbose)] + [-x (do not calculate ROUGE-L)] + [-f A|B (scoring formula)] + [-p alpha (0 <= alpha <=1)] + [-t 0|1|2 (count by token instead of sentence)] + [-z ] + [] +"; +getopts('ahHb:c:de:f:l:mMn:p:st:r:2:3:w:uUvxz:'); +my $systemID; + +die $usageFull if defined($opt_H); +die $usage if defined($opt_h)||@ARGV==0; +die "Please specify the ROUGE configuration file or use option '-h' for help\n" if(@ARGV==0); +if(@ARGV==1&&defined($opt_z)) { + $systemID="X"; # default system ID +} +elsif(@ARGV==1&&!defined($opt_a)) { + die "Please specify a system ID to evaluate or use option '-a' to evaluate all systems. For more information, use option '-h'.\n"; +} +elsif(@ARGV==2) { + $systemID=$ARGV[1]; +} +if(defined($opt_e)) { + $stopwords="$opt_e/smart_common_words.txt"; + $wordnetDB="$opt_e/WordNet-2.0.exc.db"; +} +else { + if(exists($ENV{"ROUGE_EVAL_HOME"})) { + $stopwords="$ENV{\"ROUGE_EVAL_HOME\"}/smart_common_words.txt"; + $wordnetDB="$ENV{\"ROUGE_EVAL_HOME\"}/WordNet-2.0.exc.db"; + } + elsif(exists($ENV{"RED_EVAL_HOME"})) { + $stopwords="$ENV{\"RED_EVAL_HOME\"}/smart_common_words.txt"; + $wordnetDB="$ENV{\"RED_EVAL_HOME\"}/WordNet-2.0.exc.db"; + } + else { + # if no environment variable exists then assume data files are in the current directory + $stopwords="smart_common_words.txt"; + $wordnetDB="WordNet-2.0.exc.db"; + } +} + +if(defined($opt_s)) { + $useStopwords=0; # do not use stop words +} +else { + $useStopwords=1; # use stop words +} + +if(defined($opt_l)&&defined($opt_b)) { + die "Please specify length limit in words or bytes but not both.\n"; +} + +if(defined($opt_l)) { + $lengthLimit=$opt_l; + $byteLimit=0; # no byte limit +} +elsif(defined($opt_b)) { + $lengthLimit=0; # no length limit in words + $byteLimit=$opt_b; +} +else { + $byteLimit=0; # no byte limit + $lengthLimit=0; # no length limit +} + +unless(defined($opt_c)) { + $opt_c=95; +} +else { + if($opt_c<0||$opt_c>100) { + die "Confidence interval should be within 0 and 100. Use option -h for more details.\n"; + } +} + +if(defined($opt_w)) { + if($opt_w>0) { + $weightFactor=$opt_w; + } + else { + die "ROUGE-W weight factor must greater than 0.\n"; + } +} +#unless(defined($opt_n)) { +# $opt_n=4; # default maximum ngram is 4 +#} +if(defined($opt_v)) { + $debug=1; +} +else { + $debug=0; +} + +if(define d($opt_r)) { + $numOfResamples=$opt_r; +} +else { + $numOfResamples=1000; +} + +if(defined($opt_2)) { + $skipDistance=$opt_2; +} + +if(defined($opt_3)) { + $BEMode=$opt_3; +} + +if(defined($opt_f)) { + $scoreMode=$opt_f; +} +else { + $scoreMode="A"; # default: use model average scoring formula +} + +if(defined($opt_p)) { + $alpha=$opt_p; + if($alpha<0|| + $alpha>1) { + die "Relative importance of ROUGE recall and precision has to be between 0 and 1 inclusively.\n"; + } +} +else { + $alpha=0.5; # default is equal importance of ROUGE recall and precision +} + +if(defined($opt_t)) { + # make $opt_t as undef when appropriate option is given + # when $opt_t is undef, sentence level average will be used + if($opt_t==0) { + $opt_t=undef; + } + elsif($opt_t!=1&& + $opt_t!=2) { + $opt_t=undef; # other than 1 or 2, let $opt_t to be undef + } +} + +if(defined($opt_z)) { + # If opt_z is specified, the user has to specify a system ID that + # is used for identification therefore -a option is not allowed. + # Here we make it undef. + $opt_a=undef; +} +#------------------------------------------------------------------------------------- +# Setup ROUGE scoring parameters +%ROUGEParam=(); # ROUGE scoring parameter +if(defined($lengthLimit)) { + $ROUGEParam{"LENGTH"}=$lengthLimit; +} +else { + $ROUGEParam{"LENGTH"}=undef; +} +if(defined($byteLimit)) { + $ROUGEParam{"BYTE"}=$byteLimit; +} +else { + $ROUGEParam{"BYTE"}=undef; +} +if(defined($opt_n)) { # ngram size + $ROUGEParam{"NSIZE"}=$opt_n; +} +else { + $ROUGEParam{"NSIZE"}=undef; +} +if(defined($weightFactor)) { + $ROUGEParam{"WEIGHT"}=$weightFactor; +} +else { + $ROUGEParam{"WEIGHT"}=undef; +} +if(defined($skipDistance)) { + $ROUGEParam{"SD"}=$skipDistance; +} +else { + $ROUGEParam{"SD"}=undef; +} +if(defined($scoreMode)) { + $ROUGEParam{"SM"}=$scoreMode; +} +else { + $ROUGEParam{"SM"}=undef; +} +if(defined($alpha)) { + $ROUGEParam{"ALPHA"}=$alpha; +} +else { + $ROUGEParam{"ALPHA"}=undef; +} +if(defined($opt_t)) { + $ROUGEParam{"AVERAGE"}=$opt_t; +} +else { + $ROUGEParam{"AVERAGE"}=undef; +} +if(defined($opt_3)) { + $ROUGEParam{"BEMODE"}=$opt_3; +} +else { + $ROUGEParam{"BEMODE"}=undef; +} +#------------------------------------------------------------------------------------- +# load stopwords +%stopwords=(); +open(STOP,$stopwords)||die "Cannot open $stopwords\n"; +while(defined($line=)) { + chomp($line); + $stopwords{$line}=1; +} +close(STOP); +# load WordNet database +if(-e "$wordnetDB") { + tie %exceptiondb,'DB_File',"$wordnetDB",O_RDONLY,0440,$DB_HASH or + die "Cannot open exception db file for reading: $wordnetDB\n"; +} +else { + die "Cannot open exception db file for reading: $wordnetDB\n"; +} +#------------------------------------------------------------------------------------- +# Initialize Porter Stemmer +&initialise(); +#------------------------------------------------------------------------------------- +# Read and parse the document +my $parser = new XML::DOM::Parser; +my $doc; +unless(defined($opt_z)) { + $doc=$parser->parsefile($ARGV[0]); +} +else { + open($doc,$ARGV[0])||die "Cannot open $ARGV[0]\n"; +} +%ROUGEEvals=(); +@ROUGEEvalIDs=(); +%ROUGEPeerIDTable=(); +@allPeerIDs=(); +%knownMissing=(); # remember missing submission already known +if(defined($doc)) { + # read evaluation description file + &readEvals(\%ROUGEEvals,\@ROUGEEvalIDs,\%ROUGEPeerIDTable,$doc,undef); + # print evaluation configuration + if(defined($opt_z)) { + if(defined($ARGV[1])) { + $systemID=$ARGV[1]; + } + else { + $systemID="X"; # default system ID in BE file list evaluation mode + } + push(@allPeerIDs,$systemID); + } + else { + unless(defined($opt_a)) { + $systemID=$ARGV[1]; + push(@allPeerIDs,$systemID); + } + else { + # run evaluation for each peer listed in the description file + @allPeerIDs=sort (keys %ROUGEPeerIDTable); + } + } + foreach $peerID (@allPeerIDs) { + %testIDs=(); + # print "\@PEER($peerID)-------------------------------------------------- \n"; + if(defined($opt_n)) { + # evaluate a specific peer + # compute ROUGE score up to $opt_n-gram + for($n=1;$n<=$opt_n;$n++) { + my (%ROUGEScores,%ROUGEAverages); + + %ROUGEScores=(); + foreach $e (@ROUGEEvalIDs) { + if($debug) { + print "\@Eval ($e)\n"; + } + $ROUGEParam{"NSIZE"}=$n; + &computeROUGEX("N",\%ROUGEScores,$e,$ROUGEEvals{$e},$peerID,\%ROUGEParam); + } + # compute averages + %ROUGEAverages=(); + &computeAverages(\%ROUGEScores,\%ROUGEAverages,$opt_t); + &printResults($peerID,\%ROUGEAverages,\%ROUGEScores,"ROUGE-$n",$opt_c,$opt_t,$opt_d); + } + } + unless(defined($opt_x)||defined($opt_3)) { + #----------------------------------------------- + # compute LCS score + %ROUGEScores=(); + foreach $e (@ROUGEEvalIDs) { + &computeROUGEX("L",\%ROUGEScores,$e,$ROUGEEvals{$e},$peerID,\%ROUGEParam); + } + # compute averages + %ROUGEAverages=(); + &computeAverages(\%ROUGEScores,\%ROUGEAverages,$opt_t); + &printResults($peerID,\%ROUGEAverages,\%ROUGEScores,"ROUGE-L",$opt_c,$opt_t,$opt_d); + } + if(defined($opt_w)) { + #----------------------------------------------- + # compute WLCS score + %ROUGEScores=(); + foreach $e (@ROUGEEvalIDs) { + &computeROUGEX("W",\%ROUGEScores,$e,$ROUGEEvals{$e},$peerID,\%ROUGEParam); + } + # compute averages + %ROUGEAverages=(); + &computeAverages(\%ROUGEScores,\%ROUGEAverages,$opt_t); + &printResults($peerID,\%ROUGEAverages,\%ROUGEScores,"ROUGE-W-$weightFactor",$opt_c,$opt_t,$opt_d); + } + if(defined($opt_2)) { + #----------------------------------------------- + # compute skip bigram score + %ROUGEScores=(); + foreach $e (@ROUGEEvalIDs) { + &computeROUGEX("S",\%ROUGEScores,$e,$ROUGEEvals{$e},$peerID,\%ROUGEParam); + } + # compute averages + %ROUGEAverages=(); + &computeAverages(\%ROUGEScores,\%ROUGEAverages,$opt_t); + if($skipDistance>=0) { + if(defined($opt_u)) { + &printResults($peerID,\%ROUGEAverages,\%ROUGEScores,"ROUGE-SU$skipDistance",$opt_c,$opt_t,$opt_d); + } + elsif(defined($opt_U)) { + # print regular skip bigram results + &printResults($peerID,\%ROUGEAverages,\%ROUGEScores,"ROUGE-S$skipDistance",$opt_c,$opt_t,$opt_d); + #----------------------------------------------- + # compute skip bigram with unigram extension score + $opt_u=1; + %ROUGEScores=(); + foreach $e (@ROUGEEvalIDs) { + &computeROUGEX("S",\%ROUGEScores,$e,$ROUGEEvals{$e},$peerID,\%ROUGEParam); + } + $opt_u=undef; + # compute averages + %ROUGEAverages=(); + &computeAverages(\%ROUGEScores,\%ROUGEAverages,$opt_t); + &printResults($peerID,\%ROUGEAverages,\%ROUGEScores,"ROUGE-SU$skipDistance",$opt_c,$opt_t,$opt_d); + } + else { + &printResults($peerID,\%ROUGEAverages,\%ROUGEScores,"ROUGE-S$skipDistance",$opt_c,$opt_t,$opt_d); + } + } + else { + if(defined($opt_u)) { + &printResults($peerID,\%ROUGEAverages,\%ROUGEScores,"ROUGE-SU*",$opt_c,$opt_t,$opt_d); + } + else { + &printResults($peerID,\%ROUGEAverages,\%ROUGEScores,"ROUGE-S*",$opt_c,$opt_t,$opt_d); + if(defined($opt_U)) { + #----------------------------------------------- + # compute skip bigram with unigram extension score + $opt_u=1; + %ROUGEScores=(); + foreach $e (@ROUGEEvalIDs) { + &computeROUGEX("S",\%ROUGEScores,$e,$ROUGEEvals{$e},$peerID,\%ROUGEParam); + } + $opt_u=undef; + # compute averages + %ROUGEAverages=(); + &computeAverages(\%ROUGEScores,\%ROUGEAverages,$opt_t); + &printResults($peerID,\%ROUGEAverages,\%ROUGEScores,"ROUGE-SU*",$opt_c,$opt_t,$opt_d); + } + } + } + } + if(defined($opt_3)) { + #----------------------------------------------- + # compute Basic Element triple score + %ROUGEScores=(); + foreach $e (@ROUGEEvalIDs) { + &computeROUGEX("BE",\%ROUGEScores,$e,$ROUGEEvals{$e},$peerID,\%ROUGEParam); + } + # compute averages + %ROUGEAverages=(); + &computeAverages(\%ROUGEScores,\%ROUGEAverages,$opt_t); + &printR esults($peerID,\%ROUGEAverages,\%ROUGEScores,"ROUGE-BE-$BEMode",$opt_c,$opt_t,$opt_d); + } + } +} +else { + die "Document undefined\n"; +} +if(defined($opt_z)) { + close($doc); +} +untie %exceptiondb; + +sub printResults { + my $peerID=shift; + my $ROUGEAverages=shift; + my $ROUGEScores=shift; + my $methodTag=shift; + my $opt_c=shift; + my $opt_t=shift; + my $opt_d=shift; + + print "---------------------------------------------\n"; + if(!defined($opt_t)||$opt_t==1) { + print "$peerID $methodTag Average_R: $ROUGEAverages->{'AvgR'} "; + print "($opt_c\%-conf.int. $ROUGEAverages->{'CIAvgL_R'} - $ROUGEAverages->{'CIAvgU_R'})\n"; + print "$peerID $methodTag Average_P: $ROUGEAverages->{'AvgP'} "; + print "($opt_c\%-conf.int. $ROUGEAverages->{'CIAvgL_P'} - $ROUGEAverages->{'CIAvgU_P'})\n"; + print "$peerID $methodTag Average_F: $ROUGEAverages->{'AvgF'} "; + print "($opt_c\%-conf.int. $ROUGEAverages->{'CIAvgL_F'} - $ROUGEAverages->{'CIAvgU_F'})\n"; + } + else { + print "$peerID $methodTag M_count: "; + print int($ROUGEAverages->{'M_cnt'}); + print " P_count: "; + print int($ROUGEAverages->{'P_cnt'}); + print " H_count: "; + print int($ROUGEAverages->{'H_cnt'}); + print "\n"; + } + if(defined($opt_d)) { + print ".............................................\n"; + &printPerEvalData($ROUGEScores,"$peerID $methodTag Eval"); + } +} + +sub bootstrapResampling { + my $scores=shift; + my $instances=shift; + my $seed=shift; + my $opt_t=shift; + my $sample; + my ($i,$ridx); + + # Use $seed to seed the random number generator to make sure + # we have the same random sequence every time, therefore a + # consistent estimation of confidence interval in different runs. + # This is not necessary. To ensure a consistent result in reporting + # results using ROUGE, this is implemented. + srand($seed); + for($i=0;$i<@{$instances};$i++) { + # generate a random index + $ridx=int(rand(@{$instances})); + unless(defined($sample)) { + # setup the resampling array + $sample=[]; + push(@$sample,$scores->{$instances->[$ridx]}[0]); + push(@$sample,$scores->{$instances->[$ridx]}[1]); + push(@$sample,$scores->{$instances->[$ridx]}[2]); + } + else { + # update the resampling array + $sample->[0]+=$scores->{$instances->[$ridx]}[0]; + $sample->[1]+=$scores->{$instances->[$ridx]}[1]; + $sample->[2]+=$scores->{$instances->[$ridx]}[2]; + } + } + # compute the average result for this resampling procedure + unless(defined($opt_t)) { + # per instance or sentence average + if(@{$instances}>0) { + $sample->[0]/=@{$instances}; + $sample->[1]/=@{$instances}; + $sample->[2]/=@{$instances}; + } + else { + $sample->[0]=0; + $sample->[1]=0; + $sample->[2]=0; + } + } + else { + if($opt_t==1) { + # per token or corpus level average + # output recall, precision, and f-measure score + my ($tmpR,$tmpP,$tmpF); + if($sample->[0]>0) { + $tmpR=$sample->[2]/$sample->[0]; # recall + } + else { + $tmpR=0; + } + if($sample->[1]>0) { + $tmpP=$sample->[2]/$sample->[1]; # precision + } + else { + $tmpP=0; + } + if((1-$alpha)*$tmpP+$alpha*$tmpR>0) { + $tmpF=($tmpR*$tmpP)/((1-$alpha)*$tmpP+$alpha*$tmpR); # f-measure + } + else { + $tmpF=0; + } + $sample->[0]=$tmpR; + $sample->[1]=$tmpP; + $sample->[2]=$tmpF; + } + else { + # $opt_t!=1 => output raw model token count, peer token count, and hit count + # do nothing, just return $sample + } + } + return $sample; +} + +sub by_value { + $a<=>$b; +} + +sub printPerEvalData { + my $ROUGEScores=shift; + my $tag=shift; # tag to identify each evaluation + my (@instances,$i,$j); + + @instances=sort by_evalID (keys %$ROUGEScores); + foreach $i (@instances) { + # print average per evaluation score + print "$tag $i R:$ROUGEScores->{$i}[0] P:$ROUGEScores->{$i}[1] F:$ROUGEScores->{$i}[2]\n"; + } +} + +sub by_evalID { + my ($a1,$b1); + + if($a=~/ ^([0-9]+)/o) { + $a1=$1; + } + if($b=~/^([0-9]+)/o) { + $b1=$1; + } + if(defined($a1)&&defined($b1)) { + return $a1<=>$b1; + } + else { + return $a cmp $b; + } +} + +sub computeAverages { + my $ROUGEScores=shift; + my $ROUGEAverages=shift; + my $opt_t=shift; + my ($avgAvgROUGE_R,$resampleAvgROUGE_R); + my ($avgAvgROUGE_P,$resampleAvgROUGE_P); + my ($avgAvgROUGE_F,$resampleAvgROUGE_F); + my ($ciU,$ciL); + my (@instances,$i,$j,@rankedArray_R,@rankedArray_P,@RankedArray_F); + + @instances=sort (keys %$ROUGEScores); + $avgAvgROUGE_R=0; + $avgAvgROUGE_P=0; + $avgAvgROUGE_F=0; + $resampleAvgROUGE_R=0; + $resampleAvgROUGE_P=0; + $resampleAvgROUGE_F=0; + # compute totals + foreach $i (@instances) { + $avgAvgROUGE_R+=$ROUGEScores->{$i}[0]; # recall ; or model token count + $avgAvgROUGE_P+=$ROUGEScores->{$i}[1]; # precision ; or peer token count + $avgAvgROUGE_F+=$ROUGEScores->{$i}[2]; # f1-measure ; or match token count (hit) + } + # compute averages + unless(defined($opt_t)) { + # per sentence average + if((scalar @instances)>0) { + $avgAvgROUGE_R=sprintf("%7.5f",$avgAvgROUGE_R/(scalar @instances)); + $avgAvgROUGE_P=sprintf("%7.5f",$avgAvgROUGE_P/(scalar @instances)); + $avgAvgROUGE_F=sprintf("%7.5f",$avgAvgROUGE_F/(scalar @instances)); + } + else { + $avgAvgROUGE_R=sprintf("%7.5f",0); + $avgAvgROUGE_P=sprintf("%7.5f",0); + $avgAvgROUGE_F=sprintf("%7.5f",0); + } + } + else { + if($opt_t==1) { + # per token average on corpus level + my ($tmpR,$tmpP,$tmpF); + if($avgAvgROUGE_R>0) { + $tmpR=$avgAvgROUGE_F/$avgAvgROUGE_R; + } + else { + $tmpR=0; + } + if($avgAvgROUGE_P>0) { + $tmpP=$avgAvgROUGE_F/$avgAvgROUGE_P; + } + else { + $tmpP=0; + } + if((1-$alpha)*$tmpP+$alpha*$tmpR>0) { + $tmpF=($tmpR+$tmpP)/((1-$alpha)*$tmpP+$alpha*$tmpR); + } + else { + $tmpF=0; + } + $avgAvgROUGE_R=sprintf("%7.5f",$tmpR); + $avgAvgROUGE_P=sprintf("%7.5f",$tmpP); + $avgAvgROUGE_F=sprintf("%7.5f",$tmpF); + } + } + if(!defined($opt_t)||$opt_t==1) { + # compute confidence intervals using bootstrap resampling + @ResamplingArray=(); + for($i=0;$i<$numOfResamples;$i++) { + my $sample; + + $sample=&bootstrapResampling($ROUGEScores,\@instances,$i,$opt_t); + # sample contains average sum of the sample + if(@ResamplingArray==0) { + # setup the resampling array for Avg + my $s; + + $s=[]; + push(@$s,$sample->[0]); + push(@ResamplingArray,$s); + $s=[]; + push(@$s,$sample->[1]); + push(@ResamplingArray,$s); + $s=[]; + push(@$s,$sample->[2]); + push(@ResamplingArray,$s); + } + else { + $rsa=$ResamplingArray[0]; + push(@{$rsa},$sample->[0]); + $rsa=$ResamplingArray[1]; + push(@{$rsa},$sample->[1]); + $rsa=$ResamplingArray[2]; + push(@{$rsa},$sample->[2]); + } + } + # sort resampling results + { + # recall + @rankedArray_R=sort by_value (@{$ResamplingArray[0]}); + $ResamplingArray[0]=\@rankedArray_R; + for($x=0;$x<=$#rankedArray_R;$x++) { + $resampleAvgROUGE_R+=$rankedArray_R[$x]; + # print "*R ($x): $rankedArray_R[$x]\n"; + } + $resampleAvgROUGE_R=sprintf("%7.5f",$resampleAvgROUGE_R/(scalar @rankedArray_R)); + # precision + @rankedArray_P=sort by_value (@{$ResamplingArray[1]}); + $ResamplingArray[1]=\@rankedArray_P; + for($x=0;$x<=$#rankedArray_P;$x++) { + $resampleAvgROUGE_P+=$rankedArray_P[$x]; + # print "*P ($x): $rankedArray_P[$x]\n"; + } + $resampleAvgROUGE_P=sprintf("%7.5f",$resampleAvgROUGE_P/(scalar @rankedArray_P)); + # f1-measure + @rankedArray_F=sort by_value (@{$ResamplingArray[2]}); + $ResamplingArray[2]=\@rankedArray_F; + for($x=0;$x<=$#rankedArray_F;$x++) { + $resampleAvgROUGE_F+=$rankedArray_F[$x]; + # print "*F ($x): $rankedArray_F[$x]\n"; + } + $resampleAvgROUGE_F=sprintf("%7.5f",$resampleAvgROUGE_F/(scalar @rankedArray_F)); + } + # $ciU=999-int((100-$opt_c)*10/2); # upper bound index + # $ciL=int((100-$opt_ c)*10/2); # lower bound index + $delta=$numOfResamples*((100-$opt_c)/2.0)/100.0; + $ciUa=int($numOfResamples-$delta-1); # upper confidence interval lower index + $ciUb=$ciUa+1; # upper confidence interval upper index + $ciLa=int($delta); # lower confidence interval lower index + $ciLb=$ciLa+1; # lower confidence interval upper index + $ciR=$numOfResamples-$delta-1-$ciUa; # ratio bewteen lower and upper indexes + # $ROUGEAverages->{"AvgR"}=$avgAvgROUGE_R; + #------- + # recall + $ROUGEAverages->{"AvgR"}=$resampleAvgROUGE_R; + # find condifence intervals; take maximum distance from the mean + $ROUGEAverages->{"CIAvgL_R"}=sprintf("%7.5f",$ResamplingArray[0][$ciLa]+ + ($ResamplingArray[0][$ciLb]-$ResamplingArray[0][$ciLa])*$ciR); + $ROUGEAverages->{"CIAvgU_R"}=sprintf("%7.5f",$ResamplingArray[0][$ciUa]+ + ($ResamplingArray[0][$ciUb]-$ResamplingArray[0][$ciUa])*$ciR); + #------- + # precision + $ROUGEAverages->{"AvgP"}=$resampleAvgROUGE_P; + # find condifence intervals; take maximum distance from the mean + $ROUGEAverages->{"CIAvgL_P"}=sprintf("%7.5f",$ResamplingArray[1][$ciLa]+ + ($ResamplingArray[1][$ciLb]-$ResamplingArray[1][$ciLa])*$ciR); + $ROUGEAverages->{"CIAvgU_P"}=sprintf("%7.5f",$ResamplingArray[1][$ciUa]+ + ($ResamplingArray[1][$ciUb]-$ResamplingArray[1][$ciUa])*$ciR); + #------- + # f1-measure + $ROUGEAverages->{"AvgF"}=$resampleAvgROUGE_F; + # find condifence intervals; take maximum distance from the mean + $ROUGEAverages->{"CIAvgL_F"}=sprintf("%7.5f",$ResamplingArray[2][$ciLa]+ + ($ResamplingArray[2][$ciLb]-$ResamplingArray[2][$ciLa])*$ciR); + $ROUGEAverages->{"CIAvgU_F"}=sprintf("%7.5f",$ResamplingArray[2][$ciUa]+ + ($ResamplingArray[2][$ciUb]-$ResamplingArray[2][$ciUa])*$ciR); + $ROUGEAverages->{"M_cnt"}=$avgAvgROUGE_R; # model token count + $ROUGEAverages->{"P_cnt"}=$avgAvgROUGE_P; # peer token count + $ROUGEAverages->{"H_cnt"}=$avgAvgROUGE_F; # hit token count + } + else { + # $opt_t==2 => output raw count instead of precision, recall, and f-measure values + # in this option, no resampling is necessary, just output the raw counts + $ROUGEAverages->{"M_cnt"}=$avgAvgROUGE_R; # model token count + $ROUGEAverages->{"P_cnt"}=$avgAvgROUGE_P; # peer token count + $ROUGEAverages->{"H_cnt"}=$avgAvgROUGE_F; # hit token count + } +} + +sub computeROUGEX { + my $metric=shift; # which ROUGE metric to compute? + my $ROUGEScores=shift; + my $evalID=shift; + my $ROUGEEval=shift; # one particular evaluation pair + my $peerID=shift; # a specific peer ID + my $ROUGEParam=shift; # ROUGE scoring parameters + my $lengthLimit; # lenght limit in words + my $byteLimit; # length limit in bytes + my $NSIZE; # ngram size for ROUGE-N + my $weightFactor; # weight factor for ROUGE-W + my $skipDistance; # skip distance for ROUGE-S + my $scoreMode; # scoring mode: A = model average; B = best model + my $alpha; # relative importance between recall and precision + my $opt_t; # ROUGE score counting mode + my $BEMode; # Basic Element scoring mode + my ($c,$cx,@modelPaths,$modelIDs,$modelRoot,$inputFormat); + + $lengthLimit=$ROUGEParam->{"LENGTH"}; + $byteLimit=$ROUGEParam->{"BYTE"}; + $NSIZE=$ROUGEParam->{"NSIZE"}; + $weightFactor=$ROUGEParam->{"WEIGHT"}; + $skipDistance=$ROUGEParam->{"SD"}; + $scoreMode=$ROUGEParam->{"SM"}; + $alpha=$ROUGEParam->{"ALPHA"}; + $opt_t=$ROUGEParam->{"AVERAGE"}; + $BEMode=$ROUGEParam->{"BEMODE"}; + + # Check to see if this evaluation trial contains this $peerID. + # Sometimes not every peer provides response for each + # evaluation trial. + unless(exists($ROUGEEval->{"Ps"}{$peerID