kenichiro commited on
Commit
46a030d
1 Parent(s): b615e10
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
  title: Clinical Segnemt
3
- emoji: 🚑
4
  colorFrom: purple
5
  colorTo: yellow
6
  sdk: streamlit
 
1
  ---
2
  title: Clinical Segnemt
3
+ emoji: 🌖
4
  colorFrom: purple
5
  colorTo: yellow
6
  sdk: streamlit
__pycache__/chat.cpython-38.pyc DELETED
Binary file (1.46 kB)
 
__pycache__/functionforDownloadButtons.cpython-36.pyc ADDED
Binary file (4.54 kB). View file
 
__pycache__/functionforDownloadButtons.cpython-38.pyc ADDED
Binary file (4.59 kB). View file
 
__pycache__/model.cpython-36.pyc ADDED
Binary file (7.24 kB). View file
 
__pycache__/model.cpython-38.pyc ADDED
Binary file (7.26 kB). View file
 
__pycache__/model2.cpython-36.pyc ADDED
Binary file (7.04 kB). View file
 
__pycache__/run_segbot.cpython-36.pyc ADDED
Binary file (1.9 kB). View file
 
__pycache__/run_segbot.cpython-38.pyc ADDED
Binary file (1.9 kB). View file
 
__pycache__/solver.cpython-36.pyc ADDED
Binary file (4.81 kB). View file
 
__pycache__/solver.cpython-38.pyc ADDED
Binary file (4.83 kB). View file
 
__pycache__/solver2.cpython-36.pyc ADDED
Binary file (4.43 kB). View file
 
app.py CHANGED
@@ -1,19 +1,122 @@
1
- from flask import Flask, render_template, request, jsonify
 
 
 
 
 
 
2
 
3
- from chat import get_response
 
 
 
 
4
 
5
- app = Flask(__name__)
6
 
7
- @app.get("/")
8
- def index_get():
9
- return render_template("base.html")
 
 
 
 
 
 
 
 
 
10
 
11
- @app.post("/predict")
12
- def predict():
13
- text = request.get_json().get("message")
14
- response = get_response(text)
15
- message = {"answer": response}
16
- return jsonify(message)
17
 
18
- if __name__=="__main__":
19
- app.run(debug=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import numpy as np
3
+ from pandas import DataFrame
4
+ import run_segbot
5
+ from functionforDownloadButtons import download_button
6
+ import os
7
+ import json
8
 
9
+ st.set_page_config(
10
+ page_title="Clinical segment generater",
11
+ page_icon="🚑",
12
+ layout="wide"
13
+ )
14
 
 
15
 
16
+ def _max_width_():
17
+ max_width_str = f"max-width: 1400px;"
18
+ st.markdown(
19
+ f"""
20
+ <style>
21
+ .reportview-container .main .block-container{{
22
+ {max_width_str}
23
+ }}
24
+ </style>
25
+ """,
26
+ unsafe_allow_html=True,
27
+ )
28
 
 
 
 
 
 
 
29
 
30
+ #_max_width_()
31
+
32
+ #c30 = st.columns([1,])
33
+
34
+ #with c30:
35
+ # st.image("logo.png", width=400)
36
+ st.title("🚑 Clinical segment generater")
37
+ st.header("")
38
+
39
+
40
+
41
+ with st.expander("ℹ️ - About this app", expanded=True):
42
+
43
+ st.write(
44
+ """
45
+ - The *Clinical segment generater* app is an implementation of [our paper](https://journals.plos.org/digitalhealth/article?id=10.1371/journal.pdig.0000099).
46
+ - It automatically splits Japanese sentences into smaller units representing medical meanings.
47
+ """
48
+ )
49
+
50
+ st.markdown("")
51
+
52
+ st.markdown("")
53
+ st.markdown("## 📌 Paste document")
54
+ @st.cache(allow_output_mutation=True)
55
+ def model_load():
56
+ return run_segbot.setup()
57
+ model,fm,index = model_load()
58
+ with st.form(key="my_form"):
59
+
60
+
61
+ ce, c1, ce, c2, c3 = st.columns([0.07, 1, 0.07, 5, 0.07])
62
+ with c1:
63
+ ModelType = st.radio(
64
+ "Choose the method of sentence split",
65
+ ["fullstop & linebreak (Default)", "pySBD"],
66
+ help="""
67
+ At present, you can choose between 2 methods to split your text into sentences.
68
+
69
+ The fullstop & linebreak is naive and robust to noise, but has low accuracy.
70
+ pySBD is more accurate, but more complex and less robust to noise.
71
+ """,
72
+ )
73
+
74
+ if ModelType == "fullstop & linebreak (Default)":
75
+ split_method="fullstop"
76
+
77
+ else:
78
+ split_method="pySBD"
79
+
80
+
81
+ with c2:
82
+ doc = st.text_area(
83
+ "Paste your text below",
84
+ height=510,
85
+ )
86
+
87
+ submit_button = st.form_submit_button(label="👍 Go to split!")
88
+
89
+
90
+ if not submit_button:
91
+ st.stop()
92
+
93
+ keywords = run_segbot.generate(doc, model, fm, index, split_method)
94
+
95
+
96
+ st.markdown("## 🎈 Check & download results")
97
+
98
+ st.header("")
99
+
100
+
101
+ cs, c1, c2, c3, cLast = st.columns([2, 1.5, 1.5, 1.5, 2])
102
+
103
+ with c1:
104
+ CSVButton2 = download_button(keywords, "Data.csv", "📥 Download (.csv)")
105
+ with c2:
106
+ CSVButton2 = download_button(keywords, "Data.txt", "📥 Download (.txt)")
107
+ with c3:
108
+ CSVButton2 = download_button(keywords, "Data.json", "📥 Download (.json)")
109
+
110
+ st.header("")
111
+
112
+ #df = DataFrame(keywords, columns=["Keyword/Keyphrase", "Relevancy"])
113
+ df = DataFrame(keywords)
114
+ df.index += 1
115
+ df.columns = ['Segment']
116
+ print(df)
117
+ # Add styling
118
+
119
+ #c1, c2, c3 = st.columns([1, 3, 1])
120
+
121
+ #with c2:
122
+ st.table(df)
credata.py ADDED
@@ -0,0 +1,653 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gensim
2
+ import MeCab
3
+ import pickle
4
+ from gensim.models.wrappers.fasttext import FastText
5
+ #import fasttext as ft
6
+ import random
7
+ import mojimoji
8
+ import numpy as np
9
+ from tqdm import tqdm
10
+
11
+ def ymyi(lis):
12
+ wakati = MeCab.Tagger("-Owakati -b 81920 -u /clwork/ando/SEGBOT/MANBYO_201907_Dic-utf8.dic")
13
+
14
+ with open('fm_space.pickle', 'rb') as f:
15
+ fm = pickle.load(f)
16
+ #model = gensim.models.KeyedVectors.load_word2vec_format("/clwork/ando/SEGBOT/cc.ja.300.vec", binary=False)
17
+ model = ft.load_model("/clwork/ando/SEGBOT/fast/cc.ja.300.bin")
18
+ texts = []
19
+ sent = ""
20
+ sparate = []
21
+ label = []
22
+ ruiseki = 0
23
+ ruiseki2 = 0
24
+ alls = []
25
+ labels, text, num = [], [], []
26
+ for n, line in enumerate(open(lis)):
27
+ line = line.strip("\t").rstrip("\n")
28
+ #print(line)
29
+ if line == "":
30
+ if sent == "":
31
+ continue
32
+ sent = wakati.parse(sent).split(" ")[:-1]
33
+ flag = 0
34
+ for i in sent:
35
+ for j in sparate:
36
+ if ruiseki+len(i) > j and ruiseki < j:
37
+ label.append(1)
38
+ flag = 1
39
+ elif ruiseki+len(i) == j:
40
+ label.append(1)
41
+ flag = 1
42
+ if flag == 0:
43
+ label.append(0)
44
+ flag = 0
45
+ ruiseki += len(i)
46
+ #texts += i + " "
47
+ try:
48
+ texts.append(model[i])
49
+ #texts.append(np.array(fm.vocab[i]))
50
+ #texts += str(fm.vocab[i].index) + " "
51
+ #print(i,str(fm.vocab[i].index))
52
+ except KeyError:
53
+ texts.append(fm["<unk>"])
54
+ label[-1] = 1
55
+ #texts = texts.rstrip() + "\t"
56
+ #texts += " ".join(label) + "\n"
57
+ #alls.append((n,texts,label))
58
+ labels.append(label)
59
+ text.append(texts)
60
+ num.append(n)
61
+ sent = ""
62
+ sparate = []
63
+ texts = []
64
+ label = []
65
+ ruiseki = 0
66
+ ruiseki2 = 0
67
+ continue
68
+ sent += mojimoji.han_to_zen(line, digit=False, ascii=False)
69
+ ruiseki2 += len(line)
70
+ sparate.append(ruiseki2)
71
+ return num,text,labels
72
+
73
+ def nmni(lis):
74
+ #wakati = MeCab.Tagger("-Owakati -b 81920 -u /clwork/ando/SEGBOT/MANBYO_201907_Dic-utf8.dic")
75
+ wakati = MeCab.Tagger("-Owakati -b 81920")
76
+
77
+ with open('fm_space.pickle', 'rb') as f:
78
+ fm = pickle.load(f)
79
+ #model = gensim.models.KeyedVectors.load_word2vec_format("/clwork/ando/SEGBOT/cc.ja.300.vec", binary=False)
80
+ #model = ft.load_model("/clwork/ando/SEGBOT/fast/cc.ja.300.bin")
81
+ texts = []
82
+ sent = ""
83
+ sparate = []
84
+ label = []
85
+ ruiseki = 0
86
+ ruiseki2 = 0
87
+ alls = []
88
+ labels, text, num = [], [], []
89
+ for n, line in enumerate(open(lis)):
90
+ line = line.strip("\t").rstrip("\n")
91
+ #print(line)
92
+ if line == "":
93
+ if sent == "":
94
+ continue
95
+ sent = wakati.parse(sent).split(" ")[:-1]
96
+ flag = 0
97
+ for i in sent:
98
+ for j in sparate:
99
+ if ruiseki+len(i) > j and ruiseki < j:
100
+ label.append(1)
101
+ flag = 1
102
+ elif ruiseki+len(i) == j:
103
+ label.append(1)
104
+ flag = 1
105
+ if flag == 0:
106
+ label.append(0)
107
+ flag = 0
108
+ ruiseki += len(i)
109
+ #texts += i + " "
110
+ try:
111
+ #texts.append(model[i])
112
+ texts.append(fm[i])
113
+ #texts += str(fm.vocab[i].index) + " "
114
+ #print(i,str(fm.vocab[i].index))
115
+ except KeyError:
116
+ texts.append(fm["<unk>"])
117
+ label[-1] = 1
118
+ #texts = texts.rstrip() + "\t"
119
+ #texts += " ".join(label) + "\n"
120
+ #alls.append((n,texts,label))
121
+ labels.append(label)
122
+ text.append(texts)
123
+ num.append(n)
124
+ sent = ""
125
+ sparate = []
126
+ texts = []
127
+ label = []
128
+ ruiseki = 0
129
+ ruiseki2 = 0
130
+ continue
131
+ sent += mojimoji.han_to_zen(line, digit=False, ascii=False)
132
+ ruiseki2 += len(line)
133
+ sparate.append(ruiseki2)
134
+ return num,text,labels
135
+
136
+ def nmni_finetune(lis):
137
+ #wakati = MeCab.Tagger("-Owakati -b 81920 -u /clwork/ando/SEGBOT/MANBYO_201907_Dic-utf8.dic")
138
+ wakati = MeCab.Tagger("-Owakati -b 81920")
139
+ #fm = gensim.models.KeyedVectors.load_word2vec_format('/clwork/ando/SEGBOT/cc.ja.300.vec', binary=False)
140
+ with open('fm.pickle', 'rb') as f:
141
+ fm = pickle.load(f)
142
+ #fm = gensim.models.KeyedVectors.load_word2vec_format('cc.ja.300.vec', binary=False)
143
+ #with open('fm.pickle', 'wb') as f:
144
+ # pickle.dump(fm, f)
145
+ #model = gensim.models.KeyedVectors.load_word2vec_format("/clwork/ando/SEGBOT/cc.ja.300.vec", binary=False)
146
+ #model = ft.load_model("/clwork/ando/SEGBOT/fast/cc.ja.300.bin")
147
+ texts = []
148
+ sent = ""
149
+ sparate = []
150
+ label = []
151
+ ruiseki = 0
152
+ ruiseki2 = 0
153
+ alls = []
154
+ labels, text, num = [], [], []
155
+ for n, line in enumerate(open(lis)):
156
+ line = line.strip("\t").rstrip("\n")
157
+ #print(line)
158
+ if line == "":
159
+ if sent == "":
160
+ continue
161
+ sent = wakati.parse(sent).split(" ")[:-1]
162
+ flag = 0
163
+ for i in sent:
164
+ for j in sparate:
165
+ if ruiseki+len(i) > j and ruiseki < j:
166
+ label.append(1)
167
+ flag = 1
168
+ elif ruiseki+len(i) == j:
169
+ label.append(1)
170
+ flag = 1
171
+ if flag == 0:
172
+ label.append(0)
173
+ flag = 0
174
+ ruiseki += len(i)
175
+ #texts += i + " "
176
+ try:
177
+ #texts.append(model[i])
178
+ #texts.append(fm[i])
179
+ texts.append(fm.vocab[i].index)
180
+ #print(i,str(fm.vocab[i].index))
181
+ except KeyError:
182
+ texts.append(fm.vocab["<unk>"].index)
183
+ label[-1] = 1
184
+ #texts = texts.rstrip() + "\t"
185
+ #texts += " ".join(label) + "\n"
186
+ #alls.append((n,texts,label))
187
+ labels.append(np.array(label))
188
+ text.append(np.array(texts))
189
+ num.append(n)
190
+ sent = ""
191
+ sparate = []
192
+ texts = []
193
+ label = []
194
+ ruiseki = 0
195
+ ruiseki2 = 0
196
+ continue
197
+ sent += mojimoji.han_to_zen(line, digit=False, ascii=False)
198
+ ruiseki2 += len(line)
199
+ sparate.append(ruiseki2)
200
+ return text,labels
201
+
202
+
203
+
204
+ def nmni_carte(lis):
205
+ #wakati = MeCab.Tagger("-Owakati -b 81920 -u /clwork/ando/SEGBOT/MANBYO_201907_Dic-utf8.dic")
206
+ wakati = MeCab.Tagger("-Owakati -b 81920")
207
+ #fm = gensim.models.KeyedVectors.load_word2vec_format('/clwork/ando/SEGBOT/cc.ja.300.vec', binary=False)
208
+ #fm = gensim.models.KeyedVectors.load_word2vec_format('cc.ja.300.vec', binary=False)
209
+ #with open('fm.pickle', 'wb') as f:
210
+ # pickle.dump(fm, f)
211
+ #model = gensim.models.KeyedVectors.load_word2vec_format("/clwork/ando/SEGBOT/cc.ja.300.vec", binary=False)
212
+ #model = ft.load_model("/clwork/ando/SEGBOT/fast/cc.ja.300.bin")
213
+ with open('fm.pickle', 'rb') as f:
214
+ fm = pickle.load(f)
215
+ texts = []
216
+ sent = ""
217
+ sparate = []
218
+ label = []
219
+ ruiseki = 0
220
+ ruiseki2 = 0
221
+ alls = []
222
+ labels, text, num = [], [], []
223
+ allab, altex, fukugenss = [], [], []
224
+ #for n in tqdm(range(26431)):
225
+ for n in tqdm(range(108)):
226
+ fukugens = []
227
+ for line in open(lis+str(n)+".txt"):
228
+ line = line.strip()
229
+ if line == "":
230
+ continue
231
+ sent = wakati.parse(line).split(" ")[:-1]
232
+ flag = 0
233
+ label = []
234
+ texts = []
235
+ fukugen = []
236
+ for i in sent:
237
+ try:
238
+ texts.append(fm.vocab[i].index)
239
+ except KeyError:
240
+ texts.append(fm.vocab["<unk>"].index)
241
+ fukugen.append(i)
242
+ label.append(0)
243
+ label[-1] = 1
244
+ labels.append(np.array(label))
245
+ text.append(np.array(texts))
246
+ #labels.append(label)
247
+ #text.append(texts)
248
+ fukugens.append(fukugen)
249
+ allab.append(labels)
250
+ altex.append(text)
251
+ fukugenss.append(fukugens)
252
+ labels, text, fukugens= [], [], []
253
+ return altex, allab, fukugenss
254
+
255
+
256
+ def nmni_finetune_s(lis):
257
+ #wakati = MeCab.Tagger("-Owakati -b 81920 -u /clwork/ando/SEGBOT/MANBYO_201907_Dic-utf8.dic")
258
+ wakati = MeCab.Tagger("-Owakati -b 81920")
259
+ #fm = gensim.models.KeyedVectors.load_word2vec_format('/clwork/ando/SEGBOT/cc.ja.300.vec', binary=False)
260
+ fm = gensim.models.KeyedVectors.load_word2vec_format('cc.ja.300.vec', binary=False)
261
+ with open('fm.pickle', 'wb') as f:
262
+ pickle.dump(fm, f)
263
+ #model = gensim.models.KeyedVectors.load_word2vec_format("/clwork/ando/SEGBOT/cc.ja.300.vec", binary=False)
264
+ #model = ft.load_model("/clwork/ando/SEGBOT/fast/cc.ja.300.bin")
265
+ texts = []
266
+ sent = ""
267
+ sparate = []
268
+ label = []
269
+ ruiseki = 0
270
+ ruiseki2 = 0
271
+ alls = []
272
+ labels, text, num = [], [], []
273
+ for n, line in enumerate(open(lis)):
274
+ line = line.strip("\t").rstrip("\n")
275
+ sent = wakati.parse(line).split(" ")[:-1]
276
+ flag = 0
277
+ label = []
278
+ texts = []
279
+ for i in sent:
280
+ try:
281
+ texts.append(fm.vocab[i].index)
282
+ except KeyError:
283
+ texts.append(fm.vocab["<unk>"].index)
284
+ label.append(0)
285
+ label[-1] = 1
286
+ labels.append(np.array(label))
287
+ text.append(np.array(texts))
288
+ return text,labels
289
+
290
+
291
+ def nmni_finetune_ss(lis):
292
+ #wakati = MeCab.Tagger("-Owakati -b 81920 -u /clwork/ando/SEGBOT/MANBYO_201907_Dic-utf8.dic")
293
+ wakati = MeCab.Tagger("-Owakati -b 81920")
294
+ fm = gensim.models.KeyedVectors.load_word2vec_format('cc.ja.300.vec', binary=False)
295
+ with open('fm.pickle', 'wb') as f:
296
+ pickle.dump(fm, f)
297
+ #with open('fm.pickle', 'rb') as f:
298
+ # fm = pickle.load(f)
299
+ #model = gensim.models.KeyedVectors.load_word2vec_format("/clwork/ando/SEGBOT/cc.ja.300.vec", binary=False)
300
+ #model = ft.load_model("/clwork/ando/SEGBOT/fast/cc.ja.300.bin")
301
+ t,l =[],[]
302
+ for i in range(108):
303
+ texts = []
304
+ sent = ""
305
+ sparate = []
306
+ label = []
307
+ ruiseki = 0
308
+ ruiseki2 = 0
309
+ alls = []
310
+ labels, text, num = [], [], []
311
+ for n, line in enumerate(open(lis+str(i)+".txt")):
312
+ line = line.strip("\t").rstrip("\n")
313
+ if line == "":
314
+ continue
315
+ sent = wakati.parse(line).split(" ")[:-1]
316
+ flag = 0
317
+ label = []
318
+ texts = []
319
+ for i in sent:
320
+ try:
321
+ texts.append(fm.vocab[i].index)
322
+ except KeyError:
323
+ texts.append(fm.vocab["<unk>"].index)
324
+ label.append(0)
325
+ label[-1] = 1
326
+ labels.append(np.array(label))
327
+ text.append(np.array(texts))
328
+ t.append(text)
329
+ l.append(labels)
330
+ return t,l
331
+
332
+ #model = ft.load_model("/clwork/ando/SEGBOT/fast/cc.ja.300.bin")
333
+ #print(model.get_subwords("間質性肺炎"))
334
+ #print(model.get_subwords("誤嚥性肺炎"))
335
+ #print(model.get_subwords("談話ユニット分割"))
336
+
337
+ """
338
+ texts = []
339
+ sent = ""
340
+ sparate = []
341
+ label = []
342
+ ruiseki = 0
343
+ ruiseki2 = 0
344
+ alls = []
345
+ for n, line in enumerate(open("/clwork/ando/SEGBOT/randomdata.tsv")):
346
+ line = line.strip("\t").rstrip("\n")
347
+ if line == "":
348
+ if sent == "":
349
+ continue
350
+ alls.append(sent)
351
+ sent = ""
352
+ continue
353
+ else:
354
+ sent += line
355
+ if len(sent) != 0:
356
+ alls.append(sent)
357
+ random.shuffle(alls)
358
+ #v = random.sample(alls, 300)
359
+ #for i in v:
360
+ # alls.remove(i)
361
+ #t = random.sample(alls, 300)
362
+ #for i in t:
363
+ # alls.remove(i)
364
+ with open("randomdata_concat.tsv","a")as f:
365
+ f.write("\n".join())
366
+ #with open("dev_fix.tsv","a")as f:
367
+ # for i in v:
368
+ # f.write("\n".join(i))
369
+ # f.write("\n\n")
370
+ #with open("test_fix.tsv","a")as f:
371
+ # for i in t:
372
+ # f.write("\n".join(i))
373
+ # f.write("\n\n")
374
+ """
375
+
376
+ """
377
+ out = ""
378
+ for line in open("/clwork/ando/SEGBOT_BERT/alldata2_bert.tsv"):
379
+ line = line.split("\t")
380
+ line = line[0].strip()
381
+ if line == "" or "サマリ" in line:
382
+ continue
383
+ out += line + "\n"
384
+ with open("alldata3.tsv","w")as f:
385
+ f.write(out)
386
+ """
387
+ """
388
+ #wakati = MeCab.Tagger("-Owakati -b 81920 -u /clwork/ando/SEGBOT/MANBYO_201907_Dic-utf8.dic")
389
+ wakati = MeCab.Tagger("-Owakati -b 81920")
390
+
391
+ with open('fm_space.pickle', 'rb') as f:
392
+ fm = pickle.load(f)
393
+ #model = gensim.models.KeyedVectors.load_word2vec_format("/clwork/ando/SEGBOT/cc.ja.300.vec", binary=False)
394
+ #model = ft.load_model("/clwork/ando/SEGBOT/fast/cc.ja.300.bin")
395
+ texts = []
396
+ sent = ""
397
+ sparate = []
398
+ label = []
399
+ ruiseki = 0
400
+ ruiseki2 = 0
401
+ alls = []
402
+ for n, line in enumerate(open("/clwork/ando/SEGBOT/train_fix.tsv")):
403
+ line = line.strip("\t").rstrip("\n")
404
+ #print(line)
405
+ if line == "":
406
+ if sent == "":
407
+ continue
408
+ sent = wakati.parse(sent).split(" ")[:-1]
409
+ flag = 0
410
+ for i in sent:
411
+ for j in sparate:
412
+ if ruiseki+len(i) > j and ruiseki < j:
413
+ label.append(1)
414
+ flag = 1
415
+ elif ruiseki+len(i) == j:
416
+ label.append(1)
417
+ flag = 1
418
+ if flag == 0:
419
+ label.append(0)
420
+ flag = 0
421
+ ruiseki += len(i)
422
+ #texts += i + " "
423
+ try:
424
+ #texts.append(model[i])
425
+ texts.append(fm.vocab[i])
426
+ #texts += str(fm.vocab[i].index) + " "
427
+ #print(i,str(fm.vocab[i].index))
428
+ except KeyError:
429
+ texts.append(fm.vocab["<unk>"])
430
+ print(i)
431
+ label[-1] = 1
432
+ #texts = texts.rstrip() + "\t"
433
+ #texts += " ".join(label) + "\n"
434
+ alls.append((str(n),texts,label))
435
+ sent = ""
436
+ sparate = []
437
+ texts = []
438
+ label = []
439
+ ruiseki = 0
440
+ ruiseki2 = 0
441
+ continue
442
+ sent += mojimoji.han_to_zen(line, digit=False, ascii=False)
443
+ ruiseki2 += len(line)
444
+ sparate.append(ruiseki2)
445
+ with open('nm_ni/train.pickle', 'wb') as f:
446
+ pickle.dump(alls, f)
447
+ #print(alls)
448
+ #with open("resepdata_seped.tsv","w")as f:
449
+ # f.write(texts)
450
+ """
451
+
452
+
453
+
454
+ wakati = MeCab.Tagger("-Owakati")
455
+
456
+ #fm = gensim.models.KeyedVectors.load_word2vec_format('/clwork/ando/SEGBOT/cc.ja.300.vec', binary=False)
457
+ #with open('fm.pickle', 'wb') as f:
458
+ # pickle.dump(fm, f)
459
+ texts = ""
460
+ sent = ""
461
+ sparate = []
462
+ label = []
463
+ ruiseki = 0
464
+ ruiseki2 = 0
465
+ for line in open("alldata.tsv"):
466
+ line = line.split("\t")
467
+ line = line[0].strip()
468
+ if line == "" or "サマリ" in line:
469
+ if sent == "":
470
+ continue
471
+ sent = wakati.parse(sent).split(" ")[:-1]
472
+ flag = 0
473
+ #print(sent,sparate)
474
+ for i in sent:
475
+ #print(i)
476
+ for j in sparate:
477
+ if ruiseki+len(i) > j and ruiseki < j:
478
+ #print(j)
479
+ label.append("1")
480
+ flag = 1
481
+ elif ruiseki+len(i) == j:
482
+ #print(j)
483
+ label.append("1")
484
+ flag = 1
485
+ if flag == 0:
486
+ label.append("0")
487
+ flag = 0
488
+ ruiseki += len(i)
489
+ #texts += i + " "
490
+
491
+ try:
492
+ texts += str(0) + " "
493
+ except KeyError:
494
+ print(i)
495
+ #texts += str(fm.vocab["<unk>"].index) + " "
496
+
497
+ label[-1] = "1"
498
+ texts = texts.rstrip() + "\t"
499
+ texts += " ".join(label) + "\n"
500
+ sent = ""
501
+ sparate = []
502
+ label = []
503
+ ruiseki = 0
504
+ ruiseki2 = 0
505
+ #print(texts)
506
+ continue
507
+ sent += line.strip()
508
+ ruiseki2 += len(line.strip())
509
+ sparate.append(ruiseki2)
510
+ with open("random_labbeled.tsv","w")as f:
511
+ f.write(texts)
512
+
513
+
514
+
515
+
516
+
517
+ """
518
+ wakati = MeCab.Tagger("-Owakati -u /clwork/ando/SEGBOT/MANBYO_201907_Dic-utf8.dic")
519
+
520
+
521
+ #fm = gensim.models.KeyedVectors.load_word2vec_format('/clwork/ando/SEGBOT/cc.ja.300_space.vec', binary=False)
522
+ #with open('fm_space.pickle', 'wb') as f:
523
+ # pickle.dump(fm, f)
524
+
525
+ with open('fm_space.pickle', 'rb') as f:
526
+ fm = pickle.load(f)
527
+ texts = ""
528
+ sent = ""
529
+ sparate = []
530
+ label = []
531
+ ruiseki = 0
532
+ ruiseki2 = 0
533
+ for line in open("/clwork/ando/SEGBOT/alldata_resep.tsv"):
534
+ line = line.split("\t")
535
+ line = line[0].strip("\t").rstrip("\n")
536
+ #print(line)
537
+ if line == "" or "サマリ" in line:
538
+ if sent == "":
539
+ continue
540
+ print(sent)
541
+ sent = sent.replace(" ","<space>")
542
+ sent = wakati.parse(sent).split(" ")[:-1]
543
+ print(sent)
544
+ flag = 0
545
+ #print(sent,sparate)
546
+ for i in sent:
547
+ #print(i)
548
+ for j in sparate:
549
+ if ruiseki+len(i) > j and ruiseki < j:
550
+ #print(j)
551
+ label.append("1")
552
+ flag = 1
553
+ elif ruiseki+len(i) == j:
554
+ #print(j)
555
+ label.append("1")
556
+ flag = 1
557
+ if flag == 0:
558
+ label.append("0")
559
+ flag = 0
560
+ ruiseki += len(i)
561
+ #texts += i + " "
562
+
563
+ try:
564
+ texts += str(fm.vocab[i].index) + " "
565
+ #print(i,str(fm.vocab[i].index))
566
+ except KeyError:
567
+ texts += str(fm.vocab["<unk>"].index) + " "
568
+ label[-1] = "1"
569
+ texts = texts.rstrip() + "\t"
570
+ texts += " ".join(label) + "\n"
571
+ sent = ""
572
+ sparate = []
573
+ label = []
574
+ ruiseki = 0
575
+ ruiseki2 = 0
576
+ #print(texts)
577
+ continue
578
+ sent += line.strip("\t")
579
+ ruiseki2 += len(line)
580
+ sparate.append(ruiseki2)
581
+ with open("alldata2_space.tsv","w")as f:
582
+ f.write(texts)
583
+ """
584
+
585
+
586
+
587
+ """
588
+ wakati = MeCab.Tagger("-Owakati")
589
+
590
+ fm = gensim.models.KeyedVectors.load_word2vec_format('/clwork/ando/SEGBOT/cc.ja.300.vec', binary=False)
591
+ texts = ""
592
+ sent = ""
593
+ cand = ""
594
+ sparate = []
595
+ label = []
596
+ ruiseki = 0
597
+ ruiseki2 = 0
598
+ flag2 = 1
599
+ for line in open("data2.tsv"):
600
+ line = line.split("\t")
601
+ if flag2 == 1:
602
+ cand = line
603
+ flag2 = 2
604
+ continue
605
+ if flag2 == 2:
606
+ flag2 = 1
607
+ #print(line,cand)
608
+ for n,z in enumerate(zip(cand,line)):
609
+ i = z[0]
610
+ j = z[1]
611
+ n = n+1
612
+ if i == "":
613
+ sent = wakati.parse(sent).split(" ")[:-1]
614
+ flag = 0
615
+ #print(sent,sparate)
616
+ for i in sent:
617
+ #print(i)
618
+ for j in sparate:
619
+ if ruiseki+len(i) > j and ruiseki < j:
620
+ #print(j)
621
+ label.append("1")
622
+ flag = 1
623
+ elif ruiseki+len(i) == j:
624
+ #print(j)
625
+ label.append("1")
626
+ flag = 1
627
+ if flag == 0:
628
+ label.append("0")
629
+ flag = 0
630
+ ruiseki += len(i)
631
+ #texts += i + " "
632
+
633
+ try:
634
+ texts += str(fm.vocab[i].index) + " "
635
+ except KeyError:
636
+ texts += str(fm.vocab["<unk>"].index) + " "
637
+
638
+ label[-1] = "1"
639
+ texts = texts.rstrip() + "\t"
640
+ texts += " ".join(label) + "\n"
641
+ sent = ""
642
+ sparate = []
643
+ label = []
644
+ ruiseki = 0
645
+ ruiseki2 = 0
646
+ #print(texts)
647
+ break
648
+ if j == "|":
649
+ sparate.append(n)
650
+ sent += i
651
+ with open("alldata.tsv","w")as f:
652
+ f.write(texts)
653
+ """
fm.pickle ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f4c02d5957824106f6217e9a56d89ee5b7ca9ae399c7a49af8dc062e1ea0be99
3
+ size 2521658187
functionforDownloadButtons.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pickle
3
+ import pandas as pd
4
+ import json
5
+ import base64
6
+ import uuid
7
+ import re
8
+
9
+ import importlib.util
10
+
11
+
12
+ def import_from_file(module_name: str, filepath: str):
13
+ """
14
+ Imports a module from file.
15
+
16
+ Args:
17
+ module_name (str): Assigned to the module's __name__ parameter (does not
18
+ influence how the module is named outside of this function)
19
+ filepath (str): Path to the .py file
20
+
21
+ Returns:
22
+ The module
23
+ """
24
+ spec = importlib.util.spec_from_file_location(module_name, filepath)
25
+ module = importlib.util.module_from_spec(spec)
26
+ spec.loader.exec_module(module)
27
+ return module
28
+
29
+
30
+ def notebook_header(text):
31
+ """
32
+ Insert section header into a jinja file, formatted as notebook cell.
33
+
34
+ Leave 2 blank lines before the header.
35
+ """
36
+ return f"""# # {text}
37
+
38
+ """
39
+
40
+
41
+ def code_header(text):
42
+ """
43
+ Insert section header into a jinja file, formatted as Python comment.
44
+
45
+ Leave 2 blank lines before the header.
46
+ """
47
+ seperator_len = (75 - len(text)) / 2
48
+ seperator_len_left = math.floor(seperator_len)
49
+ seperator_len_right = math.ceil(seperator_len)
50
+ return f"# {'-' * seperator_len_left} {text} {'-' * seperator_len_right}"
51
+
52
+
53
+ def to_notebook(code):
54
+ """Converts Python code to Jupyter notebook format."""
55
+ notebook = jupytext.reads(code, fmt="py")
56
+ return jupytext.writes(notebook, fmt="ipynb")
57
+
58
+
59
+ def open_link(url, new_tab=True):
60
+ """Dirty hack to open a new web page with a streamlit button."""
61
+ # From: https://discuss.streamlit.io/t/how-to-link-a-button-to-a-webpage/1661/3
62
+ if new_tab:
63
+ js = f"window.open('{url}')" # New tab or window
64
+ else:
65
+ js = f"window.location.href = '{url}'" # Current tab
66
+ html = '<img src onerror="{}">'.format(js)
67
+ div = Div(text=html)
68
+ st.bokeh_chart(div)
69
+
70
+
71
+ def download_button(object_to_download, download_filename, button_text):
72
+ """
73
+ Generates a link to download the given object_to_download.
74
+
75
+ From: https://discuss.streamlit.io/t/a-download-button-with-custom-css/4220
76
+
77
+ Params:
78
+ ------
79
+ object_to_download: The object to be downloaded.
80
+ download_filename (str): filename and extension of file. e.g. mydata.csv,
81
+ some_txt_output.txt download_link_text (str): Text to display for download
82
+ link.
83
+
84
+ button_text (str): Text to display on download button (e.g. 'click here to download file')
85
+ pickle_it (bool): If True, pickle file.
86
+
87
+ Returns:
88
+ -------
89
+ (str): the anchor tag to download object_to_download
90
+
91
+ Examples:
92
+ --------
93
+ download_link(your_df, 'YOUR_DF.csv', 'Click to download data!')
94
+ download_link(your_str, 'YOUR_STRING.txt', 'Click to download text!')
95
+
96
+ """
97
+ # if pickle_it:
98
+ # try:
99
+ # object_to_download = pickle.dumps(object_to_download)
100
+ # except pickle.PicklingError as e:
101
+ # st.write(e)
102
+ # return None
103
+
104
+ # if:
105
+ if isinstance(object_to_download, bytes):
106
+ pass
107
+
108
+ elif isinstance(object_to_download, pd.DataFrame):
109
+ object_to_download = object_to_download.to_csv(index=False)
110
+ # Try JSON encode for everything else
111
+ else:
112
+ object_to_download = json.dumps(object_to_download)
113
+
114
+ try:
115
+ # some strings <-> bytes conversions necessary here
116
+ b64 = base64.b64encode(object_to_download.encode()).decode()
117
+ except AttributeError as e:
118
+ b64 = base64.b64encode(object_to_download).decode()
119
+
120
+ button_uuid = str(uuid.uuid4()).replace("-", "")
121
+ button_id = re.sub("\d+", "", button_uuid)
122
+
123
+ custom_css = f"""
124
+ <style>
125
+ #{button_id} {{
126
+ display: inline-flex;
127
+ align-items: center;
128
+ justify-content: center;
129
+ background-color: rgb(255, 255, 255);
130
+ color: rgb(38, 39, 48);
131
+ padding: .25rem .75rem;
132
+ position: relative;
133
+ text-decoration: none;
134
+ border-radius: 4px;
135
+ border-width: 1px;
136
+ border-style: solid;
137
+ border-color: rgb(230, 234, 241);
138
+ border-image: initial;
139
+ }}
140
+ #{button_id}:hover {{
141
+ border-color: rgb(246, 51, 102);
142
+ color: rgb(246, 51, 102);
143
+ }}
144
+ #{button_id}:active {{
145
+ box-shadow: none;
146
+ background-color: rgb(246, 51, 102);
147
+ color: white;
148
+ }}
149
+ </style> """
150
+
151
+ dl_link = (
152
+ custom_css
153
+ + f'<a download="{download_filename}" id="{button_id}" href="data:file/txt;base64,{b64}">{button_text}</a><br><br>'
154
+ )
155
+ # dl_link = f'<a download="{download_filename}" id="{button_id}" href="data:file/txt;base64,{b64}"><input type="button" kind="primary" value="{button_text}"></a><br></br>'
156
+
157
+ st.markdown(dl_link, unsafe_allow_html=True)
158
+
159
+
160
+ # def download_link(
161
+ # content, label="Download", filename="file.txt", mimetype="text/plain"
162
+ # ):
163
+ # """Create a HTML link to download a string as a file."""
164
+ # # From: https://discuss.streamlit.io/t/how-to-download-file-in-streamlit/1806/9
165
+ # b64 = base64.b64encode(
166
+ # content.encode()
167
+ # ).decode() # some strings <-> bytes conversions necessary here
168
+ # href = (
169
+ # f'<a href="data:{mimetype};base64,{b64}" download="{filename}">{label}</a>'
170
+ # )
171
+ # return href
logo.png ADDED
model.py ADDED
@@ -0,0 +1,465 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.utils.rnn as R
4
+ import torch.nn.functional as F
5
+ from torch.autograd import Variable
6
+ import numpy as np
7
+
8
+
9
+
10
+ class PointerNetworks(nn.Module):
11
+ def __init__(self,voca_size, voc_embeddings,word_dim, hidden_dim,is_bi_encoder_rnn,rnn_type,rnn_layers,
12
+ dropout_prob,use_cuda,finedtuning,isbanor,batchsize):
13
+ super(PointerNetworks,self).__init__()
14
+
15
+ self.word_dim = word_dim
16
+ self.voca_size = voca_size
17
+
18
+ self.hidden_dim = hidden_dim
19
+ self.dropout_prob = dropout_prob
20
+ self.is_bi_encoder_rnn = is_bi_encoder_rnn
21
+ self.num_rnn_layers = rnn_layers
22
+ self.rnn_type = rnn_type
23
+ self.voc_embeddings = voc_embeddings
24
+ self.finedtuning = finedtuning
25
+ self.batchsize = batchsize
26
+
27
+ self.nnDropout = nn.Dropout(dropout_prob)
28
+
29
+ self.isbanor = isbanor
30
+
31
+
32
+ if rnn_type in ['LSTM', 'GRU']:
33
+
34
+
35
+
36
+ self.decoder_rnn = getattr(nn, rnn_type)(input_size=word_dim,
37
+ hidden_size=2 * hidden_dim if is_bi_encoder_rnn else hidden_dim,
38
+ num_layers=rnn_layers,
39
+ dropout=dropout_prob,
40
+ batch_first=True)
41
+
42
+ self.encoder_rnn = getattr(nn, rnn_type)(input_size=word_dim,
43
+ hidden_size=hidden_dim,
44
+ num_layers=rnn_layers,
45
+ bidirectional=is_bi_encoder_rnn,
46
+ dropout=dropout_prob,
47
+ batch_first=True)
48
+
49
+
50
+
51
+ else:
52
+ print('rnn_type should be LSTM,GRU')
53
+
54
+ self.use_cuda = True
55
+
56
+ self.nnSELU = nn.SELU()
57
+
58
+
59
+ self.nnEm = nn.Embedding(self.voca_size,self.word_dim,padding_idx=2000001)
60
+ #self.nnEm = nn.Embedding.from_pretrained(self.voc_embeddings,freeze=self.finedtuning,padding_idx=-1)
61
+ self.initEmbeddings(self.voc_embeddings)
62
+ if self.use_cuda:
63
+ self.nnEm = self.nnEm.cuda()
64
+
65
+
66
+
67
+
68
+
69
+
70
+ if self.is_bi_encoder_rnn:
71
+ self.num_encoder_bi = 2
72
+ else:
73
+ self.num_encoder_bi = 1
74
+
75
+
76
+ self.nnW1 = nn.Linear(self.num_encoder_bi * hidden_dim, self.num_encoder_bi * hidden_dim, bias=False)
77
+ self.nnW2 = nn.Linear(self.num_encoder_bi * hidden_dim, self.num_encoder_bi * hidden_dim, bias=False)
78
+ self.nnV = nn.Linear(self.num_encoder_bi * hidden_dim, 1, bias=False)
79
+
80
+
81
+
82
+
83
+
84
+
85
+
86
+
87
+
88
+
89
+
90
+ def initEmbeddings(self,weights):
91
+ self.nnEm.weight.data.copy_(torch.from_numpy(weights))
92
+ self.nnEm.weight.requires_grad = self.finedtuning
93
+
94
+
95
+
96
+ def initHidden(self,hsize,batchsize):
97
+
98
+ #hsize=self.hidden_dim
99
+ #batchsize=self.batchsize
100
+ if self.rnn_type == 'LSTM':
101
+
102
+ h_0 = Variable(torch.zeros(self.num_encoder_bi*self.num_rnn_layers, batchsize, hsize))
103
+ c_0 = Variable(torch.zeros(self.num_encoder_bi*self.num_rnn_layers, batchsize, hsize))
104
+
105
+ if self.use_cuda:
106
+ h_0 = h_0.cuda()
107
+ c_0 = c_0.cuda()
108
+
109
+ return (h_0, c_0)
110
+ else:
111
+
112
+ h_0 = Variable(torch.zeros(self.num_encoder_bi*self.num_rnn_layers, batchsize, hsize))
113
+
114
+ if self.use_cuda:
115
+ h_0 = h_0.cuda()
116
+
117
+
118
+ return h_0
119
+
120
+
121
+
122
+
123
+
124
+
125
+
126
+ def _run_rnn_packed(self, cell, x, x_lens, h=None):
127
+ #print(x_lens)
128
+ x_packed = R.pack_padded_sequence(x, x_lens.data.tolist(),
129
+ batch_first=True, enforce_sorted=False)
130
+ if h is not None:
131
+ output, h = cell(x_packed, h)
132
+ else:
133
+ output, h = cell(x_packed)
134
+
135
+ output, _ = R.pad_packed_sequence(output, batch_first=True)
136
+
137
+ return output, h
138
+
139
+
140
+
141
+
142
+
143
+ def pointerEncoder(self,Xin,lens):
144
+ self.bn_inputdata = nn.BatchNorm1d(self.word_dim, affine=False, track_running_stats=False)
145
+
146
+
147
+ batch_size,maxL = Xin.size()
148
+
149
+ X = self.nnEm(Xin) # N L C
150
+
151
+ if self.isbanor and maxL>1:
152
+ X= X.permute(0,2,1) # N C L
153
+ X = self.bn_inputdata(X)
154
+ X = X.permute(0, 2, 1) # N L C
155
+
156
+ X = self.nnDropout(X)
157
+
158
+
159
+
160
+ encoder_lstm_co_h_o = self.initHidden(self.hidden_dim, batch_size)
161
+ o, h = self._run_rnn_packed(self.encoder_rnn, X, lens, encoder_lstm_co_h_o) # batch_first=True
162
+ o = o.contiguous()
163
+
164
+ o = self.nnDropout(o)
165
+
166
+
167
+
168
+
169
+ return o,h
170
+
171
+
172
+ def pointerLayer(self,en,di):
173
+ """
174
+
175
+ :param en: [L,H]
176
+ :param di: [H,]
177
+ :return:
178
+ """
179
+
180
+
181
+ WE = self.nnW1(en)
182
+
183
+
184
+ exdi = di.expand_as(en)
185
+
186
+ WD = self.nnW2(exdi)
187
+
188
+ nnV = self.nnV(self.nnSELU(WE+WD))
189
+
190
+ nnV = nnV.permute(1,0)
191
+
192
+ nnV = self.nnSELU(nnV)
193
+
194
+
195
+ #TODO: for log loss
196
+ att_weights = F.softmax(nnV)
197
+ logits = F.log_softmax(nnV)
198
+
199
+
200
+
201
+
202
+ return logits,att_weights
203
+
204
+
205
+
206
+
207
+
208
+
209
+
210
+ def training_decoder(self,hn,hend,X,Xindex,Yindex,lens):
211
+ """
212
+
213
+
214
+ """
215
+
216
+
217
+ loss_function = nn.NLLLoss()
218
+ batch_loss =0
219
+ LoopN =0
220
+ batch_size = len(lens)
221
+ for i in range(len(lens)): #Loop batch size
222
+
223
+ curX_index = Xindex[i]
224
+ #print(curX_index)
225
+ #print()
226
+ curY_index = Yindex[i]
227
+ curL = lens[i]
228
+ curX = X[i]
229
+ #print(curX)
230
+
231
+ x_index_var = Variable(torch.from_numpy(curX_index.astype(np.int64)))
232
+ if self.use_cuda:
233
+ x_index_var = x_index_var.cuda()
234
+ cur_lookup = curX[x_index_var]
235
+ #print(cur_lookup)
236
+
237
+ curX_vectors = self.nnEm(cur_lookup) # output: [seq,features]
238
+
239
+ curX_vectors = curX_vectors.unsqueeze(0) # [batch, seq, features]
240
+
241
+
242
+
243
+ if self.rnn_type =='LSTM':# need h_end,c_end
244
+
245
+
246
+ h_end = hend[0].permute(1, 0, 2).contiguous().view(batch_size, self.num_rnn_layers,-1)
247
+ c_end = hend[1].permute(1, 0, 2).contiguous().view(batch_size, self.num_rnn_layers,-1)
248
+
249
+ curh0 = h_end[i].unsqueeze(0).permute(1, 0, 2)
250
+ curc0 = c_end[i].unsqueeze(0).permute(1, 0, 2)
251
+
252
+
253
+ h_pass = (curh0,curc0)
254
+ else:
255
+
256
+
257
+ h_end = hend.permute(1, 0, 2).contiguous().view(batch_size, self.num_rnn_layers,-1)
258
+ curh0 = h_end[i].unsqueeze(0).permute(1, 0, 2)
259
+ h_pass = curh0
260
+
261
+
262
+
263
+ decoder_out,_ = self.decoder_rnn(curX_vectors,h_pass)
264
+ decoder_out = decoder_out.squeeze(0) #[seq,features]
265
+
266
+
267
+ curencoder_hn = hn[i,0:curL,:] # hn[batch,seq,H] -->[seq,H] i is loop batch size
268
+
269
+ for j in range(len(decoder_out)): #Loop di
270
+ #print(len(decoder_out),curY_index)
271
+ cur_dj = decoder_out[j]
272
+ cur_groundy = curY_index[j]
273
+
274
+ cur_start_index = curX_index[j]
275
+ predict_range = list(range(cur_start_index,curL))
276
+
277
+ # TODO: make it point backward, only consider predict_range in current time step
278
+ # align groundtruth
279
+ cur_groundy_var = Variable(torch.LongTensor([int(cur_groundy) - int(cur_start_index)]))
280
+ if self.use_cuda:
281
+ cur_groundy_var = cur_groundy_var.cuda()
282
+
283
+ curencoder_hn_back = curencoder_hn[predict_range,:]
284
+
285
+
286
+
287
+
288
+ cur_logists, cur_weights = self.pointerLayer(curencoder_hn_back,cur_dj)
289
+
290
+ batch_loss = batch_loss + loss_function(cur_logists,cur_groundy_var)
291
+ LoopN = LoopN +1
292
+
293
+ batch_loss = batch_loss/LoopN
294
+
295
+ return batch_loss
296
+
297
+
298
+ def neg_log_likelihood(self,Xin,index_decoder_x, index_decoder_y,lens):
299
+
300
+ '''
301
+ :param Xin: stack_x, [allseq,wordDim]
302
+ :param Yin:
303
+ :param lens:
304
+ :return:
305
+ '''
306
+
307
+
308
+ encoder_hn, encoder_h_end = self.pointerEncoder(Xin,lens)
309
+
310
+ loss = self.training_decoder(encoder_hn, encoder_h_end,Xin,index_decoder_x, index_decoder_y,lens)
311
+
312
+ return loss
313
+
314
+
315
+
316
+
317
+ def test_decoder(self,hn,hend,X,Yindex,lens):
318
+
319
+ loss_function = nn.NLLLoss()
320
+ batch_loss = 0
321
+ LoopN = 0
322
+
323
+ batch_boundary =[]
324
+ batch_boundary_start =[]
325
+ batch_align_matrix =[]
326
+
327
+ batch_size = len(lens)
328
+
329
+ for i in range(len(lens)): # Loop batch size
330
+
331
+
332
+
333
+ curL = lens[i]
334
+ curY_index = Yindex[i]
335
+ curX = X[i]
336
+ cur_end_boundary =curY_index[-1]
337
+
338
+ cur_boundary = []
339
+ cur_b_start = []
340
+ cur_align_matrix = []
341
+
342
+ cur_sentence_vectors = self.nnEm(curX) # output: [seq,features]
343
+
344
+
345
+ if self.rnn_type =='LSTM':# need h_end,c_end
346
+
347
+
348
+ h_end = hend[0].permute(1, 0, 2).contiguous().view(batch_size, self.num_rnn_layers,-1)
349
+ c_end = hend[1].permute(1, 0, 2).contiguous().view(batch_size, self.num_rnn_layers,-1)
350
+
351
+ curh0 = h_end[i].unsqueeze(0).permute(1, 0, 2)
352
+ curc0 = c_end[i].unsqueeze(0).permute(1, 0, 2)
353
+
354
+ h_pass = (curh0,curc0)
355
+ else: # only need h_end
356
+
357
+
358
+ h_end = hend.permute(1, 0, 2).contiguous().view(batch_size, self.num_rnn_layers,-1)
359
+ curh0 = h_end[i].unsqueeze(0).permute(1, 0, 2)
360
+ h_pass = curh0
361
+
362
+
363
+
364
+ curencoder_hn = hn[i, 0:curL, :] # hn[batch,seq,H] --> [seq,H] i is loop batch size
365
+
366
+ Not_break = True
367
+
368
+ loop_in = cur_sentence_vectors[0,:].unsqueeze(0).unsqueeze(0) #[1,1,H]
369
+ loop_hc = h_pass
370
+
371
+
372
+ loopstart =0
373
+
374
+ loop_j =0
375
+ while (Not_break): #if not end
376
+
377
+ loop_o, loop_hc = self.decoder_rnn(loop_in,loop_hc)
378
+
379
+
380
+ #TODO: make it point backward
381
+
382
+ predict_range = list(range(loopstart,curL))
383
+ curencoder_hn_back = curencoder_hn[predict_range,:]
384
+ cur_logists, cur_weights = self.pointerLayer(curencoder_hn_back, loop_o.squeeze(0).squeeze(0))
385
+
386
+ cur_align_vector = np.zeros(curL)
387
+ cur_align_vector[predict_range]=cur_weights.data.cpu().numpy()[0]
388
+ cur_align_matrix.append(cur_align_vector)
389
+
390
+ #TODO:align groundtruth
391
+ if loop_j > len(curY_index)-1:
392
+ cur_groundy = curY_index[-1]
393
+ else:
394
+ cur_groundy = curY_index[loop_j]
395
+
396
+
397
+ cur_groundy_var = Variable(torch.LongTensor([max(0,int(cur_groundy) - loopstart)]))
398
+ if self.use_cuda:
399
+ cur_groundy_var = cur_groundy_var.cuda()
400
+
401
+ batch_loss = batch_loss + loss_function(cur_logists, cur_groundy_var)
402
+
403
+
404
+ #TODO: get predicted boundary
405
+ topv, topi = cur_logists.data.topk(1)
406
+
407
+ pred_index = topi[0][0]
408
+
409
+
410
+ #TODO: align pred_index to original seq
411
+ ori_pred_index =pred_index + loopstart
412
+
413
+
414
+ if cur_end_boundary == ori_pred_index:
415
+ cur_boundary.append(ori_pred_index)
416
+ cur_b_start.append(loopstart)
417
+ Not_break = False
418
+ loop_j = loop_j + 1
419
+ LoopN = LoopN + 1
420
+ break
421
+ else:
422
+ cur_boundary.append(ori_pred_index)
423
+
424
+ loop_in = cur_sentence_vectors[ori_pred_index+1,:].unsqueeze(0).unsqueeze(0)
425
+ cur_b_start.append(loopstart)
426
+
427
+ loopstart = ori_pred_index+1 # start = pred_end + 1
428
+
429
+ loop_j = loop_j + 1
430
+ LoopN = LoopN + 1
431
+
432
+
433
+ #For each instance in batch
434
+ batch_boundary.append(cur_boundary)
435
+ batch_boundary_start.append(cur_b_start)
436
+ batch_align_matrix.append(cur_align_matrix)
437
+
438
+ batch_loss = batch_loss / LoopN
439
+
440
+ batch_boundary=np.array(batch_boundary)
441
+ batch_boundary_start = np.array(batch_boundary_start)
442
+ batch_align_matrix = np.array(batch_align_matrix)
443
+
444
+ return batch_loss,batch_boundary,batch_boundary_start,batch_align_matrix
445
+
446
+
447
+
448
+
449
+
450
+
451
+
452
+
453
+ def predict(self,Xin,index_decoder_y,lens):
454
+
455
+ batch_size = index_decoder_y.shape[0]
456
+
457
+ encoder_hn, encoder_h_end = self.pointerEncoder(Xin, lens)
458
+
459
+
460
+
461
+
462
+
463
+ batch_loss, batch_boundary, batch_boundary_start, batch_align_matrix = self.test_decoder(encoder_hn,encoder_h_end,Xin,index_decoder_y,lens)
464
+
465
+ return batch_loss,batch_boundary,batch_boundary_start,batch_align_matrix
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ seaborn
2
+ matplotlib
3
+ streamlit == 0.87
4
+ pandas == 1.2.4
5
+ keybert
6
+ flair
7
+ click<8
run_segbot.py CHANGED
@@ -1,5 +1,4 @@
1
  import re
2
- from nltk.tokenize import word_tokenize
3
  import pickle
4
  import numpy as np
5
  import random
@@ -8,99 +7,67 @@ from solver import TrainSolver
8
 
9
  from model import PointerNetworks
10
  import gensim
11
- from tqdm import tqdm
12
-
13
- class Lang:
14
- def __init__(self, name):
15
- self.name = name
16
- self.word2index = {"RE_DIGITS":1,"UNKNOWN":0,"PADDING":2000001}
17
- self.word2count = {"RE_DIGITS":1,"UNKNOWN":1,"PADDING":1}
18
- self.index2word = {2000001: "PADDING", 1: "RE_DIGITS", 0: "UNKNOWN"}
19
- self.n_words = 3 # Count SOS and EOS
20
-
21
- def addSentence(self, sentence):
22
- for word in sentence.strip('\n').strip('\r').split(' '):
23
- self.addWord(word)
24
-
25
- def addWord(self, word):
26
- if word not in self.word2index:
27
- self.word2index[word] = self.n_words
28
- self.word2count[word] = 1
29
- self.index2word[self.n_words] = word
30
- self.n_words += 1
31
- else:
32
- self.word2count[word] += 1
33
-
34
-
35
-
36
- def mytokenizer(inS,all_dict):
37
-
38
- #repDig = re.sub(r'\d+[\.,/]?\d+','RE_DIGITS',inS)
39
- #repDig = re.sub(r'\d*[\d,]*\d+', 'RE_DIGITS', inS)
40
- toked = inS
41
- or_toked = inS
42
- re_unk_list = []
43
- ori_list = []
44
-
45
- for (i,t) in enumerate(toked):
46
- if t not in all_dict and t not in ['RE_DIGITS']:
47
- re_unk_list.append('UNKNOWN')
48
- ori_list.append(or_toked[i])
49
- else:
50
- re_unk_list.append(t)
51
- ori_list.append(or_toked[i])
52
-
53
- labey_edus = [0]*len(re_unk_list)
54
- labey_edus[-1] = 1
55
-
56
-
57
-
58
-
59
- return ori_list,re_unk_list,labey_edus
60
-
61
-
62
-
63
- def get_mapping(X,Y,D):
64
-
65
- X_map = []
66
- for w in X:
67
- if w in D:
68
- X_map.append(D[w])
69
  else:
70
- X_map.append(D['UNKNOWN'])
71
-
72
- X_map = np.array([X_map])
73
- Y_map = np.array([Y])
74
-
75
-
76
-
77
- return X_map,Y_map
78
-
79
-
80
-
81
-
82
-
83
- def get_model():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  with open('model.pickle', 'rb') as f:
85
  mysolver = pickle.load(f)
86
- return mysolver
87
-
88
- #for i in tqdm(range(0,26431)):
89
- test_batch_ave_loss, test_pre, test_rec, test_f1, visdata = mysolver.check_accuracy(X_tes, Y_tes,index2word, fukugen)
90
- #test_batch_ave_loss, test_pre, test_rec, test_f1, visdata = mysolver.check_accuracy(X_tes, Y_tes,0)
91
- #with open(str(i)+"seped","w")as f:
92
- # f.write(o)
93
- #test_batch_ave_loss, test_pre, test_rec, test_f1, visdata = mysolver.check_accuracy(X_tes, Y_tes,0)
94
- print(test_pre, test_rec, test_f1)
95
- #start_b = visdata[3][0]
96
- #end_b = visdata[2][0] + 1
97
- #segments = []
98
-
99
- #for i, END in enumerate(end_b):
100
- # START = start_b[i]
101
- # segments.append(' '.join(ori_X[START:END]))
102
-
103
- return test_pre, test_rec, test_f1
104
-
105
-
106
-
 
1
  import re
 
2
  import pickle
3
  import numpy as np
4
  import random
 
7
 
8
  from model import PointerNetworks
9
  import gensim
10
+ import MeCab
11
+ import pysbd
12
+
13
+ def create_data(doc,fm,split_method):
14
+ wakati = MeCab.Tagger("-Owakati -b 81920")
15
+ seg = pysbd.Segmenter(language="ja", clean=False)
16
+ texts = []
17
+ sent = ""
18
+ label = []
19
+ alls = []
20
+ labels, text, num = [], [], []
21
+ allab, altex, fukugenss = [], [], []
22
+ for n in range(1):
23
+ fukugens = []
24
+ if split_method == "pySBD":
25
+ lines = seg.segment(doc)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  else:
27
+ doc = doc.strip().replace("。","。\n").replace(".",".\n")
28
+ doc = re.sub("(\n)+","\n",doc)
29
+ lines = doc.split("\n")
30
+ for line in lines:
31
+ line = line.strip()
32
+ if line == "":
33
+ continue
34
+ sent = wakati.parse(line).split(" ")[:-1]
35
+ flag = 0
36
+ label = []
37
+ texts = []
38
+ fukugen = []
39
+ for i in sent:
40
+ try:
41
+ texts.append(fm.vocab[i].index)
42
+ except KeyError:
43
+ texts.append(fm.vocab["<unk>"].index)
44
+ fukugen.append(i)
45
+ label.append(0)
46
+ label[-1] = 1
47
+ labels.append(np.array(label))
48
+ text.append(np.array(texts))
49
+ fukugens.append(fukugen)
50
+ allab.append(labels)
51
+ altex.append(text)
52
+ fukugenss.append(fukugens)
53
+ labels, text, fukugens= [], [], []
54
+ return altex, allab, fukugenss
55
+
56
+
57
+ def generate(doc, mymodel, fm, index2word, split_method):
58
+ X_tes, Y_tes, fukugen = create_data(doc,fm,split_method)
59
+ output_texts = mymodel.check_accuracy(X_tes, Y_tes,index2word, fukugen)
60
+
61
+ return output_texts
62
+
63
+
64
+
65
+ def setup():
66
+ with open('index2word.pickle', 'rb') as f:
67
+ index2word = pickle.load(f)
68
  with open('model.pickle', 'rb') as f:
69
  mysolver = pickle.load(f)
70
+ with open('fm.pickle', 'rb') as f:
71
+ fm = pickle.load(f)
72
+
73
+ return mysolver,fm,index2word
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
solver.py CHANGED
@@ -6,7 +6,6 @@ from torch.autograd import Variable
6
  import random
7
  from torch.nn.utils import clip_grad_norm
8
  import copy
9
- from tqdm import tqdm
10
 
11
  import os
12
  import pickle
@@ -56,76 +55,36 @@ def align_variable_numpy(X,maxL,paddingNumber):
56
 
57
 
58
  def sample_a_sorted_batch_from_numpy(numpyX,numpyY,batch_size,use_cuda):
59
-
60
-
61
- if batch_size != None:
62
- select_index = random.sample(range(len(numpyY)), batch_size)
63
- else:
64
- select_index = np.array(range(len(numpyY)))
65
 
66
  select_index = np.array(range(len(numpyX)))
67
 
68
  batch_x = [copy.deepcopy(numpyX[i]) for i in select_index]
69
  batch_y = [copy.deepcopy(numpyY[i]) for i in select_index]
70
 
71
- #print(batch_y)
72
  index_decoder_X,index_decoder_Y = get_decoder_index_XY(batch_y)
73
- #index_decoder = [get_decoder_index_XY(i) for i in batch_y]
74
- #index_decoder_X = [i[0] for i in index_decoder]
75
- #index_decoder_Y = [i[1] for i in index_decoder]
76
- #print(index_decoder_Y)
77
-
78
-
79
- #all_lens = []
80
  all_lens = np.array([len(x) for x in batch_y])
81
- #for x in batch_y:
82
- # print(x)
83
- # try:
84
- # all_lens.append(len(x))
85
- # except:
86
- # all_lens.append(1)
87
- #all_lens = np.array(all_lens)
88
 
89
  maxL = np.max(all_lens)
90
 
91
- #idx = all_lens
92
- #print(idx)
93
  idx = np.argsort(all_lens)
94
  idx = np.sort(idx)
95
- #print(idx)
96
- #idx = idx[::-1] # decreasing
97
- #print(idx)
98
  batch_x = [batch_x[i] for i in idx]
99
  batch_y = [batch_y[i] for i in idx]
100
  all_lens = all_lens[idx]
101
 
102
  index_decoder_X = np.array([index_decoder_X[i] for i in idx])
103
  index_decoder_Y = np.array([index_decoder_Y[i] for i in idx])
104
- #print(index_decoder_Y)
105
 
106
  numpy_batch_x = batch_x
107
 
108
-
109
-
110
  batch_x = align_variable_numpy(batch_x,maxL,2000001)
111
  batch_y = align_variable_numpy(batch_y,maxL,2)
112
-
113
-
114
-
115
-
116
-
117
-
118
-
119
- print(len(batch_x))
120
- #batch_x = Variable(torch.from_numpy(batch_x.astype(np.int64)))
121
  batch_x = Variable(torch.from_numpy(np.array(batch_x, dtype="int64")))
122
 
123
-
124
  if use_cuda:
125
  batch_x = batch_x.cuda()
126
 
127
-
128
-
129
  return numpy_batch_x,batch_x,batch_y,index_decoder_X,index_decoder_Y,all_lens,maxL
130
 
131
 
@@ -144,7 +103,6 @@ class TrainSolver(object):
144
  self.lr_decay_epoch = lr_decay_epoch
145
  self.eval_size = eval_size
146
 
147
-
148
  self.dev_x, self.dev_y = dev_x, dev_y
149
 
150
  self.model = model
@@ -152,294 +110,70 @@ class TrainSolver(object):
152
  self.weight_decay =weight_decay
153
 
154
 
155
-
156
-
157
- def sample_dev(self):
158
- test_tr_x = []
159
- test_tr_y = []
160
- select_index = random.sample(range(len(self.train_y)),self.eval_size)
161
- test_tr_x = [self.train_x[n] for n in select_index]
162
- test_tr_y = [self.train_y[n] for n in select_index]
163
-
164
- return test_tr_x,test_tr_y
165
-
166
-
167
-
168
-
169
-
170
-
171
-
172
  def get_batch_micro_metric(self,pre_b, ground_b, x,index2word, fukugen, nloop):
173
 
 
 
174
  tokendic = {}
175
- #with open('index2word.pickle', 'rb') as f:
176
- # index2word = pickle.load(f)
177
  for n,i in enumerate(index2word):
178
  tokendic[n] = i
179
- All_C = []
180
- All_R = []
181
- All_G = []
182
- """
183
- for i,cur_seq_y in enumerate(zip(ground_b,fukugen[nloop])):
184
- #print(fukugen[nloop])
185
- fuku = cur_seq_y[1]
186
- cur_seq_y = cur_seq_y[0]
187
- index_of_1 = np.where(cur_seq_y==1)[0]
188
- #print(index_of_1)
189
- index_pre = pre_b[i]
190
- inp = x[i]
191
- #print(len(inp))
192
- """
193
- print(len(pre_b), len(ground_b), len(fukugen))
194
- #global leng
195
- #print(fukugen)
196
  for i,cur_seq_y in enumerate(ground_b):
197
- #print(fukugen[nloop])
198
  fuku = fukugen[i]
199
- #cur_seq_y = cur_seq_y[0]
200
  index_of_1 = np.where(cur_seq_y==1)[0]
201
- #print(index_of_1)
202
  index_pre = pre_b[i]
203
  inp = x[i]
204
- #print(len(inp))
205
 
206
  index_pre = np.array(index_pre)
207
  END_B = index_of_1[-1]
208
  index_pre = index_pre[index_pre != END_B]
209
  index_of_1 = index_of_1[index_of_1 != END_B]
210
 
211
- no_correct = len(np.intersect1d(list(index_of_1), list(index_pre)))
212
- All_C.append(no_correct)
213
- All_R.append(len(index_pre))
214
- All_G.append(len(index_of_1))
215
 
216
  index_of_1 = list(index_of_1)
217
  index_pre = list(index_pre)
218
 
219
- FN = []
220
  FP = []
221
- TP = []
222
  sent = []
223
  ex = ""
224
- for j in inp:
225
- sent.append(tokendic[int(j.to('cpu').detach().numpy().copy())])
226
- for k in index_of_1:
227
- if k not in index_pre:
228
- FN.append(k)
229
- if k in index_pre:
230
- TP.append(k)
231
  for k in index_pre:
232
  if k not in index_of_1:
233
  FP.append(k)
234
- #if len(FN) == 0 and len(FP) == 0:
235
- # continue
236
- #for n,i in enumerate(sent):
237
  for n,k in enumerate(zip(sent, fuku)):
238
  f = k[1]
239
  i = k[0]
240
  if k == "<pad>":
241
  continue
242
  if n in FP:
243
- ex += f + "<FP>"
244
- else:
245
  ex += f
246
- """
247
- if n in FN:
248
- #ex += i + "<FN>"
249
- ex += i
250
- elif n in FP:
251
- ex += i + "<FP>"
252
- elif n in TP:
253
- ex += i + "<TP>"
254
  else:
255
- ex += i
256
- """
257
- #with open(str(nloop)+"_sep_nounk.txt", "a")as f:
258
- # f.write(ex+"\n")
259
- #print(i)
260
- #leng += 1
261
-
262
- return All_C,All_R,All_G
263
-
264
-
265
-
266
-
267
-
268
- def get_batch_metric(self,pre_b, ground_b):
269
-
270
- b_pr =[]
271
- b_re =[]
272
- b_f1 =[]
273
- for i,cur_seq_y in enumerate(ground_b):
274
- index_of_1 = np.where(cur_seq_y==1)[0]
275
- index_pre = pre_b[i]
276
-
277
- no_correct = len(np.intersect1d(index_of_1,index_pre))
278
-
279
- cur_pre = no_correct / len(index_pre)
280
- cur_rec = no_correct / len(index_of_1)
281
- cur_f1 = 2*cur_pre*cur_rec/ (cur_pre+cur_rec)
282
-
283
- b_pr.append(cur_pre)
284
- b_re.append(cur_rec)
285
- b_f1.append(cur_f1)
286
-
287
- return b_pr,b_re,b_f1
288
-
289
 
290
 
291
  def check_accuracy(self,data2X,data2Y,index2word, fukugen2):
292
- for nloop in tqdm(range(0,108)):
293
  dataY = data2Y[nloop]
294
  dataX = data2X[nloop]
295
  fukugen = fukugen2[nloop]
296
- #print(len(dataX), len(dataY), len(fukugen))
297
  need_loop = int(np.ceil(len(dataY) / self.batch_size))
298
- #need_loop = int(np.ceil(len(dataY) / 1))
299
- all_ave_loss =[]
300
- all_boundary =[]
301
- all_boundary_start = []
302
- all_align_matrix = []
303
- all_index_decoder_y =[]
304
- all_x_save = []
305
-
306
- all_C =[]
307
- all_R =[]
308
- all_G =[]
309
 
310
  for lp in range(need_loop):
311
  startN = lp*self.batch_size
312
  endN = (lp+1)*self.batch_size
313
  if endN > len(dataY):
314
  endN = len(dataY)
315
- #print(fukugen)
316
  fukuge = fukugen[startN:endN]
317
- #print(startN, endN)
318
- #print(len(fukugen))
319
- #print(fukugen)
320
- #for nloop in tqdm(range(0,26431)):
321
  numpy_batch_x, batch_x, batch_y, index_decoder_X, index_decoder_Y, all_lens, maxL = sample_a_sorted_batch_from_numpy(
322
  dataX[startN:endN], dataY[startN:endN], None, self.use_cuda)
323
- #numpy_batch_x, batch_x, batch_y, index_decoder_X, index_decoder_Y, all_lens, maxL = sample_a_sorted_batch_from_numpy(
324
- # dataX, dataY, None, self.use_cuda)
325
-
326
- batch_ave_loss, batch_boundary, batch_boundary_start, batch_align_matrix = self.model.predict(batch_x,
327
- index_decoder_Y,
328
- all_lens)
329
-
330
- all_ave_loss.extend([batch_ave_loss.data.item()]) #[batch_ave_loss.data[0]]
331
- all_boundary.extend(batch_boundary)
332
- all_boundary_start.extend(batch_boundary_start)
333
- all_align_matrix.extend(batch_align_matrix)
334
- all_index_decoder_y.extend(index_decoder_Y)
335
- all_x_save.extend(numpy_batch_x)
336
-
337
-
338
-
339
- #print(batch_y)
340
- ba_C,ba_R,ba_G = self.get_batch_micro_metric(batch_boundary,batch_y,batch_x,index2word, fukuge, nloop)
341
-
342
- all_C.extend(ba_C)
343
- all_R.extend(ba_R)
344
- all_G.extend(ba_G)
345
-
346
-
347
- ba_pre = np.sum(all_C)/ np.sum(all_R)
348
- ba_rec = np.sum(all_C)/ np.sum(all_G)
349
- ba_f1 = 2*ba_pre*ba_rec/ (ba_pre+ba_rec)
350
-
351
-
352
- return np.mean(all_ave_loss),ba_pre,ba_rec,ba_f1, (all_x_save,all_index_decoder_y,all_boundary, all_boundary_start, all_align_matrix)
353
-
354
-
355
-
356
-
357
-
358
-
359
-
360
- def adjust_learning_rate(self,optimizer,epoch,lr_decay=0.5, lr_decay_epoch=5):
361
-
362
- if (epoch % lr_decay_epoch == 0) and (epoch != 0):
363
- for param_group in optimizer.param_groups:
364
- param_group['lr'] *= lr_decay
365
-
366
-
367
-
368
- def train(self,n):
369
-
370
- self.test_train_x, self.test_train_y = self.sample_dev()
371
-
372
- optimizer = optim.Adam(filter(lambda p: p.requires_grad, self.model.parameters()), lr=self.lr, weight_decay=self.weight_decay)
373
-
374
-
375
-
376
- num_each_batch = int(np.round(len(self.train_y) / self.batch_size))
377
-
378
- #os.mkdir(self.save_path)
379
-
380
- best_i =0
381
- best_f1 =0
382
-
383
- for epoch in range(self.epoch):
384
- print(epoch)
385
- self.adjust_learning_rate(optimizer, epoch, 0.8, self.lr_decay_epoch)
386
-
387
- track_epoch_loss = []
388
- for iter in tqdm(range(num_each_batch)):
389
- #print("epoch:%d,iteration:%d" % (epoch, iter))
390
-
391
- self.model.zero_grad()
392
-
393
- numpy_batch_x,batch_x, batch_y, index_decoder_X, index_decoder_Y, all_lens, maxL = sample_a_sorted_batch_from_numpy(
394
- self.train_x, self.train_y, self.batch_size, self.use_cuda)
395
-
396
- neg_loss = self.model.neg_log_likelihood(batch_x, index_decoder_X, index_decoder_Y,all_lens)
397
-
398
-
399
-
400
- neg_loss_v = float(neg_loss.data.item())
401
- #print(neg_loss_v)
402
- track_epoch_loss.append(neg_loss_v)
403
-
404
- neg_loss.backward()
405
-
406
- clip_grad_norm(self.model.parameters(), 5)
407
- optimizer.step()
408
-
409
-
410
- #TODO: after each epoch,check accuracy
411
-
412
-
413
- self.model.eval()
414
-
415
- #tr_batch_ave_loss, tr_pre, tr_rec, tr_f1 ,visdata= self.check_accuracy(self.test_train_x,self.test_train_y)
416
-
417
- dev_batch_ave_loss, dev_pre, dev_rec, dev_f1, visdata =self.check_accuracy(self.dev_x,self.dev_y,n)
418
- print("f1="+str(dev_f1))
419
- print("loss="+str(dev_batch_ave_loss))
420
- """
421
- if best_f1 < dev_f1:
422
- best_f1 = dev_f1
423
- best_rec = dev_rec
424
- best_pre = dev_pre
425
- best_i = epoch
426
-
427
-
428
-
429
- save_data = [epoch,dev_batch_ave_loss,dev_pre,dev_rec,dev_f1]
430
-
431
-
432
- save_file_name = 'bs_{}_es_{}_lr_{}_lrdc_{}_wd_{}_epoch_loss_acc_pk_wd.txt'.format(self.batch_size,self.eval_size,self.lr,self.lr_decay_epoch,self.weight_decay)
433
- """
434
- #with open(os.path.join(self.save_path,save_file_name), 'a') as f:
435
- # f.write(','.join(map(str,save_data))+'\n')
436
-
437
-
438
- #if epoch % 1 ==0 and epoch !=0:
439
- # torch.save(self.model, os.path.join(self.save_path,r'model_epoch_%d.torchsave'%(epoch)))
440
-
441
 
442
- self.model.train()
 
443
 
444
- #return best_i,best_pre,best_rec,best_f1
445
- return best_i,best_f1,n
 
6
  import random
7
  from torch.nn.utils import clip_grad_norm
8
  import copy
 
9
 
10
  import os
11
  import pickle
 
55
 
56
 
57
  def sample_a_sorted_batch_from_numpy(numpyX,numpyY,batch_size,use_cuda):
58
+ select_index = np.array(range(len(numpyY)))
 
 
 
 
 
59
 
60
  select_index = np.array(range(len(numpyX)))
61
 
62
  batch_x = [copy.deepcopy(numpyX[i]) for i in select_index]
63
  batch_y = [copy.deepcopy(numpyY[i]) for i in select_index]
64
 
 
65
  index_decoder_X,index_decoder_Y = get_decoder_index_XY(batch_y)
 
 
 
 
 
 
 
66
  all_lens = np.array([len(x) for x in batch_y])
 
 
 
 
 
 
 
67
 
68
  maxL = np.max(all_lens)
69
 
 
 
70
  idx = np.argsort(all_lens)
71
  idx = np.sort(idx)
 
 
 
72
  batch_x = [batch_x[i] for i in idx]
73
  batch_y = [batch_y[i] for i in idx]
74
  all_lens = all_lens[idx]
75
 
76
  index_decoder_X = np.array([index_decoder_X[i] for i in idx])
77
  index_decoder_Y = np.array([index_decoder_Y[i] for i in idx])
 
78
 
79
  numpy_batch_x = batch_x
80
 
 
 
81
  batch_x = align_variable_numpy(batch_x,maxL,2000001)
82
  batch_y = align_variable_numpy(batch_y,maxL,2)
 
 
 
 
 
 
 
 
 
83
  batch_x = Variable(torch.from_numpy(np.array(batch_x, dtype="int64")))
84
 
 
85
  if use_cuda:
86
  batch_x = batch_x.cuda()
87
 
 
 
88
  return numpy_batch_x,batch_x,batch_y,index_decoder_X,index_decoder_Y,all_lens,maxL
89
 
90
 
 
103
  self.lr_decay_epoch = lr_decay_epoch
104
  self.eval_size = eval_size
105
 
 
106
  self.dev_x, self.dev_y = dev_x, dev_y
107
 
108
  self.model = model
 
110
  self.weight_decay =weight_decay
111
 
112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  def get_batch_micro_metric(self,pre_b, ground_b, x,index2word, fukugen, nloop):
114
 
115
+
116
+
117
  tokendic = {}
 
 
118
  for n,i in enumerate(index2word):
119
  tokendic[n] = i
120
+ sents = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  for i,cur_seq_y in enumerate(ground_b):
 
122
  fuku = fukugen[i]
 
123
  index_of_1 = np.where(cur_seq_y==1)[0]
 
124
  index_pre = pre_b[i]
125
  inp = x[i]
 
126
 
127
  index_pre = np.array(index_pre)
128
  END_B = index_of_1[-1]
129
  index_pre = index_pre[index_pre != END_B]
130
  index_of_1 = index_of_1[index_of_1 != END_B]
131
 
 
 
 
 
132
 
133
  index_of_1 = list(index_of_1)
134
  index_pre = list(index_pre)
135
 
 
136
  FP = []
 
137
  sent = []
138
  ex = ""
139
+ sent = [tokendic[int(j.to('cpu').detach().numpy().copy())] for j in inp]
 
 
 
 
 
 
140
  for k in index_pre:
141
  if k not in index_of_1:
142
  FP.append(k)
143
+ #FP = [int(j.to('cpu').detach().numpy().copy()) for j in FP]
144
+
 
145
  for n,k in enumerate(zip(sent, fuku)):
146
  f = k[1]
147
  i = k[0]
148
  if k == "<pad>":
149
  continue
150
  if n in FP:
 
 
151
  ex += f
152
+ sents.append(ex)
153
+ ex = ""
 
 
 
 
 
 
154
  else:
155
+ ex += f
156
+ sents.append(ex)
157
+ return sents
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
 
160
  def check_accuracy(self,data2X,data2Y,index2word, fukugen2):
161
+ for nloop in range(1):
162
  dataY = data2Y[nloop]
163
  dataX = data2X[nloop]
164
  fukugen = fukugen2[nloop]
 
165
  need_loop = int(np.ceil(len(dataY) / self.batch_size))
 
 
 
 
 
 
 
 
 
 
 
166
 
167
  for lp in range(need_loop):
168
  startN = lp*self.batch_size
169
  endN = (lp+1)*self.batch_size
170
  if endN > len(dataY):
171
  endN = len(dataY)
 
172
  fukuge = fukugen[startN:endN]
 
 
 
 
173
  numpy_batch_x, batch_x, batch_y, index_decoder_X, index_decoder_Y, all_lens, maxL = sample_a_sorted_batch_from_numpy(
174
  dataX[startN:endN], dataY[startN:endN], None, self.use_cuda)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
 
176
+ batch_ave_loss, batch_boundary, batch_boundary_start, batch_align_matrix = self.model.predict(batch_x,index_decoder_Y,all_lens)
177
+ output_texts = self.get_batch_micro_metric(batch_boundary,batch_y,batch_x,index2word, fukuge, nloop)
178
 
179
+ return output_texts