aliabd commited on
Commit
7e3e85d
1 Parent(s): d26e36a

full demo working with old graido

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .idea/SummerTime.iml +8 -0
  2. .idea/inspectionProfiles/Project_Default.xml +16 -0
  3. .idea/inspectionProfiles/profiles_settings.xml +6 -0
  4. .idea/modules.xml +8 -0
  5. LICENSE +202 -0
  6. README.md +1 -1
  7. SummerTime.egg-info/PKG-INFO +124 -0
  8. SummerTime.egg-info/SOURCES.txt +46 -0
  9. SummerTime.egg-info/dependency_links.txt +1 -0
  10. SummerTime.egg-info/top_level.txt +4 -0
  11. __init__.py +3 -0
  12. app.py +28 -0
  13. build/scripts-3.9/summertime +3 -0
  14. dataset/__init__.py +36 -0
  15. dataset/dataset_loaders.py +501 -0
  16. dataset/non_huggingface_datasets_builders/arxiv_longsummarization.py +104 -0
  17. dataset/non_huggingface_datasets_builders/qmsum.py +119 -0
  18. dataset/non_huggingface_datasets_builders/scisummnet.py +105 -0
  19. dataset/non_huggingface_datasets_builders/summscreen.py +123 -0
  20. dataset/st_dataset.py +281 -0
  21. dependencies.txt +11 -0
  22. dist/SummerTime-0.1-py3-none-any.whl +0 -0
  23. download.py +3 -0
  24. evaluation/__init__.py +14 -0
  25. evaluation/base_metric.py +27 -0
  26. evaluation/bertscore_metric.py +20 -0
  27. evaluation/bleu_metric.py +20 -0
  28. evaluation/meteor_metric.py +31 -0
  29. evaluation/rouge_metric.py +23 -0
  30. evaluation/rougewe_metric.py +24 -0
  31. evaluation/summeval_metric.py +18 -0
  32. model/__init__.py +34 -0
  33. model/base_model.py +81 -0
  34. model/defaults.py +10 -0
  35. model/dialogue/__init__.py +1 -0
  36. model/dialogue/hmnet/ExampleRawData/meeting_summarization/AMI_proprec/test_ami.json +1 -0
  37. model/dialogue/hmnet/ExampleRawData/meeting_summarization/role_dict_ext.json +1 -0
  38. model/dialogue/hmnet/config/dialogue.conf +98 -0
  39. model/dialogue/hmnet_model.py +483 -0
  40. model/multi_doc/__init__.py +2 -0
  41. model/multi_doc/base_multi_doc_model.py +40 -0
  42. model/multi_doc/multi_doc_joint_model.py +51 -0
  43. model/multi_doc/multi_doc_separate_model.py +49 -0
  44. model/query_based/__init__.py +2 -0
  45. model/query_based/base_query_based_model.py +147 -0
  46. model/query_based/bm25_model.py +45 -0
  47. model/query_based/tf_idf_model.py +46 -0
  48. model/single_doc/__init__.py +5 -0
  49. model/single_doc/bart_model.py +36 -0
  50. model/single_doc/base_single_doc_model.py +36 -0
.idea/SummerTime.iml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <module type="PYTHON_MODULE" version="4">
3
+ <component name="NewModuleRootManager">
4
+ <content url="file://$MODULE_DIR$" />
5
+ <orderEntry type="inheritedJdk" />
6
+ <orderEntry type="sourceFolder" forTests="false" />
7
+ </component>
8
+ </module>
.idea/inspectionProfiles/Project_Default.xml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <profile version="1.0">
3
+ <option name="myName" value="Project Default" />
4
+ <inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
5
+ <option name="ignoredPackages">
6
+ <value>
7
+ <list size="3">
8
+ <item index="0" class="java.lang.String" itemvalue="onnxruntime" />
9
+ <item index="1" class="java.lang.String" itemvalue="onnx_tf" />
10
+ <item index="2" class="java.lang.String" itemvalue="onnx" />
11
+ </list>
12
+ </value>
13
+ </option>
14
+ </inspection_tool>
15
+ </profile>
16
+ </component>
.idea/inspectionProfiles/profiles_settings.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <settings>
3
+ <option name="USE_PROJECT_PROFILE" value="false" />
4
+ <version value="1.0" />
5
+ </settings>
6
+ </component>
.idea/modules.xml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectModuleManager">
4
+ <modules>
5
+ <module fileurl="file://$PROJECT_DIR$/.idea/SummerTime.iml" filepath="$PROJECT_DIR$/.idea/SummerTime.iml" />
6
+ </modules>
7
+ </component>
8
+ </project>
LICENSE ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ https://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "{}"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2021 SummerTime
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ https://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
202
+
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
  title: SummerTime
3
- emoji: 💩
4
  colorFrom: purple
5
  colorTo: green
6
  sdk: gradio
1
  ---
2
  title: SummerTime
3
+ emoji: 🔥
4
  colorFrom: purple
5
  colorTo: green
6
  sdk: gradio
SummerTime.egg-info/PKG-INFO ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.1
2
+ Name: SummerTime
3
+ Version: 0.1
4
+ Summary: A summarization mode
5
+ Home-page: https://github.com/LILYlab
6
+ Author: Ansong Ni, Murori Mutuma, Zhangir Azerbayev, Yusen Zhang, Tao Yu, Dragomir Radev
7
+ Author-email: ansong.ni@yale.edu, murorimutuma@gmail.com, zhangir.azerbayev@yale.edu
8
+ License: UNKNOWN
9
+ Description: # SummerTime
10
+
11
+ A library to help users choose appropriate summarization tools based on their specific tasks or needs. Includes models, evaluation metrics, and datasets.
12
+
13
+
14
+
15
+ ## Installation and setup
16
+
17
+ #### Create and activate a new `conda` environment:
18
+ ```bash
19
+ conda create -n st python=3.7
20
+ conda activate st
21
+ ```
22
+
23
+ #### `pip` dependencies for local demo:
24
+ ```bash
25
+ pip install -r requirements.txt
26
+ ```
27
+
28
+
29
+
30
+ ## Quick Start
31
+ Imports model, initializes default model, and summarizes sample documents.
32
+ ```python
33
+ import model as st_model
34
+
35
+ model = st_model.summarizer()
36
+ documents = [
37
+ """ PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions.
38
+ The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected
39
+ by the shutoffs which were expected to last through at least midday tomorrow."""
40
+ ]
41
+ model.summarize(documents)
42
+
43
+ # ["California's largest electricity provider has turned off power to hundreds of thousands of customers."]
44
+ ```
45
+
46
+ Also, please run `demo.ipynb` demo Jupyter notebook for more examples. To start demo Jupyter notebook on localhost:
47
+ ```bash
48
+ jupyter notebook demo.ipynb
49
+ ```
50
+
51
+
52
+
53
+ ## Models
54
+ Import and initialization:
55
+ ```python
56
+ import model as st_model
57
+
58
+ default_model = std_model.summarizer()
59
+ bart_model = std_model.bart_model.BartModel()
60
+ pegasus_model = std_model.pegasus_model.PegasusModel()
61
+ lexrank_model = std_model.lexrank_model.LexRankModel()
62
+ textrank_model = st_model.textrank_model.TextRankModel()
63
+ ```
64
+
65
+ All models can be initialized with the following optional options:
66
+ ```python
67
+ def __init__(self,
68
+ trained_domain: str=None,
69
+ max_input_length: int=None,
70
+ max_output_length: int=None,
71
+ ):
72
+ ```
73
+
74
+ All models implement the following methods:
75
+ ```python
76
+ def summarize(self,
77
+ corpus: Union[List[str], List[List[str]]],
78
+ queries: List[str]=None) -> List[str]:
79
+
80
+ def show_capability(cls) -> None:
81
+
82
+ def generate_basic_description(cls) -> str:
83
+ ```
84
+
85
+
86
+
87
+ ## Evaluation
88
+ Import and initialization:
89
+ ```python
90
+ import eval as st_eval
91
+
92
+ bert_eval = st_eval.bertscore()
93
+ bleu_eval = st_eval.bleu_eval()
94
+ rouge_eval = st_eval.rouge()
95
+ rougewe_eval = st_eval.rougewe()
96
+ ```
97
+
98
+ All evaluation metrics can be initialized with the following optional arguments:
99
+ ```python
100
+ def __init__(self, metric_name):
101
+ ```
102
+
103
+ All evaluation metric objects implement the following methods:
104
+ ```python
105
+ def evaluate(self, model, data):
106
+
107
+ def get_dict(self, keys):
108
+ ```
109
+
110
+
111
+ ## Datasets
112
+ Import and initialization:
113
+ ```python
114
+ import dataset.stdatasets as st_data
115
+ ```
116
+
117
+ ## Contributors
118
+ This repository is built by the [LILY Lab](https://yale-lily.github.io/) at Yale University, led by Prof. [Dragomir Radev](https://cpsc.yale.edu/people/dragomir-radev). The main contributors are [Ansong Ni](https://niansong1996.github.io), Zhangir Azerbayev, Troy Feng, Murori Mutuma and Yusen Zhang (Penn State). For comments and question, please open an issue.
119
+
120
+ Platform: UNKNOWN
121
+ Classifier: Programming Language :: Python :: 3
122
+ Classifier: License :: OSI Approved :: MIT License
123
+ Classifier: Operating System :: OS Independent
124
+ Description-Content-Type: text/markdown
SummerTime.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ README.md
2
+ setup.py
3
+ summertime.py
4
+ SummerTime.egg-info/PKG-INFO
5
+ SummerTime.egg-info/SOURCES.txt
6
+ SummerTime.egg-info/dependency_links.txt
7
+ SummerTime.egg-info/top_level.txt
8
+ dataset/__init__.py
9
+ dataset/datasets_demo.py
10
+ dataset/huggingface_datasets.py
11
+ dataset/non_huggingface_datasets.py
12
+ dataset/st_dataset.py
13
+ evaluation/__init__.py
14
+ evaluation/base_metric.py
15
+ evaluation/bertscore_metric.py
16
+ evaluation/bleu_metric.py
17
+ evaluation/meteor_metric.py
18
+ evaluation/rouge_metric.py
19
+ evaluation/rougewe_metric.py
20
+ evaluation/summeval_metric.py
21
+ model/__init__.py
22
+ model/base_model.py
23
+ model/defaults.py
24
+ model/dialogue/__init__.py
25
+ model/dialogue/hmnet_model.py
26
+ model/multi_doc/__init__.py
27
+ model/multi_doc/base_multi_doc_model.py
28
+ model/multi_doc/multi_doc_joint_model.py
29
+ model/multi_doc/multi_doc_separate_model.py
30
+ model/query_based/__init__.py
31
+ model/query_based/base_query_based_model.py
32
+ model/query_based/bm25_model.py
33
+ model/query_based/tf_idf_model.py
34
+ model/single_doc/__init__.py
35
+ model/single_doc/bart_model.py
36
+ model/single_doc/base_single_doc_model.py
37
+ model/single_doc/lexrank_model.py
38
+ model/single_doc/longformer_model.py
39
+ model/single_doc/pegasus_model.py
40
+ model/single_doc/textrank_model.py
41
+ tests/__init__.py
42
+ tests/dataset_test.py
43
+ tests/demo_test.py
44
+ tests/evaluation_test.py
45
+ tests/integration_test.py
46
+ tests/model_test.py
SummerTime.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
1
+
SummerTime.egg-info/top_level.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
1
+ dataset
2
+ evaluation
3
+ model
4
+ tests
__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ import SummerTime.model
2
+ import SummerTime.dataset.st_dataset as data
3
+ import SummerTime.evaluation
app.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import model as st_model
3
+ import gradio as gr
4
+
5
+
6
+ model = st_model.summarizer()
7
+
8
+ def inference(text):
9
+ documents = [text]
10
+ model.summarize(documents)
11
+ return model.summarize(documents)[0]
12
+
13
+ title = "SummerTime: Text Summarization for Non-Experts"
14
+ description = "This is a demo of SummerTime: An open-source text summarization toolkit for non-experts. You can read more about the project at the links below. Input your text below (or click one of the examples to load them), and the model will generate a summary for it."
15
+ article = "<p style='text-align: center'><a target='_blank' href='https://arxiv.org/abs/2108.12738'>SummerTime: Text Summarization Toolkit for Non-experts</a> | <a target='_blank' href='https://github.com/Yale-LILY/SummerTime'>Github Repo</a> | <a target='_blank' href='https://colab.research.google.com/drive/19tPdBgaJ4_QjSiFyoxtpnFGW4OG1gTec?usp=sharing'>Colab Notebook</a></p>"
16
+
17
+ gr.Interface(
18
+ inference,
19
+ [gr.inputs.Textbox(label="Input", lines=20)],
20
+ gr.outputs.Textbox(label="Output"),
21
+ title=title,
22
+ description=description,
23
+ article=article,
24
+ examples=[["""PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions.
25
+ The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected
26
+ by the shutoffs which were expected to last through at least midday tomorrow."""],
27
+ ["""Representative Kevin McCarthy, the House Republican leader, has threatened to retaliate against any company that complies with the congressional committee investigating the Jan. 6 riot, after the panel asked dozens of firms to preserve the phone and social media records of 11 far-right members of Congress who pushed to overturn the results of the 2020 election. Mr. McCarthy’s warning was an escalation of his efforts to thwart a full accounting of the deadly attack at the Capitol carried out by a pro-Trump mob, and his latest attempt to insulate the former president and Republican lawmakers from scrutiny of any ties to the violence. It came after he led the G.O.P. opposition to the creation of an independent bipartisan commission to investigate the riot, and then pulled five Republican congressmen from the select committee that Democrats created on their own, boycotting the proceedings."""],
28
+ ["""Asked about the report, Google responded in an email that its "advertising technologies help websites and apps fund their content, enable small businesses to grow, and protect users from exploitative privacy practices and bad ad experiences." A lawsuit by 38 U.S. states and territories accuses Google of abusing its market power in an effort to make its search engine as dominant inside cars, TVs and speakers as it is in phones. This was consolidated with the federal lawsuit for purposes of discovery. Texas, backed by other states, filed a separate lawsuit against Google, accusing it of breaking antitrust law in how it runs its online advertising business."""]]).launch(debug=True)
build/scripts-3.9/summertime ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ #!python
2
+
3
+ print("welcome to Summer Time!")
dataset/__init__.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataset.dataset_loaders import (
2
+ CnndmDataset,
3
+ MultinewsDataset,
4
+ SamsumDataset,
5
+ XsumDataset,
6
+ PubmedqaDataset,
7
+ MlsumDataset,
8
+ ScisummnetDataset,
9
+ SummscreenDataset,
10
+ QMsumDataset,
11
+ ArxivDataset,
12
+ )
13
+
14
+
15
+ SUPPORTED_SUMM_DATASETS = [
16
+ CnndmDataset,
17
+ MultinewsDataset,
18
+ SamsumDataset,
19
+ XsumDataset,
20
+ PubmedqaDataset,
21
+ MlsumDataset,
22
+ ScisummnetDataset,
23
+ SummscreenDataset,
24
+ QMsumDataset,
25
+ ArxivDataset,
26
+ ]
27
+
28
+
29
+ def list_all_datasets():
30
+ all_datasets = []
31
+ for ds in SUPPORTED_SUMM_DATASETS:
32
+ dataset_description = ds.generate_basic_description()
33
+
34
+ all_datasets.append((ds.dataset_name, dataset_description))
35
+
36
+ return all_datasets
dataset/dataset_loaders.py ADDED
@@ -0,0 +1,501 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from os import path
2
+ from tqdm import tqdm
3
+ from typing import List, Generator, Optional, Union
4
+
5
+ from datasets import Dataset
6
+
7
+ from dataset.st_dataset import SummInstance, SummDataset
8
+
9
+
10
+ # Set directory to load non_huggingface dataset scripts
11
+ FILE_DIRECTORY_PATH = path.dirname(path.realpath(__file__))
12
+ BASE_NONHUGGINGFACE_DATASETS_PATH = path.join(
13
+ FILE_DIRECTORY_PATH, "non_huggingface_datasets_builders"
14
+ )
15
+
16
+
17
+ # Huggingface Datasets
18
+
19
+
20
+ class CnndmDataset(SummDataset):
21
+ """
22
+ The CNN/DM dataset
23
+ """
24
+
25
+ dataset_name = "CNN/DailyMail"
26
+
27
+ is_query_based = False
28
+ is_dialogue_based = False
29
+ is_multi_document = False
30
+
31
+ huggingface_dataset = True
32
+ huggingface_page = "https://huggingface.co/datasets/cnn_dailymail"
33
+
34
+ def __init__(self):
35
+ super().__init__(
36
+ dataset_args=(
37
+ "cnn_dailymail",
38
+ "3.0.0",
39
+ )
40
+ )
41
+
42
+ def _process_data(self, data: Dataset) -> Generator[SummInstance, None, None]:
43
+ """
44
+ Overrides the SummDataset '_process_data()' method
45
+ This method processes the data contained in the dataset
46
+ and puts each data instance into a SummInstance object
47
+ :param dataset: a train/validation/test dataset
48
+ :rtype: a generator yielding SummInstance objects
49
+ """
50
+ for instance in tqdm(data):
51
+ article: str = instance["article"]
52
+ highlights: str = instance["highlights"]
53
+ summ_instance = SummInstance(source=article, summary=highlights)
54
+
55
+ yield summ_instance
56
+
57
+
58
+ class MultinewsDataset(SummDataset):
59
+ """
60
+ The Multi News dataset
61
+ """
62
+
63
+ dataset_name = "Multinews"
64
+
65
+ is_query_based = False
66
+ is_dialogue_based = False
67
+ is_multi_document = True
68
+
69
+ huggingface_dataset = True
70
+ huggingface_page = "https://huggingface.co/datasets/multi_news"
71
+
72
+ def __init__(self):
73
+ super().__init__(dataset_args=("multi_news",))
74
+
75
+ def _process_data(self, data: Dataset) -> Generator[SummInstance, None, None]:
76
+ """
77
+ Overrides the SummDataset '_process_data()' method
78
+ This method processes the data contained in the dataset
79
+ and puts each data instance into a SummInstance object
80
+ :param dataset: a train/validation/test dataset
81
+ :rtype: a generator yielding SummInstance objects
82
+ """
83
+ for instance in tqdm(data):
84
+ document: list = [
85
+ doc for doc in instance["document"].split("|||||") if doc
86
+ ] # removes the empty string generated
87
+ # since each doc ends with the delimiting token '|||||'
88
+ # the final doc creates an empty string
89
+ summary: str = instance["summary"]
90
+ summ_instance = SummInstance(source=document, summary=summary)
91
+
92
+ yield summ_instance
93
+
94
+
95
+ class SamsumDataset(SummDataset):
96
+ """
97
+ The SAMsum Dataset
98
+ """
99
+
100
+ dataset_name = "Samsum"
101
+
102
+ is_query_based = False
103
+ is_dialogue_based = True
104
+ is_multi_document = False
105
+
106
+ huggingface_dataset = True
107
+ huggingface_page = "https://huggingface.co/datasets/samsum"
108
+
109
+ def __init__(self):
110
+ super().__init__(dataset_args=("samsum",))
111
+
112
+ def _process_data(self, data: Dataset) -> Generator[SummInstance, None, None]:
113
+ """
114
+ Overrides the SummDataset '_process_data()' method
115
+ This method processes the data contained in the dataset
116
+ and puts each data instance into a SummInstance object
117
+ :param dataset: a train/validation/test dataset
118
+ :rtype: a generator yielding SummInstance objects
119
+ """
120
+ for instance in tqdm(data):
121
+ dialogue: List = instance["dialogue"].split(
122
+ "\r\n"
123
+ ) # split each dialogue into a list of strings such as
124
+ # ["speaker1 : utter..", "speaker2 : utter..."]
125
+ summary: str = instance["summary"]
126
+ summ_instance = SummInstance(source=dialogue, summary=summary)
127
+
128
+ yield summ_instance
129
+
130
+
131
+ class XsumDataset(SummDataset):
132
+ """
133
+ The Xsum Dataset
134
+ """
135
+
136
+ dataset_name = "Xsum"
137
+
138
+ huggingface_dataset = True
139
+ huggingface_page = "https://huggingface.co/datasets/xsum"
140
+
141
+ is_query_based = False
142
+ is_dialogue_based = False
143
+ is_multi_document = False
144
+
145
+ def __init__(self):
146
+ super().__init__(dataset_args=("xsum",))
147
+
148
+ def _process_data(self, data: Dataset) -> Generator[SummInstance, None, None]:
149
+ """
150
+ Overrides the SummDataset '_process_data()' method
151
+ This method processes the data contained in the dataset
152
+ and puts each data instance into a SummInstance object
153
+ :param dataset: a train/validation/test dataset
154
+ :rtype: a generator yielding SummInstance objects
155
+ """
156
+ for instance in tqdm(data):
157
+ document: List = instance["document"]
158
+ summary: str = instance["summary"]
159
+ summ_instance = SummInstance(source=document, summary=summary)
160
+
161
+ yield summ_instance
162
+
163
+
164
+ class PubmedqaDataset(SummDataset):
165
+ """
166
+ The Pubmed QA dataset
167
+ """
168
+
169
+ dataset_name = "Pubmedqa"
170
+
171
+ is_query_based = True
172
+ is_dialogue_based = False
173
+ is_multi_document = False
174
+
175
+ huggingface_dataset = True
176
+ huggingface_page = "https://huggingface.co/datasets/pubmed_qa"
177
+
178
+ def __init__(self, seed=None):
179
+ super().__init__(
180
+ dataset_args=(
181
+ "pubmed_qa",
182
+ "pqa_artificial",
183
+ )
184
+ )
185
+
186
+ def _process_data(self, data: Dataset) -> Generator[SummInstance, None, None]:
187
+ """
188
+ Overrides the SummDataset '_process_data()' method
189
+ This method processes the data contained in the dataset
190
+ and puts each data instance into a SummInstance object
191
+ :param dataset: a train/validation/test dataset
192
+ :rtype: a generator yielding SummInstance objects
193
+ """
194
+ for instance in tqdm(data):
195
+ context: str = " ".join(instance["context"]["contexts"])
196
+ answer: str = instance["long_answer"]
197
+ query: str = instance["question"]
198
+ summ_instance = SummInstance(source=context, summary=answer, query=query)
199
+
200
+ yield summ_instance
201
+
202
+
203
+ class MlsumDataset(SummDataset):
204
+ """
205
+ The MLsum Dataset - A multi-lingual dataset featuring 5 languages
206
+ Includes 1.5 million news articles and their corresponding summaries
207
+
208
+ "de" - German
209
+ "es" - Spanish
210
+ "fr" - French
211
+ "ru" - Russian
212
+ "tu" - Turkish
213
+ """
214
+
215
+ dataset_name = "MlSum"
216
+
217
+ is_query_based = False
218
+ is_dialogue_based = False
219
+ is_multi_document = False
220
+
221
+ huggingface_dataset = True
222
+ huggingface_page = "https://huggingface.co/datasets/mlsum"
223
+ supported_languages = ["de", "es", "fr", "ru", "tu"]
224
+
225
+ mlsum_instantiation_guide = """The languages supported for the Mlsum Dataset are:
226
+ de - German
227
+ es - Spanish
228
+ fr - French
229
+ ru - Russian
230
+ tu - Turkish
231
+
232
+ Examples to instantiate the dataset:
233
+ 1. Dataset with only one language
234
+ dataset = MlsumDataset({language_token})
235
+ dataset = MlsumDataset("es")
236
+ dataset = MlsumDataset("tu")...
237
+
238
+ 2. Dataset with a multiple languages
239
+ dataset = MlsumDataset({list of language_token})
240
+ dataset = MlsumDataset(["es","de"])
241
+ dataset = MlsumDataset(["es","de", "tu"])...
242
+
243
+ 3. Dataset with all supported languages (default)
244
+ dataset = MlsumDataset(all)
245
+ dataset = MlsumDataset()
246
+ """
247
+
248
+ def __init__(self, languages: Optional[Union[str, List[str]]] = "all"):
249
+ super().__init__(dataset_args=(languages,))
250
+
251
+ def _load_dataset_safe(self, languages: Optional[Union[str, List[str]]]):
252
+ """
253
+ Overrides the parent class method
254
+ Method loads multiple datasets of different languages provided in :param languages:
255
+ It then concatenates these datasets into one combined dataset
256
+ :rtype: datasetDict containing the combined dataset
257
+ :param languages: Optional, either a string or list of strings specifying the languages
258
+ to load
259
+ """
260
+ print(MlsumDataset.mlsum_instantiation_guide)
261
+
262
+ # Choose languages to download articles
263
+ if languages == "all":
264
+ selected_languages = MlsumDataset.supported_languages
265
+ elif isinstance(languages, list):
266
+ for language in languages:
267
+ assert self.is_supported(language)
268
+ selected_languages = languages
269
+ else:
270
+ assert self.is_supported(languages)
271
+ selected_languages = [languages]
272
+
273
+ # Concatenate selected languaeges into one dataset
274
+ language_datasets = []
275
+ for language in selected_languages:
276
+ dataset = super()._load_dataset_safe(
277
+ "mlsum",
278
+ language,
279
+ )
280
+
281
+ language_datasets.append(dataset)
282
+
283
+ mlsum_dataset = self._concatenate_dataset_dicts(language_datasets)
284
+
285
+ return mlsum_dataset
286
+
287
+ def _process_data(self, data: Dataset) -> Generator[SummInstance, None, None]:
288
+ """
289
+ Overrides the SummDataset '_process_data()' method
290
+ This method processes the data contained in the dataset
291
+ and puts each data instance into a SummInstance object
292
+ :param dataset: a train/validation/test dataset
293
+ :rtype: a generator yielding SummInstance objects
294
+ """
295
+ for instance in tqdm(data):
296
+ article: List = instance["text"]
297
+ summary: str = instance["summary"]
298
+ summ_instance = SummInstance(source=article, summary=summary)
299
+
300
+ yield summ_instance
301
+
302
+ def is_supported(self, language: str):
303
+ """
304
+ Checks whether the requested langues is supported
305
+ :param language: string containing the requested language
306
+ :rtype bool:
307
+ """
308
+ if language not in MlsumDataset.supported_languages:
309
+ print(MlsumDataset.mlsum_instantiation_guide)
310
+ raise ValueError(
311
+ f"The language(s): '{language}' entered is not supported. See above message for usage info"
312
+ )
313
+ else:
314
+ return True
315
+
316
+
317
+ # Non-huggingface datasets
318
+
319
+
320
+ class ScisummnetDataset(SummDataset):
321
+ """
322
+ The SciSummNet dataset. As a dataset not included by huggingface, we need to do manually download, set basic
323
+ information for the dataset
324
+ """
325
+
326
+ dataset_name = "ScisummNet"
327
+
328
+ version = "1.1.0"
329
+ description = (
330
+ "A summary of scientific papers should ideally incorporate the impact of the papers on the "
331
+ "research community reflected by citations. To facilitate research in citation-aware scientific "
332
+ "paper summarization (Scisumm), the CL-Scisumm shared task has been organized since 2014 for "
333
+ "papers in the computational linguistics and NLP domain."
334
+ )
335
+
336
+ is_dialogue_based = False
337
+ is_multi_document = False
338
+ is_query_based = False
339
+
340
+ huggingface_dataset = False
341
+ builder_script_path = path.join(
342
+ BASE_NONHUGGINGFACE_DATASETS_PATH, dataset_name.lower() + ".py"
343
+ )
344
+
345
+ def __init__(self, seed=None):
346
+ super().__init__()
347
+
348
+ def _process_data(self, data: Dataset) -> Generator[SummInstance, None, None]:
349
+ """
350
+ Overrides the SummDataset '_process_data()' method
351
+ This method processes the data contained in the dataset
352
+ and puts each data instance into a SummInstance object
353
+ :param dataset: a train/validation/test dataset
354
+ :rtype: a generator yielding SummInstance objects
355
+ """
356
+ for instance in tqdm(data):
357
+ docs: List = [
358
+ instance["document_xml"],
359
+ instance["citing_sentences_annotated.json"],
360
+ ]
361
+ summary: str = instance["summary"]
362
+ summ_instance = SummInstance(source=docs, summary=summary)
363
+
364
+ yield summ_instance
365
+
366
+
367
+ class SummscreenDataset(SummDataset):
368
+ """
369
+ The SummScreen dataset. As a dataset not included by huggingface, we need to do manually download, set basic
370
+ information for the dataset
371
+ """
372
+
373
+ dataset_name = "Summscreen"
374
+
375
+ version = "1.1.0"
376
+ is_dialogue_based = True
377
+ is_multi_document = False
378
+ is_query_based = False
379
+
380
+ huggingface_dataset = False
381
+ builder_script_path = path.join(
382
+ BASE_NONHUGGINGFACE_DATASETS_PATH, dataset_name.lower() + ".py"
383
+ )
384
+
385
+ def __init__(self, seed=None):
386
+ super().__init__()
387
+
388
+ def _process_data(self, data: Dataset) -> Generator[SummInstance, None, None]:
389
+ """
390
+ Overrides the SummDataset '_process_data()' method
391
+ This method processes the data contained in the dataset
392
+ and puts each data instance into a SummInstance object
393
+ :param dataset: a train/validation/test dataset
394
+ :rtype: a generator yielding SummInstance objects
395
+ """
396
+ for instance in tqdm(data):
397
+ transcript: List = instance[
398
+ "transcript"
399
+ ] # convert string into a list of string dialogues
400
+ recap: str = instance["recap"]
401
+ summ_instance = SummInstance(source=transcript, summary=recap)
402
+
403
+ yield summ_instance
404
+
405
+
406
+ class QMsumDataset(SummDataset):
407
+ """
408
+ QMSum Dataset
409
+ """
410
+
411
+ dataset_name = "QMsum"
412
+ description = """
413
+ QMSum is a new human-annotated benchmark for query-based multi-domain meeting summarization task,
414
+ which consists of 1,808 query-summary pairs over 232 meetings in multiple domains.
415
+ """
416
+
417
+ is_dialogue_based = True
418
+ is_multi_document = False
419
+ is_query_based = True
420
+
421
+ huggingface_dataset = False
422
+ builder_script_path = path.join(
423
+ BASE_NONHUGGINGFACE_DATASETS_PATH, dataset_name.lower() + ".py"
424
+ )
425
+
426
+ def __init__(self):
427
+ super().__init__()
428
+
429
+ def _process_data(self, data: Dataset) -> Generator[SummInstance, None, None]:
430
+ """
431
+ Overrides the SummDataset '_process_data()' method
432
+ This method processes the data contained in the dataset
433
+ and puts each data instance into a SummInstance object
434
+ :param dataset: a train/validation/test dataset
435
+ :rtype: a generator yielding SummInstance objects
436
+ """
437
+ for instance in tqdm(data):
438
+ for query_set in (
439
+ instance["general_query_list"] + instance["specific_query_list"]
440
+ ):
441
+ meeting: List = [
442
+ utterance["speaker"] + " : " + utterance["content"]
443
+ for utterance in instance["meeting_transcripts"]
444
+ ]
445
+ query: str = query_set["query"]
446
+ summary: str = query_set["answer"]
447
+ summ_instance = SummInstance(
448
+ source=meeting, summary=summary, query=query
449
+ )
450
+
451
+ yield summ_instance
452
+
453
+
454
+ class ArxivDataset(SummDataset):
455
+ """
456
+ The Arxiv Dataset
457
+ """
458
+
459
+ dataset_name = "Arxiv_longsummarization"
460
+ description = """
461
+ A summarization dataset comprised of pairs of scientific papers.
462
+ The dataset provides a challenging testbed for abstractive summarization.
463
+ It contains papers and their abstracts.
464
+ """
465
+
466
+ is_dialogue_based = False
467
+ is_multi_document = False
468
+ is_query_based = False
469
+
470
+ huggingface_dataset = False
471
+ builder_script_path = path.join(
472
+ BASE_NONHUGGINGFACE_DATASETS_PATH, dataset_name.lower() + ".py"
473
+ )
474
+
475
+ def __init__(self):
476
+
477
+ print(
478
+ "*****************",
479
+ "***Attention***",
480
+ "This dataset is quite large (approx 5Gb and will need about 15 Gb for the extraction process",
481
+ "Cancel/interrupt the download if size and time constraints will not be met",
482
+ "*****************",
483
+ sep="\n",
484
+ )
485
+
486
+ super().__init__()
487
+
488
+ def _process_data(self, data: Dataset) -> Generator[SummInstance, None, None]:
489
+ """
490
+ Overrides the SummDataset '_process_data()' method
491
+ This method processes the data contained in the dataset
492
+ and puts each data instance into a SummInstance object
493
+ :param dataset: a train/validation/test dataset
494
+ :rtype: a generator yielding SummInstance objects
495
+ """
496
+ for instance in tqdm(data):
497
+ article: List = instance["article_text"]
498
+ abstract: str = " ".join(instance["abstract_text"])
499
+ summ_instance = SummInstance(source=article, summary=abstract)
500
+
501
+ yield summ_instance
dataset/non_huggingface_datasets_builders/arxiv_longsummarization.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import datasets
4
+
5
+
6
+ """Arxiv dataset."""
7
+
8
+
9
+ _CITATION = """
10
+ @article{Cohan_2018,
11
+ title={A Discourse-Aware Attention Model for Abstractive Summarization of
12
+ Long Documents},
13
+ url={http://dx.doi.org/10.18653/v1/n18-2097},
14
+ DOI={10.18653/v1/n18-2097},
15
+ journal={Proceedings of the 2018 Conference of the North American Chapter of
16
+ the Association for Computational Linguistics: Human Language
17
+ Technologies, Volume 2 (Short Papers)},
18
+ publisher={Association for Computational Linguistics},
19
+ author={Cohan, Arman and Dernoncourt, Franck and Kim, Doo Soon and Bui, Trung and Kim, Seokhwan and Chang, Walter and Goharian, Nazli},
20
+ year={2018}
21
+ }
22
+ """
23
+
24
+ _DESCRIPTION = """
25
+ A summarization dataset comprised of pairs of scientific papers.
26
+ The dataset provides a challenging testbed for abstractive summarization.
27
+ It contains papers and their abstracts.
28
+ """
29
+
30
+ _HOMEPAGE = "https://github.com/armancohan/long-summarization"
31
+
32
+ _LICENSE = "Apache-2.0 License"
33
+
34
+ _URL = "https://archive.org/download/armancohan-long-summarization-paper-code/arxiv-dataset.zip"
35
+
36
+
37
+ class SummertimeArxiv(datasets.GeneratorBasedBuilder):
38
+ """Arxiv long summarization dataset."""
39
+
40
+ VERSION = datasets.Version("1.0.0")
41
+
42
+ BUILDER_CONFIGS = [
43
+ datasets.BuilderConfig(),
44
+ ]
45
+
46
+ def _info(self):
47
+ features = datasets.Features(
48
+ {
49
+ "article_id": datasets.Value("string"),
50
+ "article_text": [datasets.Value("string")],
51
+ "abstract_text": [datasets.Value("string")],
52
+ }
53
+ )
54
+ return datasets.DatasetInfo(
55
+ description=_DESCRIPTION,
56
+ features=features,
57
+ supervised_keys=None,
58
+ homepage=_HOMEPAGE,
59
+ license=_LICENSE,
60
+ citation=_CITATION,
61
+ )
62
+
63
+ def _split_generators(self, dl_manager):
64
+ """Returns SplitGenerators."""
65
+ my_urls = _URL
66
+ path = dl_manager.download_and_extract(my_urls)
67
+ path = os.path.join(path, "arxiv-dataset")
68
+
69
+ trainpath = os.path.join(path, "train.txt")
70
+ valpath = os.path.join(path, "val.txt")
71
+ testpath = os.path.join(path, "test.txt")
72
+
73
+ return [
74
+ datasets.SplitGenerator(
75
+ name=datasets.Split.TRAIN,
76
+ # These kwargs will be passed to _generate_examples
77
+ gen_kwargs={"filepath": trainpath, "split": "train"},
78
+ ),
79
+ datasets.SplitGenerator(
80
+ name=datasets.Split.VALIDATION,
81
+ # These kwargs will be passed to _generate_examples
82
+ gen_kwargs={"filepath": valpath, "split": "val"},
83
+ ),
84
+ datasets.SplitGenerator(
85
+ name=datasets.Split.TEST,
86
+ # These kwargs will be passed to _generate_examples
87
+ gen_kwargs={"filepath": testpath, "split": "test"},
88
+ ),
89
+ ]
90
+
91
+ def _generate_examples(self, filepath, split):
92
+ """Yields examples."""
93
+
94
+ with open(filepath, "r") as f:
95
+ for line in f:
96
+
97
+ instance = json.loads(line)
98
+
99
+ entry = {}
100
+ entry["article_id"] = instance["article_id"]
101
+ entry["article_text"] = instance["article_text"]
102
+ entry["abstract_text"] = instance["abstract_text"]
103
+
104
+ yield entry["article_id"], entry
dataset/non_huggingface_datasets_builders/qmsum.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import datasets
4
+
5
+
6
+ """QMsum dataset."""
7
+
8
+
9
+ _CITATION = """
10
+ @inproceedings{zhong2021qmsum,
11
+ title={{QMS}um: {A} {N}ew {B}enchmark for {Q}uery-based {M}ulti-domain {M}eeting {S}ummarization},
12
+ author={Zhong, Ming and Yin, Da and Yu, Tao and Zaidi, Ahmad and Mutuma, Mutethia and Jha, Rahul and Hassan Awadallah, Ahmed and Celikyilmaz, Asli and Liu, Yang and Qiu, Xipeng and Radev, Dragomir},
13
+ booktitle={North American Association for Computational Linguistics (NAACL)},
14
+ year={2021}
15
+ }
16
+ """
17
+
18
+ _DESCRIPTION = """
19
+ QMSum is a new human-annotated benchmark for query-based multi-domain meeting summarization task, \
20
+ which consists of 1,808 query-summary pairs over 232 meetings in multiple domains.
21
+ """
22
+
23
+ _HOMEPAGE = "https://github.com/Yale-LILY/QMSum"
24
+
25
+ _BASE_URL = "https://raw.githubusercontent.com/Yale-LILY/QMSum/main/data/ALL/jsonl"
26
+ _URLs = {
27
+ "train": _BASE_URL + "/train.jsonl",
28
+ "val": _BASE_URL + "/val.jsonl",
29
+ "test": _BASE_URL + "/test.jsonl",
30
+ }
31
+
32
+
33
+ class SummertimeQmsum(datasets.GeneratorBasedBuilder):
34
+ """QMsum dataset."""
35
+
36
+ VERSION = datasets.Version("1.0.0")
37
+
38
+ BUILDER_CONFIGS = [
39
+ datasets.BuilderConfig(),
40
+ ]
41
+
42
+ def _info(self):
43
+ features = datasets.Features(
44
+ {
45
+ "entry_number": datasets.Value("string"),
46
+ "meeting_transcripts": [
47
+ {
48
+ "speaker": datasets.Value("string"),
49
+ "content": datasets.Value("string"),
50
+ }
51
+ ],
52
+ "general_query_list": [
53
+ {
54
+ "query": datasets.Value("string"),
55
+ "answer": datasets.Value("string"),
56
+ }
57
+ ],
58
+ "specific_query_list": [
59
+ {
60
+ "query": datasets.Value("string"),
61
+ "answer": datasets.Value("string"),
62
+ "relevant_text_span": [[datasets.Value("string")]],
63
+ }
64
+ ],
65
+ }
66
+ )
67
+ return datasets.DatasetInfo(
68
+ description=_DESCRIPTION,
69
+ features=features,
70
+ supervised_keys=None,
71
+ homepage=_HOMEPAGE,
72
+ license=None,
73
+ citation=_CITATION,
74
+ )
75
+
76
+ def _split_generators(self, dl_manager):
77
+ """Returns SplitGenerators."""
78
+ my_urls = _URLs
79
+ downloaded_files = dl_manager.download_and_extract(my_urls)
80
+
81
+ trainpath = downloaded_files["train"]
82
+ valpath = downloaded_files["val"]
83
+ testpath = downloaded_files["test"]
84
+
85
+ return [
86
+ datasets.SplitGenerator(
87
+ name=datasets.Split.TRAIN,
88
+ # These kwargs will be passed to _generate_examples
89
+ gen_kwargs={"filepath": trainpath, "split": "train"},
90
+ ),
91
+ datasets.SplitGenerator(
92
+ name=datasets.Split.VALIDATION,
93
+ # These kwargs will be passed to _generate_examples
94
+ gen_kwargs={"filepath": valpath, "split": "val"},
95
+ ),
96
+ datasets.SplitGenerator(
97
+ name=datasets.Split.TEST,
98
+ # These kwargs will be passed to _generate_examples
99
+ gen_kwargs={"filepath": testpath, "split": "test"},
100
+ ),
101
+ ]
102
+
103
+ def _generate_examples(self, filepath, split):
104
+ """Yields examples."""
105
+
106
+ extraction_path = os.path.join(filepath)
107
+
108
+ with open(extraction_path) as f:
109
+ for i, line in enumerate(f):
110
+
111
+ instance = json.loads(line)
112
+
113
+ entry = {}
114
+ entry["entry_number"] = split + "_" + str(i)
115
+ entry["meeting_transcripts"] = instance["meeting_transcripts"]
116
+ entry["general_query_list"] = instance["general_query_list"]
117
+ entry["specific_query_list"] = instance["specific_query_list"]
118
+
119
+ yield entry["entry_number"], entry
dataset/non_huggingface_datasets_builders/scisummnet.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import datasets
3
+
4
+
5
+ """Scisummnet dataset."""
6
+
7
+
8
+ _CITATION = """
9
+ @InProceedings{yasunaga&al.19.scisumm,
10
+ title = {{ScisummNet}: A Large Annotated Corpus and Content-Impact Models for Scientific Paper Summarization with Citation Networks},
11
+ author = {Michihiro Yasunaga and Jungo Kasai and Rui Zhang and Alexander Fabbri and Irene Li and Dan Friedman and Dragomir Radev},
12
+ booktitle = {Proceedings of AAAI 2019},
13
+ year = {2019}
14
+ }
15
+ @InProceedings{yasunaga&al.17,
16
+ title = {Graph-based Neural Multi-Document Summarization},
17
+ author = {Yasunaga, Michihiro and Zhang, Rui and Meelu, Kshitijh and Pareek, Ayush and Srinivasan, Krishnan and Radev, Dragomir R.},
18
+ booktitle = {Proceedings of CoNLL 2017},
19
+ year = {2017}
20
+ }
21
+ """
22
+
23
+ _DESCRIPTION = """
24
+ A summary of scientific papers should ideally incorporate the impact of the papers on the research community
25
+ reflected by citations. To facilitate research in citation-aware scientific paper summarization (Scisumm),
26
+ the CL-Scisumm shared task has been organized since 2014 for papers in the computational linguistics and NLP domain.
27
+ """
28
+
29
+ _HOMEPAGE = "https://cs.stanford.edu/~myasu/projects/scisumm_net/"
30
+
31
+ _LICENSE = "CC BY-SA 4.0"
32
+
33
+ _URLs = "https://cs.stanford.edu/~myasu/projects/scisumm_net/scisummnet_release1.1__20190413.zip"
34
+
35
+
36
+ class SummertimeScisummnet(datasets.GeneratorBasedBuilder):
37
+ """Scisummnet dataset."""
38
+
39
+ VERSION = datasets.Version("1.1.0")
40
+
41
+ BUILDER_CONFIGS = [
42
+ datasets.BuilderConfig(),
43
+ ]
44
+
45
+ def _info(self):
46
+ features = datasets.Features(
47
+ {
48
+ "entry_number": datasets.Value("string"),
49
+ "document_xml": datasets.Value("string"),
50
+ "citing_sentences_annotated.json": datasets.Value("string"),
51
+ "summary": datasets.Value("string"),
52
+ }
53
+ )
54
+ return datasets.DatasetInfo(
55
+ description=_DESCRIPTION,
56
+ features=features,
57
+ supervised_keys=None,
58
+ homepage=_HOMEPAGE,
59
+ license=_LICENSE,
60
+ citation=_CITATION,
61
+ )
62
+
63
+ def _split_generators(self, dl_manager):
64
+ """Returns SplitGenerators."""
65
+ my_urls = _URLs
66
+ path = dl_manager.download_and_extract(my_urls)
67
+ trainpath = os.path.join(
68
+ path, "scisummnet_release1.1__20190413", "top1000_complete"
69
+ )
70
+ return [
71
+ datasets.SplitGenerator(
72
+ name=datasets.Split.TRAIN,
73
+ # These kwargs will be passed to _generate_examples
74
+ gen_kwargs={"extraction_path": trainpath, "split": "train"},
75
+ )
76
+ ]
77
+
78
+ def _generate_examples(self, extraction_path, split):
79
+ """Yields examples."""
80
+
81
+ for folder in os.listdir(extraction_path):
82
+
83
+ entry = {}
84
+
85
+ entry["entry_number"] = folder
86
+
87
+ doc_xml_path = os.path.join(
88
+ extraction_path, folder, "Documents_xml", folder + ".xml"
89
+ )
90
+ with open(doc_xml_path, "r", encoding="utf-8") as f:
91
+ entry["document_xml"] = f.read()
92
+
93
+ cite_annot_path = os.path.join(
94
+ extraction_path, folder, "citing_sentences_annotated.json"
95
+ )
96
+ with open(cite_annot_path, "r", encoding="utf-8") as f:
97
+ entry["citing_sentences_annotated.json"] = f.read()
98
+
99
+ summary_path = os.path.join(
100
+ extraction_path, folder, "summary", folder + ".gold.txt"
101
+ )
102
+ with open(summary_path, "r", encoding="utf-8") as f:
103
+ entry["summary"] = f.read()
104
+
105
+ yield entry["entry_number"], entry
dataset/non_huggingface_datasets_builders/summscreen.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import datasets
4
+
5
+
6
+ """Summscreen dataset."""
7
+
8
+
9
+ _CITATION = """
10
+ @article{DBLP:journals/corr/abs-2104-07091,
11
+ author = {Mingda Chen and
12
+ Zewei Chu and
13
+ Sam Wiseman and
14
+ Kevin Gimpel},
15
+ title = {SummScreen: {A} Dataset for Abstractive Screenplay Summarization},
16
+ journal = {CoRR},
17
+ volume = {abs/2104.07091},
18
+ year = {2021},
19
+ url = {https://arxiv.org/abs/2104.07091},
20
+ archivePrefix = {arXiv},
21
+ eprint = {2104.07091},
22
+ timestamp = {Mon, 19 Apr 2021 16:45:47 +0200},
23
+ biburl = {https://dblp.org/rec/journals/corr/abs-2104-07091.bib},
24
+ bibsource = {dblp computer science bibliography, https://dblp.org}
25
+ }
26
+ """
27
+
28
+ _DESCRIPTION = """
29
+ A summary of scientific papers should ideally incorporate the impact of the papers on the research community
30
+ reflected by citations. To facilitate research in citation-aware scientific paper summarization (Scisumm),
31
+ the CL-Scisumm shared task has been organized since 2014 for papers in the computational linguistics and NLP domain.
32
+ """
33
+
34
+ _HOMEPAGE = "https://github.com/mingdachen/SummScreen"
35
+
36
+ _LICENSE = "MIT Licencse"
37
+
38
+ _URLs = "https://drive.google.com/uc?id=1BvdIllGBo9d2-bzXQRzWuJXB04XPVmfF"
39
+
40
+
41
+ class SummertimeSummscreen(datasets.GeneratorBasedBuilder):
42
+ """Summscreen dataset."""
43
+
44
+ VERSION = datasets.Version("1.1.0")
45
+
46
+ BUILDER_CONFIGS = [
47
+ datasets.BuilderConfig(),
48
+ ]
49
+
50
+ def _info(self):
51
+ features = datasets.Features(
52
+ {
53
+ "entry_number": datasets.Value("string"),
54
+ "transcript": datasets.features.Sequence(datasets.Value("string")),
55
+ "recap": datasets.Value("string"),
56
+ }
57
+ )
58
+ return datasets.DatasetInfo(
59
+ description=_DESCRIPTION,
60
+ features=features,
61
+ supervised_keys=None,
62
+ homepage=_HOMEPAGE,
63
+ license=_LICENSE,
64
+ citation=_CITATION,
65
+ )
66
+
67
+ def _split_generators(self, dl_manager):
68
+ """Returns SplitGenerators."""
69
+ my_urls = _URLs
70
+ path = dl_manager.download_and_extract(my_urls)
71
+ path = os.path.join(path, "SummScreen")
72
+
73
+ trainpath_fd = os.path.join("ForeverDreaming", "fd_train.json")
74
+ trainpath_tms = os.path.join("TVMegaSite", "tms_train.json")
75
+ trainpaths = [trainpath_fd, trainpath_tms]
76
+
77
+ devpath_fd = os.path.join("ForeverDreaming", "fd_dev.json")
78
+ devpath_tms = os.path.join("TVMegaSite", "tms_dev.json")
79
+ devpaths = [devpath_fd, devpath_tms]
80
+
81
+ testpath_fd = os.path.join("ForeverDreaming", "fd_test.json")
82
+ testpath_tms = os.path.join("TVMegaSite", "tms_test.json")
83
+ testpaths = [testpath_fd, testpath_tms]
84
+
85
+ return [
86
+ datasets.SplitGenerator(
87
+ name=datasets.Split.TRAIN,
88
+ # These kwargs will be passed to _generate_examples
89
+ gen_kwargs={"filepaths": (path, trainpaths), "split": "train"},
90
+ ),
91
+ datasets.SplitGenerator(
92
+ name=datasets.Split.VALIDATION,
93
+ # These kwargs will be passed to _generate_examples
94
+ gen_kwargs={"filepaths": (path, devpaths), "split": "dev"},
95
+ ),
96
+ datasets.SplitGenerator(
97
+ name=datasets.Split.TEST,
98
+ # These kwargs will be passed to _generate_examples
99
+ gen_kwargs={"filepaths": (path, testpaths), "split": "test"},
100
+ ),
101
+ ]
102
+
103
+ def _generate_examples(self, filepaths, split):
104
+ """Yields examples."""
105
+
106
+ path, relative_filepaths = filepaths
107
+ for filepath in relative_filepaths:
108
+
109
+ extraction_path = os.path.join(path, filepath)
110
+
111
+ with open(extraction_path, "r") as f:
112
+ for line in f:
113
+ processed_line = line.replace("@@ ", "")
114
+ instance = json.loads(processed_line)
115
+
116
+ entry = {}
117
+ entry["entry_number"] = instance["filename"]
118
+ entry["transcript"] = instance["Transcript"]
119
+ entry["recap"] = instance["Recap"][
120
+ 0
121
+ ] # Recap is a single string in list
122
+
123
+ yield entry["entry_number"], entry
dataset/st_dataset.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ from pprint import pformat
3
+ from time import sleep
4
+ from typing import List, Tuple, Optional, Union, Generator
5
+
6
+ from datasets import (
7
+ Dataset,
8
+ DatasetDict,
9
+ DatasetInfo,
10
+ concatenate_datasets,
11
+ load_dataset,
12
+ )
13
+
14
+ # Defualt values for retrying dataset download
15
+ DEFAULT_NUMBER_OF_RETRIES_ALLOWED = 5
16
+ DEFAULT_WAIT_SECONDS_BEFORE_RETRY = 5
17
+
18
+ # Default value for creating missing val/test splits
19
+ TEST_OR_VAL_SPLIT_RATIO = 0.1
20
+
21
+
22
+ class SummInstance:
23
+ """
24
+ Basic instance for summarization tasks
25
+ """
26
+
27
+ def __init__(
28
+ self, source: Union[List[str], str], summary: str, query: Optional[str] = None
29
+ ):
30
+ """
31
+ Create a summarization instance
32
+ :rtype: object
33
+ :param source: either `List[str]` or `str`, depending on the dataset itself, string joining may needed to fit
34
+ into specific models. For example, for the same document, it could be simply `str` or `List[str]` for
35
+ a list of sentences in the same document
36
+ :param summary: a string summary that serves as ground truth
37
+ :param query: Optional, applies when a string query is present
38
+ """
39
+ self.source = source
40
+ self.summary = summary
41
+ self.query = query
42
+
43
+ def __repr__(self):
44
+ instance_dict = {"source": self.source, "summary": self.summary}
45
+ if self.query:
46
+ instance_dict["query"] = self.query
47
+
48
+ return str(instance_dict)
49
+
50
+ def __str__(self):
51
+ instance_dict = {"source": self.source, "summary": self.summary}
52
+ if self.query:
53
+ instance_dict["query"] = self.query
54
+
55
+ return pformat(instance_dict, indent=1)
56
+
57
+
58
+ class SummDataset:
59
+ """
60
+ Dataset class for summarization, which takes into account of the following tasks:
61
+ * Single document summarization
62
+ * Multi-document/Dialogue summarization
63
+ * Query-based summarization
64
+ """
65
+
66
+ def __init__(
67
+ self, dataset_args: Optional[Tuple[str]] = None, splitseed: Optional[int] = None
68
+ ):
69
+ """Create dataset information from the huggingface Dataset class
70
+ :rtype: object
71
+ :param dataset_args: a tuple containing arguments to passed on to the 'load_dataset_safe' method.
72
+ Only required for datasets loaded from the Huggingface library.
73
+ The arguments for each dataset are different and comprise of a string or multiple strings
74
+ :param splitseed: a number to instantiate the random generator used to generate val/test splits
75
+ for the datasets without them
76
+ """
77
+
78
+ # Load dataset from huggingface, use default huggingface arguments
79
+ if self.huggingface_dataset:
80
+ dataset = self._load_dataset_safe(*dataset_args)
81
+ # Load non-huggingface dataset, use custom dataset builder
82
+ else:
83
+ dataset = self._load_dataset_safe(path=self.builder_script_path)
84
+
85
+ info_set = self._get_dataset_info(dataset)
86
+
87
+ # Ensure any dataset with a val or dev or validation split is standardised to validation split
88
+ if "val" in dataset:
89
+ dataset["validation"] = dataset["val"]
90
+ dataset.remove("val")
91
+ elif "dev" in dataset:
92
+ dataset["validation"] = dataset["dev"]
93
+ dataset.remove("dev")
94
+
95
+ # If no splits other other than training, generate them
96
+ assert (
97
+ "train" in dataset or "validation" in dataset or "test" in dataset
98
+ ), "At least one of train/validation test needs to be not empty!"
99
+
100
+ if not ("validation" in dataset or "test" in dataset):
101
+ dataset = self._generate_missing_val_test_splits(dataset, splitseed)
102
+
103
+ self.description = info_set.description
104
+ self.citation = info_set.citation
105
+ self.homepage = info_set.homepage
106
+
107
+ # Extract the dataset entries from folders and load into dataset
108
+ self._train_set = self._process_data(dataset["train"])
109
+ self._validation_set = self._process_data(
110
+ dataset["validation"]
111
+ ) # Some datasets have a validation split
112
+ self._test_set = self._process_data(dataset["test"])
113
+
114
+ @property
115
+ def train_set(self) -> Union[Generator[SummInstance, None, None], List]:
116
+ if self._train_set is not None:
117
+ return self._train_set
118
+ else:
119
+ print(
120
+ f"{self.dataset_name} does not contain a train set, empty list returned"
121
+ )
122
+ return list()
123
+
124
+ @property
125
+ def validation_set(self) -> Union[Generator[SummInstance, None, None], List]:
126
+ if self._validation_set is not None:
127
+ return self._validation_set
128
+ else:
129
+ print(
130
+ f"{self.dataset_name} does not contain a validation set, empty list returned"
131
+ )
132
+ return list()
133
+
134
+ @property
135
+ def test_set(self) -> Union[Generator[SummInstance, None, None], List]:
136
+ if self._test_set is not None:
137
+ return self._test_set
138
+ else:
139
+ print(
140
+ f"{self.dataset_name} does not contain a test set, empty list returned"
141
+ )
142
+ return list()
143
+
144
+ def _load_dataset_safe(self, *args, **kwargs) -> Dataset:
145
+ """
146
+ This method creates a wrapper around the huggingface 'load_dataset()' function for a more robust download function,
147
+ the original 'load_dataset()' function occassionally fails when it cannot reach a server especially after multiple requests.
148
+ This method tackles this problem by attempting the download multiple times with a wait time before each retry
149
+
150
+ The wrapper method passes all arguments and keyword arguments to the 'load_dataset' function with no alteration.
151
+ :rtype: Dataset
152
+ :param args: non-keyword arguments to passed on to the 'load_dataset' function
153
+ :param kwargs: keyword arguments to passed on to the 'load_dataset' function
154
+ """
155
+
156
+ tries = DEFAULT_NUMBER_OF_RETRIES_ALLOWED
157
+ wait_time = DEFAULT_WAIT_SECONDS_BEFORE_RETRY
158
+
159
+ for i in range(tries):
160
+ try:
161
+ dataset = load_dataset(*args, **kwargs)
162
+ except ConnectionError:
163
+ if i < tries - 1: # i is zero indexed
164
+ sleep(wait_time)
165
+ continue
166
+ else:
167
+ raise RuntimeError(
168
+ "Wait for a minute and attempt downloading the dataset again. \
169
+ The server hosting the dataset occassionally times out."
170
+ )
171
+ break
172
+
173
+ return dataset
174
+
175
+ def _get_dataset_info(self, data_dict: DatasetDict) -> DatasetInfo:
176
+ """
177
+ Get the information set from the dataset
178
+ The information set contains: dataset name, description, version, citation and licence
179
+ :param data_dict: DatasetDict
180
+ :rtype: DatasetInfo
181
+ """
182
+ return data_dict["train"].info
183
+
184
+ @abstractmethod
185
+ def _process_data(self, dataset: Dataset) -> Generator[SummInstance, None, None]:
186
+ """
187
+ Abstract class method to process the data contained within each dataset.
188
+ Each dataset class processes it's own information differently due to the diversity in domains
189
+ This method processes the data contained in the dataset
190
+ and puts each data instance into a SummInstance object,
191
+ the SummInstance has the following properties [source, summary, query[optional]]
192
+ :param dataset: a train/validation/test dataset
193
+ :rtype: a generator yielding SummInstance objects
194
+ """
195
+ return
196
+
197
+ def _generate_missing_val_test_splits(
198
+ self, dataset_dict: DatasetDict, seed: int
199
+ ) -> DatasetDict:
200
+ """
201
+ Creating the train, val and test splits from a dataset
202
+ the generated sets are 'train: ~.80', 'validation: ~.10', and 'test: ~10' in size
203
+ the splits are randomized for each object unless a seed is provided for the random generator
204
+
205
+ :param dataset: Arrow Dataset with containing, usually the train set
206
+ :param seed: seed for the random generator to shuffle the dataset
207
+ :rtype: Arrow DatasetDict containing the three splits
208
+ """
209
+
210
+ # Return dataset if no train set available for splitting
211
+ if "train" not in dataset_dict:
212
+ if "validation" not in dataset_dict:
213
+ dataset_dict["validation"] = None
214
+ if "test" not in dataset_dict:
215
+ dataset_dict["test"] = None
216
+
217
+ return dataset_dict
218
+
219
+ # Create a 'test' split from 'train' if no 'test' set is available
220
+ if "test" not in dataset_dict:
221
+ dataset_traintest_split = dataset_dict["train"].train_test_split(
222
+ test_size=TEST_OR_VAL_SPLIT_RATIO, seed=seed
223
+ )
224
+ dataset_dict["train"] = dataset_traintest_split["train"]
225
+ dataset_dict["test"] = dataset_traintest_split["test"]
226
+
227
+ # Create a 'validation' split from the remaining 'train' set if no 'validation' set is available
228
+ if "validation" not in dataset_dict:
229
+ dataset_trainval_split = dataset_dict["train"].train_test_split(
230
+ test_size=TEST_OR_VAL_SPLIT_RATIO, seed=seed
231
+ )
232
+ dataset_dict["train"] = dataset_trainval_split["train"]
233
+ dataset_dict["validation"] = dataset_trainval_split["test"]
234
+
235
+ return dataset_dict
236
+
237
+ def _concatenate_dataset_dicts(
238
+ self, dataset_dicts: List[DatasetDict]
239
+ ) -> DatasetDict:
240
+ """
241
+ Concatenate two dataset dicts with similar splits and columns tinto one
242
+ :param dataset_dicts: A list of DatasetDicts
243
+ :rtype: DatasetDict containing the combined data
244
+ """
245
+
246
+ # Ensure all dataset dicts have the same splits
247
+ setsofsplits = set(tuple(dataset_dict.keys()) for dataset_dict in dataset_dicts)
248
+ if len(setsofsplits) > 1:
249
+ raise ValueError("Splits must match for all datasets")
250
+
251
+ # Concatenate all datasets into one according to the splits
252
+ temp_dict = {}
253
+ for split in setsofsplits.pop():
254
+ split_set = [dataset_dict[split] for dataset_dict in dataset_dicts]
255
+ temp_dict[split] = concatenate_datasets(split_set)
256
+
257
+ return DatasetDict(temp_dict)
258
+
259
+ @classmethod
260
+ def generate_basic_description(cls) -> str:
261
+ """
262
+ Automatically generate the basic description string based on the attributes
263
+ :rtype: string containing the description
264
+ :param cls: class object
265
+ """
266
+
267
+ basic_description = (
268
+ f": {cls.dataset_name} is a "
269
+ f"{'query-based ' if cls.is_query_based else ''}"
270
+ f"{'dialogue ' if cls.is_dialogue_based else ''}"
271
+ f"{'multi-document' if cls.is_multi_document else 'single-document'} "
272
+ f"summarization dataset."
273
+ )
274
+
275
+ return basic_description
276
+
277
+ def show_description(self):
278
+ """
279
+ Print the description of the dataset.
280
+ """
281
+ print(self.dataset_name, ":\n", self.description)
dependencies.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ Migrate information to documentation/pypi for first release.
2
+
3
+ Dependencies:
4
+ - lexrank
5
+ - sentencepiece
6
+ - torch
7
+ - transformers
8
+
9
+ # datasets
10
+ - datasets
11
+ - py7zr
dist/SummerTime-0.1-py3-none-any.whl ADDED
Binary file (1.42 kB). View file
download.py ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ import nltk
2
+
3
+ nltk.download("stopwords")
evaluation/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import site
2
+ import os
3
+
4
+ # needed so that rouge works
5
+ package_path = site.getsitepackages()[0]
6
+ os.environ["ROUGE_HOME"] = package_path + "/summ_eval/ROUGE-1.5.5/"
7
+
8
+ from .rouge_metric import Rouge
9
+ from .bertscore_metric import BertScore
10
+ from .rougewe_metric import RougeWe
11
+ from .bleu_metric import Bleu
12
+ from .meteor_metric import Meteor
13
+
14
+ SUPPORTED_EVALUATION_METRICS = [BertScore, Bleu, Rouge, RougeWe, Meteor]
evaluation/base_metric.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple, Dict
2
+
3
+
4
+ class SummMetric:
5
+ metric_name: str = None
6
+ range: Tuple[float, float] = None
7
+ higher_is_better: bool = None
8
+ requires_heavy_compute: bool = None
9
+
10
+ def evaluate(
11
+ self,
12
+ # TODO zhangir: integrate with dataset api
13
+ inputs: List[str],
14
+ targets: List[str],
15
+ keys: List[str],
16
+ ) -> Dict[str, float]:
17
+ """
18
+ All metrics should have this function.
19
+ :input: A list of summaries.
20
+ :target: A list of target summaries corresponding to each entry of input.
21
+ :keys: Which metrics to return,
22
+ e.g, ['rouge_1_f_score', 'rouge_2_f_score']
23
+ :return: A dictionary with keys metrics and values scores.
24
+ """
25
+ raise NotImplementedError(
26
+ "the base class for metrics shouldn't be instantiated!"
27
+ )
evaluation/bertscore_metric.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from summ_eval.bert_score_metric import BertScoreMetric
2
+ from evaluation.summeval_metric import SummEvalMetric
3
+ from typing import List, Dict
4
+
5
+
6
+ class BertScore(SummEvalMetric):
7
+ metric_name = "bert score"
8
+ range = (0, 1)
9
+ higher_is_better = True
10
+ requires_heavy_compute = True
11
+
12
+ def __init__(self):
13
+ se_metric = BertScoreMetric()
14
+ super(BertScore, self).__init__(se_metric)
15
+
16
+ def evaluate(
17
+ self, inputs: List[str], targets: List[str], keys: List[str] = ["bert_score_f1"]
18
+ ) -> Dict[str, float]:
19
+ # TODO zhangir: update when datasets api is merged
20
+ return super(BertScore, self).evaluate(inputs, targets, keys)
evaluation/bleu_metric.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from summ_eval.bleu_metric import BleuMetric
2
+ from evaluation.summeval_metric import SummEvalMetric
3
+ from typing import List, Dict
4
+
5
+
6
+ class Bleu(SummEvalMetric):
7
+ metric_name = "bleu"
8
+ range = (0, 100)
9
+ higher_is_better = True
10
+ requires_heavy_compute = False
11
+
12
+ def __init__(self):
13
+ se_metric = BleuMetric()
14
+ super(Bleu, self).__init__(se_metric)
15
+
16
+ def evaluate(
17
+ self, inputs: List[str], targets: List[str], keys: List[str] = ["bleu"]
18
+ ) -> Dict[str, float]:
19
+ # TODO zhangir: potentially update when dataset api is merged.
20
+ return super(Bleu, self).evaluate(inputs, targets, keys)
evaluation/meteor_metric.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base_metric import SummMetric
2
+ from typing import List, Dict
3
+ from nltk.translate import meteor_score as nltk_meteor
4
+ import nltk
5
+ import statistics
6
+
7
+
8
+ class Meteor(SummMetric):
9
+ metric_name = "meteor"
10
+ range = (0, 1)
11
+ higher_is_better = True
12
+ requires_heavy_compute = False
13
+
14
+ def __init__(self):
15
+ nltk.download("wordnet")
16
+
17
+ def evaluate(
18
+ self, inputs: List[str], targets: List[str], keys=["meteor"]
19
+ ) -> Dict[str, float]:
20
+
21
+ for key in keys:
22
+ if key != "meteor":
23
+ raise KeyError(key, "is not a valid key")
24
+
25
+ meteor_scores = [
26
+ nltk_meteor.meteor_score([input], target)
27
+ for input, target in zip(inputs, targets)
28
+ ]
29
+ meteor_score = statistics.mean(meteor_scores)
30
+
31
+ return {key: meteor_score for key in keys}
evaluation/rouge_metric.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from summ_eval.rouge_metric import RougeMetric
2
+ from evaluation.summeval_metric import SummEvalMetric
3
+ from typing import List, Dict
4
+
5
+
6
+ class Rouge(SummEvalMetric):
7
+ metric_name = "rouge"
8
+ range = (0, 1)
9
+ higher_is_better = True
10
+ requires_heavy_compute = False
11
+
12
+ def __init__(self):
13
+ se_metric = RougeMetric()
14
+ super(Rouge, self).__init__(se_metric)
15
+
16
+ def evaluate(
17
+ self,
18
+ inputs: List[str],
19
+ targets: List[str],
20
+ keys: List[str] = ["rouge_1_f_score", "rouge_2_f_score", "rouge_l_f_score"],
21
+ ) -> Dict[str, float]:
22
+ score_dict = self.se_metric.evaluate_batch(inputs, targets)
23
+ return {key: score_dict["rouge"][key] for key in keys}
evaluation/rougewe_metric.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from evaluation.summeval_metric import SummEvalMetric
2
+ from typing import List, Dict
3
+
4
+ import nltk
5
+
6
+
7
+ class RougeWe(SummEvalMetric):
8
+ metric_name = "rougeWE"
9
+ range = (0, 1)
10
+ higher_is_better = True
11
+ requires_heavy_compute = True
12
+
13
+ def __init__(self):
14
+ from summ_eval.rouge_we_metric import RougeWeMetric
15
+
16
+ nltk.download("stopwords")
17
+ se_metric = RougeWeMetric()
18
+ super(RougeWe, self).__init__(se_metric)
19
+
20
+ def evaluate(
21
+ self, inputs: List[str], targets: List[str], keys: List[str] = ["rouge_we_3_f"]
22
+ ) -> Dict[str, float]:
23
+ # TODO zhangir: update when dataset api is merged.
24
+ return super(RougeWe, self).evaluate(inputs, targets, keys)
evaluation/summeval_metric.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base_metric import SummMetric
2
+ from summ_eval.metric import Metric as SEMetric
3
+ from typing import List, Dict
4
+
5
+
6
+ class SummEvalMetric(SummMetric):
7
+ """
8
+ Generic class for a summarization metric whose backend is SummEval.
9
+ """
10
+
11
+ def __init__(self, se_metric: SEMetric):
12
+ self.se_metric = se_metric
13
+
14
+ def evaluate(
15
+ self, inputs: List[str], targets: List[str], keys: List[str]
16
+ ) -> Dict[str, float]:
17
+ score_dict = self.se_metric.evaluate_batch(inputs, targets)
18
+ return {key: score_dict[key] if key in score_dict else None for key in keys}
model/__init__.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .single_doc import (
2
+ BartModel,
3
+ LexRankModel,
4
+ LongformerModel,
5
+ PegasusModel,
6
+ TextRankModel,
7
+ )
8
+ from .multi_doc import MultiDocJointModel, MultiDocSeparateModel
9
+ from .dialogue import HMNetModel
10
+ from .query_based import TFIDFSummModel, BM25SummModel
11
+ from .defaults import summarizer
12
+
13
+ SUPPORTED_SUMM_MODELS = [
14
+ BartModel,
15
+ LexRankModel,
16
+ LongformerModel,
17
+ PegasusModel,
18
+ TextRankModel,
19
+ MultiDocJointModel,
20
+ MultiDocSeparateModel,
21
+ HMNetModel,
22
+ TFIDFSummModel,
23
+ BM25SummModel,
24
+ ]
25
+
26
+
27
+ def list_all_models():
28
+ all_model_tuples = []
29
+ for model_class in SUPPORTED_SUMM_MODELS:
30
+ model_description = model_class.generate_basic_description()
31
+
32
+ all_model_tuples.append((model_class, model_description))
33
+
34
+ return all_model_tuples
model/base_model.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Union
2
+
3
+
4
+ class SummModel:
5
+ """
6
+ Base model class for SummerTime
7
+ """
8
+
9
+ # static variables
10
+ model_name = "None"
11
+ is_extractive = False
12
+ is_neural = False
13
+ is_query_based = False
14
+ is_dialogue_based = False
15
+ is_multi_document = False
16
+
17
+ def __init__(
18
+ self,
19
+ trained_domain: str = None,
20
+ max_input_length: int = None,
21
+ max_output_length: int = None,
22
+ ):
23
+ self.trained_domain = trained_domain
24
+ self.max_input_length = max_input_length
25
+ self.max_output_length = max_output_length
26
+
27
+ def summarize(
28
+ self, corpus: Union[List[str], List[List[str]]], queries: List[str] = None
29
+ ) -> List[str]:
30
+ """
31
+ All summarization models should have this function
32
+
33
+ :param corpus: each string in the list is a source document to be summarized; if the model is multi-document or
34
+ dialogue summarization model, then each instance contains a list of documents/utterances
35
+ :param queries: a list of queries if this is a query-based model
36
+ :return: a list of generated summaries
37
+ """
38
+ raise NotImplementedError(
39
+ "The base class for models shouldn't be instantiated!"
40
+ )
41
+
42
+ @classmethod
43
+ def assert_summ_input_type(
44
+ cls, corpus: Union[List[str], List[List[str]]], queries: Union[List[str], None]
45
+ ):
46
+ """
47
+ Verifies that type of input corpus or queries for summarization align with the model type.
48
+ """
49
+ raise NotImplementedError(
50
+ "The base class for models shouldn't be instantiated!"
51
+ )
52
+
53
+ @classmethod
54
+ def show_capability(cls) -> None:
55
+ """
56
+ Use concise language to show the strength and weakness for each model. Try not to use NLP terminologies
57
+ """
58
+ raise NotImplementedError(
59
+ "The base class for models shouldn't be instantiated!"
60
+ )
61
+
62
+ @classmethod
63
+ def generate_basic_description(cls) -> str:
64
+ """
65
+ Automatically generate the basic description string based on the attributes
66
+ """
67
+ extractive_abstractive = "extractive" if cls.is_extractive else "abstractive"
68
+ neural = "neural" if cls.is_neural else "non-neural"
69
+
70
+ basic_description = (
71
+ f"{cls.model_name} is a"
72
+ f"{'query-based' if cls.is_query_based else ''} "
73
+ f"{extractive_abstractive}, {neural} model for summarization."
74
+ )
75
+ if cls.is_multi_document or cls.is_dialogue_based:
76
+ basic_description += (
77
+ f"It can handle {'multi-document' if cls.is_multi_document else ''} "
78
+ f"{'dialogue' if cls.is_dialogue_based else ''} textual data."
79
+ )
80
+
81
+ return basic_description
model/defaults.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .single_doc import PegasusModel
2
+
3
+
4
+ class summarizer(PegasusModel):
5
+ def __init__(self, device="cpu"):
6
+ super(summarizer, self).__init__(device)
7
+
8
+ def show_capability(self):
9
+ print("Pegasus is the default singe-document summarization model.")
10
+ super(summarizer, self).show_capability()
model/dialogue/__init__.py ADDED
@@ -0,0 +1 @@
 
1
+ from .hmnet_model import HMNetModel
model/dialogue/hmnet/ExampleRawData/meeting_summarization/AMI_proprec/test_ami.json ADDED
@@ -0,0 +1 @@
 
1
+ [{"source": {"dataset": "../ExampleRawData/meeting_summarization/AMI_proprec/test/"}, "task": "meeting", "name": "ami"}]
model/dialogue/hmnet/ExampleRawData/meeting_summarization/role_dict_ext.json ADDED
@@ -0,0 +1 @@
 
1
+ {}
model/dialogue/hmnet/config/dialogue.conf ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ##################
2
+ # Trainer settings
3
+ ##################
4
+
5
+ MODEL MeetingNet_Transformer
6
+ TASK HMNet
7
+ CRITERION MLECriterion
8
+
9
+ SEED 1033
10
+
11
+ MAX_NUM_EPOCHS 20
12
+ EVAL_PER_UPDATE_NUM 10
13
+ UPDATES_PER_EPOCH 20
14
+
15
+ # The actuall learning rate will be multiplied with the number of GPUs
16
+ OPTIMIZER RAdam
17
+ START_LEARNING_RATE 1e-3
18
+ LR_SCHEDULER LnrWrmpInvSqRtDcyScheduler
19
+ WARMUP_STEPS 16000
20
+ WARMUP_INIT_LR 1e-4
21
+ WARMUP_END_LR 1e-3
22
+
23
+ # The actuall start learning rate equals START_LEARNING_RATE * GRADIENT_ACCUMULATE_STEP
24
+ # Model will be updated after every MINI_BATCH * GRADIENT_ACCUMULATE_STEP samples
25
+ GRADIENT_ACCUMULATE_STEP 5
26
+
27
+ GRAD_CLIPPING 2
28
+
29
+ ##################
30
+ # Task settings
31
+ ##################
32
+
33
+ # This is the relative path to the directory where this conf file locates
34
+ USE_REL_DATA_PATH
35
+ TRAIN_FILE ../ExampleRawData/meeting_summarization/AMI_proprec/train_ami.json
36
+ DEV_FILE ../ExampleRawData/meeting_summarization/AMI_proprec/valid_ami.json
37
+ TEST_FILE ../ExampleRawData/meeting_summarization/AMI_proprec/test_ami.json
38
+ ROLE_DICT_FILE ../ExampleRawData/meeting_summarization/role_dict_ext.json
39
+
40
+ MINI_BATCH 1
41
+ MAX_PADDING_RATIO 1
42
+ BATCH_READ_AHEAD 10
43
+ DOC_SHUFFLE_BUF_SIZE 10
44
+ SAMPLE_SHUFFLE_BUFFER_SIZE 10
45
+ BATCH_SHUFFLE_BUFFER_SIZE 10
46
+
47
+ MAX_TRANSCRIPT_WORD 8300
48
+ #MAX_SENT_LEN 30
49
+ MAX_SENT_LEN 12
50
+ # MAX_SENT_NUM 300
51
+ MAX_SENT_NUM 60
52
+
53
+ ##################
54
+ # Model settings
55
+ ##################
56
+
57
+ DROPOUT 0.1
58
+ VOCAB_DIM 512
59
+ ROLE_SIZE 32
60
+ ROLE_DIM 16
61
+ POS_DIM 16
62
+ ENT_DIM 16
63
+
64
+ USE_ROLE
65
+ USE_POSENT
66
+
67
+ USE_BOS_TOKEN
68
+ USE_EOS_TOKEN
69
+
70
+ TRANSFORMER_EMBED_DROPOUT 0.1
71
+ TRANSFORMER_RESIDUAL_DROPOUT 0.1
72
+ TRANSFORMER_ATTENTION_DROPOUT 0.1
73
+ TRANSFORMER_LAYER 6
74
+ TRANSFORMER_HEAD 8
75
+ TRANSFORMER_POS_DISCOUNT 80
76
+
77
+ PRE_TOKENIZER TransfoXLTokenizer
78
+ PRE_TOKENIZER_PATH ../../../third_party/HMNet/ExampleInitModel/transfo-xl-wt103
79
+ PYLEARN_MODEL ../../../third_party/HMNet/ExampleInitModel/AMI-finetuned
80
+ # e.g. PYLEARN_MODEL conf_hmnet_AMI_conf~/run_1/11600
81
+
82
+ ##################
83
+ # Tokenizer settings
84
+ ##################
85
+
86
+ EXTRA_IDS 1000
87
+
88
+ ##################
89
+ # Decoding settings
90
+ ##################
91
+
92
+ BEAM_WIDTH 6
93
+ EVAL_TOKENIZED
94
+ EVAL_LOWERCASE
95
+ # MAX_GEN_LENGTH 300
96
+ MAX_GEN_LENGTH 60
97
+ MIN_GEN_LENGTH 10
98
+ NO_REPEAT_NGRAM_SIZE 3
model/dialogue/hmnet_model.py ADDED
@@ -0,0 +1,483 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from model.base_model import SummModel
2
+ import argparse
3
+ import os
4
+ import torch
5
+ import gzip
6
+ import json
7
+ from model.third_party.HMNet.Models.Trainers.HMNetTrainer import HMNetTrainer
8
+ from model.third_party.HMNet.Utils.Arguments import Arguments
9
+
10
+ import spacy
11
+
12
+ nlp = spacy.load("en_core_web_sm", disable=["parser"])
13
+ # tagger = nlp.get_pipe('tagger')
14
+ # ner = nlp.get_pipe('ner')
15
+ # POS = {w: i for i, w in enumerate([''] + list(tagger.labels))}
16
+ # ENT = {w: i for i, w in enumerate([''] + list(ner.move_names))}
17
+ # These two dicts are adapted from SpaCy 2.3.1, since HMNet's embedding for POS and ENT is fixed
18
+ POS = {
19
+ "": 0,
20
+ "$": 1,
21
+ "''": 2,
22
+ ",": 3,
23
+ "-LRB-": 4,
24
+ "-RRB-": 5,
25
+ ".": 6,
26
+ ":": 7,
27
+ "ADD": 8,
28
+ "AFX": 9,
29
+ "CC": 10,
30
+ "CD": 11,
31
+ "DT": 12,
32
+ "EX": 13,
33
+ "FW": 14,
34
+ "HYPH": 15,
35
+ "IN": 16,
36
+ "JJ": 17,
37
+ "JJR": 18,
38
+ "JJS": 19,
39
+ "LS": 20,
40
+ "MD": 21,
41
+ "NFP": 22,
42
+ "NN": 23,
43
+ "NNP": 24,
44
+ "NNPS": 25,
45
+ "NNS": 26,
46
+ "PDT": 27,
47
+ "POS": 28,
48
+ "PRP": 29,
49
+ "PRP$": 30,
50
+ "RB": 31,
51
+ "RBR": 32,
52
+ "RBS": 33,
53
+ "RP": 34,
54
+ "SYM": 35,
55
+ "TO": 36,
56
+ "UH": 37,
57
+ "VB": 38,
58
+ "VBD": 39,
59
+ "VBG": 40,
60
+ "VBN": 41,
61
+ "VBP": 42,
62
+ "VBZ": 43,
63
+ "WDT": 44,
64
+ "WP": 45,
65
+ "WP$": 46,
66
+ "WRB": 47,
67
+ "XX": 48,
68
+ "_SP": 49,
69
+ "``": 50,
70
+ }
71
+ ENT = {
72
+ "": 0,
73
+ "B-ORG": 1,
74
+ "B-DATE": 2,
75
+ "B-PERSON": 3,
76
+ "B-GPE": 4,
77
+ "B-MONEY": 5,
78
+ "B-CARDINAL": 6,
79
+ "B-NORP": 7,
80
+ "B-PERCENT": 8,
81
+ "B-WORK_OF_ART": 9,
82
+ "B-LOC": 10,
83
+ "B-TIME": 11,
84
+ "B-QUANTITY": 12,
85
+ "B-FAC": 13,
86
+ "B-EVENT": 14,
87
+ "B-ORDINAL": 15,
88
+ "B-PRODUCT": 16,
89
+ "B-LAW": 17,
90
+ "B-LANGUAGE": 18,
91
+ "I-ORG": 19,
92
+ "I-DATE": 20,
93
+ "I-PERSON": 21,
94
+ "I-GPE": 22,
95
+ "I-MONEY": 23,
96
+ "I-CARDINAL": 24,
97
+ "I-NORP": 25,
98
+ "I-PERCENT": 26,
99
+ "I-WORK_OF_ART": 27,
100
+ "I-LOC": 28,
101
+ "I-TIME": 29,
102
+ "I-QUANTITY": 30,
103
+ "I-FAC": 31,
104
+ "I-EVENT": 32,
105
+ "I-ORDINAL": 33,
106
+ "I-PRODUCT": 34,
107
+ "I-LAW": 35,
108
+ "I-LANGUAGE": 36,
109
+ "L-ORG": 37,
110
+ "L-DATE": 38,
111
+ "L-PERSON": 39,
112
+ "L-GPE": 40,
113
+ "L-MONEY": 41,
114
+ "L-CARDINAL": 42,
115
+ "L-NORP": 43,
116
+ "L-PERCENT": 44,
117
+ "L-WORK_OF_ART": 45,
118
+ "L-LOC": 46,
119
+ "L-TIME": 47,
120
+ "L-QUANTITY": 48,
121
+ "L-FAC": 49,
122
+ "L-EVENT": 50,
123
+ "L-ORDINAL": 51,
124
+ "L-PRODUCT": 52,
125
+ "L-LAW": 53,
126
+ "L-LANGUAGE": 54,
127
+ "U-ORG": 55,
128
+ "U-DATE": 56,
129
+ "U-PERSON": 57,
130
+ "U-GPE": 58,
131
+ "U-MONEY": 59,
132
+ "U-CARDINAL": 60,
133
+ "U-NORP": 61,
134
+ "U-PERCENT": 62,
135
+ "U-WORK_OF_ART": 63,
136
+ "U-LOC": 64,
137
+ "U-TIME": 65,
138
+ "U-QUANTITY": 66,
139
+ "U-FAC": 67,
140
+ "U-EVENT": 68,
141
+ "U-ORDINAL": 69,
142
+ "U-PRODUCT": 70,
143
+ "U-LAW": 71,
144
+ "U-LANGUAGE": 72,
145
+ "O": 73,
146
+ }
147
+
148
+
149
+ class HMNetModel(SummModel):
150
+ # static variables
151
+ model_name = "HMNET"
152
+ is_extractive = False
153
+ is_neural = True
154
+ is_dialogue_based = True
155
+
156
+ def __init__(
157
+ self,
158
+ min_gen_length: int = 10,
159
+ max_gen_length: int = 300,
160
+ beam_width: int = 6,
161
+ **kwargs,
162
+ ):
163
+ """
164
+ Create a summarization model with HMNet backbone. In the default setting, the inference speed will be
165
+ 10s/sample (on one GPU), however, if one can tune these three parameters properly, e.g. min_gen_length=10,
166
+ max_gen_length=100, and beam_width=2, the inference speed will increase to 2s/sample (on one GPU).
167
+
168
+ Args:
169
+ min_gen_length (int): minimum generation length of the decoder
170
+ max_gen_length (int): maximum generation length of the decoder
171
+ beam_width (int): width of the beam when doing beam search in the decoding process
172
+ kwargs: the other valid parameters. The valid parameters can be found in
173
+ model/dialogue/hmnet/config/dialogue.conf . You can use either lower case or upper case for parameter
174
+ name. The valid parameter name is one of the following args, however, we do not encourage you to modify
175
+ them, since some unexpected, untested errors might be triggered:
176
+ ['MODEL', 'TASK', 'CRITERION', 'SEED', 'MAX_NUM_EPOCHS', 'EVAL_PER_UPDATE_NUM'
177
+ , 'UPDATES_PER_EPOCH', 'OPTIMIZER', 'START_LEARNING_RATE', 'LR_SCHEDULER', 'WARMUP_STEPS',
178
+ 'WARMUP_INIT_LR', 'WARMUP_END_LR', 'GRADIENT_ACCUMULATE_STEP', 'GRAD_CLIPPING', 'USE_REL_DATA_PATH',
179
+ 'TRAIN_FILE', 'DEV_FILE', 'TEST_FILE', 'ROLE_DICT_FILE', 'MINI_BATCH', 'MAX_PADDING_RATIO',
180
+ 'BATCH_READ_AHEAD', 'DOC_SHUFFLE_BUF_SIZE', 'SAMPLE_SHUFFLE_BUFFER_SIZE', 'BATCH_SHUFFLE_BUFFER_SIZE',
181
+ 'MAX_TRANSCRIPT_WORD', 'MAX_SENT_LEN', 'MAX_SENT_NUM', 'DROPOUT', 'VOCAB_DIM', 'ROLE_SIZE', 'ROLE_DIM',
182
+ 'POS_DIM', 'ENT_DIM', 'USE_ROLE', 'USE_POSENT', 'USE_BOS_TOKEN', 'USE_EOS_TOKEN',
183
+ 'TRANSFORMER_EMBED_DROPOUT', 'TRANSFORMER_RESIDUAL_DROPOUT', 'TRANSFORMER_ATTENTION_DROPOUT',
184
+ 'TRANSFORMER_LAYER', 'TRANSFORMER_HEAD', 'TRANSFORMER_POS_DISCOUNT', 'PRE_TOKENIZER',
185
+ 'PRE_TOKENIZER_PATH', 'PYLEARN_MODEL', 'EXTRA_IDS', 'BEAM_WIDTH', 'EVAL_TOKENIZED', 'EVAL_LOWERCASE',
186
+ 'MAX_GEN_LENGTH', 'MIN_GEN_LENGTH', 'NO_REPEAT_NGRAM_SIZE']
187
+
188
+ Return an instance of HMNet model for dialogue summarization.
189
+ """
190
+ super(HMNetModel, self).__init__()
191
+ self.root_path = self._get_root()
192
+
193
+ # we leave the most influential params with prompt and the others as hidden kwargs
194
+ kwargs["MIN_GEN_LENGTH"] = min_gen_length
195
+ kwargs["MAX_GEN_LENGTH"] = max_gen_length
196
+ kwargs["BEAM_WIDTH"] = beam_width
197
+ self.opt = self._parse_args(kwargs)
198
+ self.model = HMNetTrainer(self.opt)
199
+
200
+ def _get_root(self):
201
+ root_path = os.getcwd()
202
+ while "model" not in os.listdir(root_path):
203
+ root_path = os.path.dirname(root_path)
204
+ root_path = os.path.join(root_path, "model/dialogue")
205
+ return root_path
206
+
207
+ def _parse_args(self, kwargs):
208
+ parser = argparse.ArgumentParser(
209
+ description="HMNet: Pretrain or fine-tune models for HMNet model."
210
+ )
211
+ parser.add_argument(
212
+ "--command", default="evaluate", help="Command: train/evaluate"
213
+ )
214
+ parser.add_argument(
215
+ "--conf_file",
216
+ default=os.path.join(self.root_path, "hmnet/config/dialogue.conf"),
217
+ help="Path to the BigLearn conf file.",
218
+ )
219
+ parser.add_argument(
220
+ "--PYLEARN_MODEL", help="Overrides this option from the conf file."
221
+ )
222
+ parser.add_argument(
223
+ "--master_port", help="Overrides this option default", default=None
224
+ )
225
+ parser.add_argument("--cluster", help="local, philly or aml", default="local")
226
+ parser.add_argument(
227
+ "--dist_init_path", help="Distributed init path for AML", default="./tmp"
228
+ )
229
+ parser.add_argument(
230
+ "--fp16",
231
+ action="store_true",
232
+ help="Whether to use 16-bit float precision instead of 32-bit",
233
+ )
234
+ parser.add_argument(
235
+ "--fp16_opt_level",
236
+ type=str,
237
+ default="O1",
238
+ help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
239
+ "See details at https://nvidia.github.io/apex/amp.html",
240
+ )
241
+ parser.add_argument("--no_cuda", action="store_true", help="Disable cuda.")
242
+ parser.add_argument(
243
+ "--config_overrides",
244
+ help="Override parameters on config, VAR=val;VAR=val;...",
245
+ )
246
+
247
+ cmdline_args = parser.parse_args()
248
+ command = cmdline_args.command
249
+ conf_file = cmdline_args.conf_file
250
+ conf_args = Arguments(conf_file)
251
+ opt = conf_args.readArguments()
252
+
253
+ if cmdline_args.config_overrides:
254
+ for config_override in cmdline_args.config_overrides.split(";"):
255
+ config_override = config_override.strip()
256
+ if config_override:
257
+ var_val = config_override.split("=")
258
+ assert (
259
+ len(var_val) == 2
260
+ ), f"Config override '{var_val}' does not have the form 'VAR=val'"
261
+ conf_args.add_opt(opt, var_val[0], var_val[1], force_override=True)
262
+
263
+ opt["cuda"] = torch.cuda.is_available() and not cmdline_args.no_cuda
264
+ opt["confFile"] = conf_file
265
+ if "datadir" not in opt:
266
+ opt["datadir"] = os.path.dirname(
267
+ conf_file
268
+ ) # conf_file specifies where the data folder is
269
+ opt["basename"] = os.path.basename(
270
+ conf_file
271
+ ) # conf_file specifies where the name of save folder is
272
+ opt["command"] = command
273
+
274
+ # combine cmdline_args into opt dictionary
275
+ for key, val in cmdline_args.__dict__.items():
276
+ # if val is not None and key not in ['command', 'conf_file']:
277
+ if val is not None:
278
+ opt[key] = val
279
+
280
+ # combine kwargs into opt dictionary (we allow lower case)
281
+ for key, val in kwargs.items():
282
+ valid_keys = [x for x in opt.keys() if x.upper() == x]
283
+ if key.upper() not in valid_keys:
284
+ print("WARNING: {} is not a valid key in HMNet.".format(key))
285
+ print("The valid keys are:", valid_keys)
286
+ continue
287
+ if val is not None:
288
+ opt[key.upper()] = val
289
+
290
+ return opt
291
+
292
+ def summarize(self, corpus, queries=None):
293
+ print(f"HMNet model: processing document of {corpus.__len__()} samples")
294
+ # transform the original dataset to "dialogue" input
295
+ # we only use test set path for evaluation
296
+ data_folder = os.path.join(
297
+ os.path.dirname(self.opt["datadir"]),
298
+ "ExampleRawData/meeting_summarization/AMI_proprec/test",
299
+ )
300
+
301
+ self._create_datafolder(data_folder)
302
+ self._preprocess(corpus, data_folder)
303
+
304
+ # return self.model.eval()
305
+ results = self._evaluate()
306
+
307
+ return results
308
+
309
+ def _evaluate(self):
310
+ if self.opt["rank"] == 0:
311
+ self.model.log("-----------------------------------------------")
312
+ self.model.log("Evaluating model ... ")
313
+
314
+ self.model.set_up_model()
315
+
316
+ eval_dataset = "test"
317
+ batch_generator_eval = self.model.get_batch_generator(eval_dataset)
318
+ predictions = self._eval_batches(
319
+ self.model.module, batch_generator_eval, self.model.saveFolder, eval_dataset
320
+ )
321
+
322
+ return predictions
323
+
324
+ def _eval_batches(self, module, dev_batches, save_folder, label=""):
325
+ max_sent_len = int(self.opt["MAX_GEN_LENGTH"])
326
+
327
+ print("Decoding current model ... \nSaving folder is {}".format(save_folder))
328
+ print("Each sample will cost about 10 second.")
329
+ import time
330
+
331
+ start_time = time.time()
332
+ predictions = [] # prediction of tokens from model
333
+ if not isinstance(module.tokenizer, list):
334
+ decoder_tokenizer = module.tokenizer
335
+ elif len(module.tokenizer) == 1:
336
+ decoder_tokenizer = module.tokenizer[0]
337
+ elif len(module.tokenizer) == 2:
338
+ decoder_tokenizer = module.tokenizer[1]
339
+ else:
340
+ assert False, "len(module.tokenizer) > 2"
341
+
342
+ with torch.no_grad():
343
+ for j, dev_batch in enumerate(dev_batches):
344
+ for b in dev_batch:
345
+ if torch.is_tensor(dev_batch[b]):
346
+ dev_batch[b] = dev_batch[b].to(self.opt["device"])
347
+
348
+ beam_search_res = module(
349
+ dev_batch, beam_search=True, max_sent_len=max_sent_len
350
+ )
351
+ pred = [
352
+ [t[0] for t in x] if len(x) > 0 else [[]] for x in beam_search_res
353
+ ]
354
+ predictions.extend(
355
+ [
356
+ [
357
+ self._convert_tokens_to_string(decoder_tokenizer, tt)
358
+ for tt in t
359
+ ]
360
+ for t in pred
361
+ ]
362
+ )
363
+
364
+ if (
365
+ "DEBUG" in self.opt and j >= 10
366
+ ) or j >= self.model.task.evaluator.eval_batches_num:
367
+ # in debug mode (decode first 10 batches) ortherwise decode first self.eval_batches_num bathes
368
+ break
369
+
370
+ top1_predictions = [x[0] for x in predictions]
371
+
372
+ print("Total time for inference:", time.time() - start_time)
373
+ return top1_predictions
374
+
375
+ def _convert_tokens_to_string(self, tokenizer, tokens):
376
+ if "EVAL_TOKENIZED" in self.opt:
377
+ tokens = [t for t in tokens if t not in tokenizer.all_special_tokens]
378
+ if "EVAL_LOWERCASE" in self.opt:
379
+ tokens = [t.lower() for t in tokens]
380
+ if "EVAL_TOKENIZED" in self.opt:
381
+ return " ".join(tokens)
382
+ else:
383
+ return tokenizer.decode(
384
+ tokenizer.convert_tokens_to_ids(tokens), skip_special_tokens=True
385
+ )
386
+
387
+ def _preprocess(self, corpus, test_path):
388
+ samples = []
389
+ for i, sample in enumerate(corpus):
390
+ new_sample = {"id": i, "meeting": [], "summary": []}
391
+ if isinstance(sample, str):
392
+ raise RuntimeError(
393
+ "Error: the input of HMNet should be dialogues, rather than documents."
394
+ )
395
+
396
+ # add all the turns one by one
397
+ for turn in sample:
398
+ turn = [x.strip() for x in turn.split(":")]
399
+ if len(turn) < 2:
400
+ continue
401
+ tokenized_turn = nlp(turn[1])
402
+ # In case we can't find proper entity in move_names
403
+ ent_id = []
404
+ pos_id = []
405
+ for token in tokenized_turn:
406
+ ent = (
407
+ token.ent_iob_ + "-" + token.ent_type_
408
+ if token.ent_iob_ != "O"
409
+ else "O"
410
+ )
411
+ ent_id.append(ENT[ent] if ent in ENT else ENT[""])
412
+
413
+ pos = token.tag_
414
+ pos_id.append(POS[pos] if pos in POS else POS[""])
415
+
416
+ new_sample["meeting"].append(
417
+ {
418
+ "speaker": turn[0],
419
+ "role": "",
420
+ "utt": {
421
+ "word": [str(token) for token in tokenized_turn],
422
+ "pos_id": pos_id,
423
+ "ent_id": ent_id,
424
+ },
425
+ }
426
+ )
427
+ new_sample["summary"].append(
428
+ "This is a dummy summary. HMNet will filter out the sample w/o summary!"
429
+ )
430
+ samples.append(new_sample)
431
+ # save to the gzip
432
+ file_path = os.path.join(test_path, "split_{}.jsonl.gz".format(i))
433
+ with gzip.open(file_path, "wt", encoding="utf-8") as file:
434
+ file.write(json.dumps(new_sample))
435
+
436
+ def _clean_datafolder(self, data_folder):
437
+ for name in os.listdir(data_folder):
438
+ name = os.path.join(data_folder, name)
439
+ if ".gz" in name:
440
+ os.remove(name)
441
+
442
+ def _create_datafolder(self, data_folder):
443
+ if os.path.exists(data_folder):
444
+ self._clean_datafolder(data_folder)
445
+ else:
446
+ os.makedirs(data_folder)
447
+ with open(
448
+ os.path.join(os.path.dirname(data_folder), "test_ami.json"),
449
+ "w",
450
+ encoding="utf-8",
451
+ ) as file:
452
+ json.dump(
453
+ [
454
+ {
455
+ "source": {
456
+ "dataset": "../ExampleRawData/meeting_summarization/AMI_proprec/test/"
457
+ },
458
+ "task": "meeting",
459
+ "name": "ami",
460
+ }
461
+ ],
462
+ file,
463
+ )
464
+
465
+ with open(
466
+ os.path.join(
467
+ os.path.dirname(os.path.dirname(data_folder)), "role_dict_ext.json"
468
+ ),
469
+ "w",
470
+ ) as file:
471
+ json.dump({}, file)
472
+
473
+ @classmethod
474
+ def show_capability(cls) -> None:
475
+ basic_description = cls.generate_basic_description()
476
+ more_details = (
477
+ "A HMNet model finetuned on CNN-DM dataset for summarization.\n\n"
478
+ "Strengths:\n - High performance on dialogue summarization task.\n\n"
479
+ "Weaknesses:\n - Not suitable for datasets other than dialogues.\n\n"
480
+ "Initialization arguments:\n "
481
+ " - `corpus`: Unlabelled corpus of documents.\n"
482
+ )
483
+ print(f"{basic_description} \n {'#' * 20} \n {more_details}")
model/multi_doc/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
1
+ from .multi_doc_joint_model import MultiDocJointModel
2
+ from .multi_doc_separate_model import MultiDocSeparateModel
model/multi_doc/base_multi_doc_model.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from model.base_model import SummModel
2
+
3
+
4
+ class MultiDocSummModel(SummModel):
5
+
6
+ is_multi_document = True
7
+
8
+ def __init__(
9
+ self,
10
+ trained_domain: str = None,
11
+ max_input_length: int = None,
12
+ max_output_length: int = None,
13
+ ):
14
+ super(MultiDocSummModel, self).__init__(
15
+ trained_domain=trained_domain,
16
+ max_input_length=max_input_length,
17
+ max_output_length=max_output_length,
18
+ )
19
+
20
+ @classmethod
21
+ def assert_summ_input_type(cls, corpus, query):
22
+ if not all(
23
+ [
24
+ isinstance(ins, list) and all([isinstance(doc, str) for doc in ins])
25
+ for ins in corpus
26
+ ]
27
+ ):
28
+ raise TypeError(
29
+ "Multi-document summarization models summarize instances of multiple documents (`List[List[str]]`)."
30
+ )
31
+
32
+ if query is not None:
33
+ if not isinstance(query, list):
34
+ raise TypeError(
35
+ "Query-based single-document summarization requires query of `List[str]`."
36
+ )
37
+ if not all([isinstance(q, str) for q in query]):
38
+ raise TypeError(
39
+ "Query-based single-document summarization requires query of `List[str]`."
40
+ )
model/multi_doc/multi_doc_joint_model.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base_multi_doc_model import MultiDocSummModel
2
+ from model.base_model import SummModel
3
+ from model.single_doc import TextRankModel
4
+ from typing import Union, List
5
+
6
+
7
+ class MultiDocJointModel(MultiDocSummModel):
8
+
9
+ model_name = "Multi-document joint"
10
+ is_multi_document = True
11
+
12
+ def __init__(self, model_backend: SummModel = TextRankModel, **kwargs):
13
+ super(MultiDocJointModel, self).__init__()
14
+ model = model_backend(**kwargs)
15
+ self.model = model
16
+
17
+ def summarize(
18
+ self,
19
+ corpus: Union[List[str], List[List[str]]],
20
+ query: Union[List[str], List[List[str]]] = None,
21
+ ) -> List[str]:
22
+ self.assert_summ_input_type(corpus, None)
23
+ joint_corpus = []
24
+ for instance in corpus:
25
+ joint_corpus.append(" ".join(instance))
26
+
27
+ summaries = self.model.summarize(joint_corpus)
28
+
29
+ return summaries
30
+
31
+ @classmethod
32
+ def generate_basic_description(cls) -> str:
33
+ basic_description = (
34
+ "MultiDocJointModel performs multi-document summarization by"
35
+ " first concatenating all documents,"
36
+ " and then performing single-document summarization on the concatenation."
37
+ )
38
+ return basic_description
39
+
40
+ @classmethod
41
+ def show_capability(cls):
42
+ basic_description = cls.generate_basic_description()
43
+ more_details = (
44
+ "A multi-document summarization model."
45
+ " Allows for custom model backend selection at initialization."
46
+ " Concatenates each document in corpus and returns single-document summarization of joint corpus.\n"
47
+ "Strengths: \n - Allows for control of backend model.\n"
48
+ "Weaknesses: \n - Assumes all documents are equally weighted.\n"
49
+ " - May fail to extract information from certain documents.\n"
50
+ )
51
+ print(f"{basic_description}\n{'#' * 20}\n{more_details}")
model/multi_doc/multi_doc_separate_model.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base_multi_doc_model import MultiDocSummModel
2
+ from model.base_model import SummModel
3
+ from model.single_doc import TextRankModel
4
+ from typing import Union, List
5
+
6
+
7
+ class MultiDocSeparateModel(MultiDocSummModel):
8
+
9
+ model_name = "Multi-document separate"
10
+ is_multi_document = True
11
+
12
+ def __init__(self, model_backend: SummModel = TextRankModel, **kwargs):
13
+ super(MultiDocSeparateModel, self).__init__()
14
+ model = model_backend(**kwargs)
15
+ self.model = model
16
+
17
+ def summarize(
18
+ self,
19
+ corpus: Union[List[str], List[List[str]]],
20
+ query: Union[List[str], List[List[str]]] = None,
21
+ ) -> List[str]:
22
+ self.assert_summ_input_type(corpus, None)
23
+ summaries = []
24
+ for instance in corpus:
25
+ instance_summaries = self.model.summarize(instance)
26
+ summaries.append(" ".join(instance_summaries))
27
+
28
+ return summaries
29
+
30
+ @classmethod
31
+ def generate_basic_description(cls) -> str:
32
+ basic_description = (
33
+ "MultiDocSeparateModel performs multi-document summarization by"
34
+ " first performing single-document summarization on each document,"
35
+ " and then concatenating the results."
36
+ )
37
+ return basic_description
38
+
39
+ @classmethod
40
+ def show_capability(cls):
41
+ basic_description = cls.generate_basic_description()
42
+ more_details = (
43
+ "A multi-document summarization model."
44
+ " Allows for custom model backend selection at initialization."
45
+ " Performs single-document summarization on each document in corpus and returns concatenated result.\n"
46
+ "Strengths: \n - Allows for control of backend model.\n"
47
+ "Weaknesses: \n - Assumes all documents are equally weighted.\n - May produce redundant information for similar documents.\n"
48
+ )
49
+ print(f"{basic_description}\n{'#' * 20}\n{more_details}")
model/query_based/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
1
+ from .bm25_model import BM25SummModel
2
+ from .tf_idf_model import TFIDFSummModel
model/query_based/base_query_based_model.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from model.base_model import SummModel
2
+ from model.single_doc import TextRankModel
3
+ from typing import List, Union
4
+
5
+ from nltk import sent_tokenize, word_tokenize
6
+ from nltk.corpus import stopwords
7
+ from nltk.stem import PorterStemmer
8
+
9
+
10
+ class QueryBasedSummModel(SummModel):
11
+
12
+ is_query_based = True
13
+
14
+ def __init__(
15
+ self,
16
+ trained_domain: str = None,
17
+ max_input_length: int = None,
18
+ max_output_length: int = None,
19
+ model_backend: SummModel = TextRankModel,
20
+ retrieval_ratio: float = 0.5,
21
+ preprocess: bool = True,
22
+ **kwargs,
23
+ ):
24
+ super(QueryBasedSummModel, self).__init__(
25
+ trained_domain=trained_domain,
26
+ max_input_length=max_input_length,
27
+ max_output_length=max_output_length,
28
+ )
29
+ self.model = model_backend(**kwargs)
30
+ self.retrieval_ratio = retrieval_ratio
31
+ self.preprocess = preprocess
32
+
33
+ def _retrieve(self, instance: List[str], query: List[str], n_best) -> List[str]:
34
+ raise NotImplementedError()
35
+
36
+ def summarize(
37
+ self,
38
+ corpus: Union[List[str], List[List[str]]],
39
+ queries: List[str] = None,
40
+ ) -> List[str]:
41
+ self.assert_summ_input_type(corpus, queries)
42
+
43
+ retrieval_output = [] # List[str]
44
+ for instance, query in zip(corpus, queries):
45
+ if isinstance(instance, str):
46
+ is_dialogue = False
47
+ instance = sent_tokenize(instance)
48
+ else:
49
+ is_dialogue = True
50
+ query = [query]
51
+
52
+ # instance & query now are List[str] for sure
53
+ if self.preprocess:
54
+ preprocessor = Preprocessor()
55
+ instance = preprocessor.preprocess(instance)
56
+ query = preprocessor.preprocess(query)
57
+
58
+ n_best = max(int(len(instance) * self.retrieval_ratio), 1)
59
+ top_n_sent = self._retrieve(instance, query, n_best)
60
+
61
+ if not is_dialogue:
62
+ top_n_sent = " ".join(top_n_sent) # str
63
+ retrieval_output.append(top_n_sent)
64
+
65
+ summaries = self.model.summarize(
66
+ retrieval_output
67
+ ) # List[str] or List[List[str]]
68
+ return summaries
69
+
70
+ def generate_specific_description(self):
71
+ is_neural = self.model.is_neural & self.is_neural
72
+ is_extractive = self.model.is_extractive | self.is_extractive
73
+ model_name = "Pipeline with retriever: {}, summarizer: {}".format(
74
+ self.model_name, self.model.model_name
75
+ )
76
+
77
+ extractive_abstractive = "extractive" if is_extractive else "abstractive"
78
+ neural = "neural" if is_neural else "non-neural"
79
+
80
+ basic_description = (
81
+ f"{model_name} is a "
82
+ f"{'query-based' if self.is_query_based else ''} "
83
+ f"{extractive_abstractive}, {neural} model for summarization."
84
+ )
85
+
86
+ return basic_description
87
+
88
+ @classmethod
89
+ def assert_summ_input_type(cls, corpus, query):
90
+ if query is None:
91
+ raise TypeError(
92
+ "Query-based summarization models summarize instances of query-text pairs, however, query is missing."
93
+ )
94
+
95
+ if not isinstance(query, list):
96
+ raise TypeError(
97
+ "Query-based single-document summarization requires query of `List[str]`."
98
+ )
99
+ if not all([isinstance(q, str) for q in query]):
100
+ raise TypeError(
101
+ "Query-based single-document summarization requires query of `List[str]`."
102
+ )
103
+
104
+ @classmethod
105
+ def generate_basic_description(cls) -> str:
106
+ basic_description = (
107
+ "QueryBasedSummModel performs query-based summarization. Given a query-text pair,"
108
+ "the model will first extract the most relevant sentences in articles or turns in "
109
+ "dialogues, then use the single document summarization model to generate the summary"
110
+ )
111
+ return basic_description
112
+
113
+ @classmethod
114
+ def show_capability(cls):
115
+ basic_description = cls.generate_basic_description()
116
+ more_details = (
117
+ "A query-based summarization model."
118
+ " Allows for custom model backend selection at initialization."
119
+ " Retrieve relevant turns and then summarize the retrieved turns\n"
120
+ "Strengths: \n - Allows for control of backend model.\n"
121
+ "Weaknesses: \n - Heavily depends on the performance of both retriever and summarizer.\n"
122
+ )
123
+ print(f"{basic_description}\n{'#' * 20}\n{more_details}")
124
+
125
+
126
+ class Preprocessor:
127
+ def __init__(self, remove_stopwords=True, lower_case=True, stem=False):
128
+ self.sw = stopwords.words("english")
129
+ self.stemmer = PorterStemmer()
130
+ self.remove_stopwords = remove_stopwords
131
+ self.lower_case = lower_case
132
+ self.stem = stem
133
+
134
+ def preprocess(self, corpus: List[str]) -> List[str]:
135
+ if self.lower_case:
136
+ corpus = [sent.lower() for sent in corpus]
137
+ tokenized_corpus = [word_tokenize(sent) for sent in corpus]
138
+ if self.remove_stopwords:
139
+ tokenized_corpus = [
140
+ [word for word in sent if word not in self.sw]
141
+ for sent in tokenized_corpus
142
+ ]
143
+ if self.stem:
144
+ tokenized_corpus = [
145
+ [self.stemmer.stem(word) for word in sent] for sent in tokenized_corpus
146
+ ]
147
+ return [" ".join(sent) for sent in tokenized_corpus]
model/query_based/bm25_model.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base_query_based_model import QueryBasedSummModel
2
+ from model.base_model import SummModel
3
+ from model.single_doc import TextRankModel
4
+ from typing import List
5
+
6
+ from gensim.summarization.bm25 import BM25
7
+ from nltk import word_tokenize
8
+
9
+
10
+ class BM25SummModel(QueryBasedSummModel):
11
+
12
+ # static variables
13
+ model_name = "BM25"
14
+ is_extractive = True # only represents the retrieval part
15
+ is_neural = False # only represents the retrieval part
16
+ is_query_based = True
17
+
18
+ def __init__(
19
+ self,
20
+ trained_domain: str = None,
21
+ max_input_length: int = None,
22
+ max_output_length: int = None,
23
+ model_backend: SummModel = TextRankModel,
24
+ retrieval_ratio: float = 0.5,
25
+ preprocess: bool = True,
26
+ **kwargs
27
+ ):
28
+ super(BM25SummModel, self).__init__(
29
+ trained_domain=trained_domain,
30
+ max_input_length=max_input_length,
31
+ max_output_length=max_output_length,
32
+ model_backend=model_backend,
33
+ retrieval_ratio=retrieval_ratio,
34
+ preprocess=preprocess,
35
+ **kwargs
36
+ )
37
+
38
+ def _retrieve(self, instance: List[str], query: List[str], n_best):
39
+ bm25 = BM25(word_tokenize(s) for s in instance)
40
+ scores = bm25.get_scores(query)
41
+ best_sent_ind = sorted(
42
+ range(len(scores)), key=lambda i: scores[i], reverse=True
43
+ )[:n_best]
44
+ top_n_sent = [instance[ind] for ind in sorted(best_sent_ind)]
45
+ return top_n_sent
model/query_based/tf_idf_model.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base_query_based_model import QueryBasedSummModel
2
+ from model.base_model import SummModel
3
+ from model.single_doc import TextRankModel
4
+ from typing import List
5
+
6
+ from sklearn.feature_extraction.text import TfidfVectorizer
7
+ from sklearn.metrics.pairwise import cosine_similarity
8
+
9
+
10
+ class TFIDFSummModel(QueryBasedSummModel):
11
+
12
+ # static variables
13
+ model_name = "TF-IDF"
14
+ is_extractive = True
15
+ is_neural = False
16
+ is_query_based = True
17
+
18
+ def __init__(
19
+ self,
20
+ trained_domain: str = None,
21
+ max_input_length: int = None,
22
+ max_output_length: int = None,
23
+ model_backend: SummModel = TextRankModel,
24
+ retrieval_ratio: float = 0.5,
25
+ preprocess: bool = True,
26
+ **kwargs
27
+ ):
28
+ super(TFIDFSummModel, self).__init__(
29
+ trained_domain=trained_domain,
30
+ max_input_length=max_input_length,
31
+ max_output_length=max_output_length,
32
+ model_backend=model_backend,
33
+ retrieval_ratio=retrieval_ratio,
34
+ preprocess=preprocess,
35
+ **kwargs
36
+ )
37
+ self.vectorizer = TfidfVectorizer()
38
+
39
+ def _retrieve(self, instance: List[str], query: List[str], n_best):
40
+ instance_vectors = self.vectorizer.fit_transform(instance)
41
+ query_vector = self.vectorizer.transform(query)
42
+
43
+ similarities = cosine_similarity(query_vector, instance_vectors).squeeze()
44
+ top_n_index = similarities.argsort()[::-1][0:n_best]
45
+ top_n_sent = [instance[ind] for ind in top_n_index] # List[str]
46
+ return top_n_sent
model/single_doc/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
1
+ from .bart_model import BartModel
2
+ from .pegasus_model import PegasusModel
3
+ from .lexrank_model import LexRankModel
4
+ from .longformer_model import LongformerModel
5
+ from .textrank_model import TextRankModel
model/single_doc/bart_model.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BartForConditionalGeneration, BartTokenizer
2
+ from .base_single_doc_model import SingleDocSummModel
3
+
4
+
5
+ class BartModel(SingleDocSummModel):
6
+
7
+ # static variables
8
+ model_name = "BART"
9
+ is_extractive = False
10
+ is_neural = False
11
+
12
+ def __init__(self, device="cpu"):
13
+ super(BartModel, self).__init__()
14
+
15
+ self.device = device
16
+ model_name = "facebook/bart-large-cnn"
17
+ self.tokenizer = BartTokenizer.from_pretrained(model_name)
18
+ self.model = BartForConditionalGeneration.from_pretrained(model_name)
19
+
20
+ def summarize(self, corpus, queries=None):
21
+ self.assert_summ_input_type(corpus, queries)
22
+
23
+ batch = self.tokenizer(
24
+ corpus, truncation=True, padding="longest", return_tensors="pt"
25
+ ).to(self.device)
26
+ encoded_summaries = self.model.generate(**batch)
27
+ summaries = self.tokenizer.batch_decode(
28
+ encoded_summaries, skip_special_tokens=True
29
+ )
30
+
31
+ return summaries
32
+
33
+ @classmethod
34
+ def show_capability(cls) -> None:
35
+ # TODO zhangir: add the show capability function for BART
36
+ print(cls.generate_basic_description())
model/single_doc/base_single_doc_model.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from model.base_model import SummModel
2
+
3
+
4
+ class SingleDocSummModel(SummModel):
5
+ def __init__(
6
+ self,
7
+ trained_domain: str = None,
8
+ max_input_length: int = None,
9
+ max_output_length: int = None,
10
+ ):
11
+ super(SingleDocSummModel, self).__init__(
12
+ trained_domain=trained_domain,
13
+ max_input_length=max_input_length,
14
+ max_output_length=max_output_length,
15
+ )
16
+
17
+ @classmethod
18
+ def assert_summ_input_type(cls, corpus, query):
19
+ if not isinstance(corpus, list):
20
+ raise TypeError(
21
+ "Single-document summarization requires corpus of `List[str]`."
22
+ )
23
+ if not all([isinstance(ins, str) for ins in corpus]):
24
+ raise TypeError(
25
+ "Single-document summarization requires corpus of `List[str]`."
26
+ )
27
+
28
+ if query is not None:
29
+ if not isinstance(query, list):
30
+ raise TypeError(
31
+ "Query-based single-document summarization requires query of `List[str]`."
32
+ )
33
+ if not all([isinstance(q, str) for q in query]):
34
+ raise TypeError(
35
+ "Query-based single-document summarization requires query of `List[str]`."
36
+ )