Spaces:
Runtime error
Runtime error
adding project source
Browse files- Dockerfile +26 -0
- app.py +202 -0
- docker-compose.yml +10 -0
- functionforDownloadButtons.py +171 -0
- kpe.py +67 -0
- kpe_ranker.py +24 -0
- labeling.py +125 -0
- logo.png +0 -0
- main.py +71 -0
- ner_data_construction.py +70 -0
- predict.py +14 -0
- ranker.py +28 -0
- requirements.txt +9 -0
- utils.py +63 -0
Dockerfile
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.9
|
2 |
+
|
3 |
+
RUN mkdir /app
|
4 |
+
WORKDIR /app
|
5 |
+
|
6 |
+
|
7 |
+
# download model and put in trained_model folder
|
8 |
+
# RUN wget https://drive.ahdsoft.dev/s/xp5Mb7bQ34Z7BRX/download/trained_model_10000.pt
|
9 |
+
# RUN mkdir trained_model
|
10 |
+
# RUN mv trained_model_10000.pt trained_model/
|
11 |
+
|
12 |
+
# download packages
|
13 |
+
COPY requirements.txt .
|
14 |
+
|
15 |
+
ENV HTTP_PROXY http://172.17.0.1:10805
|
16 |
+
ENV HTTPS_PROXY http://172.17.0.1:10805
|
17 |
+
ENV http_proxy http://172.17.0.1:10805
|
18 |
+
ENV https_proxy http://172.17.0.1:10805
|
19 |
+
|
20 |
+
RUN pip install git+https://github.com/mohammadkarrabi/NERDA.git
|
21 |
+
RUN pip install -r requirements.txt
|
22 |
+
RUN pip install sentence_transformers
|
23 |
+
|
24 |
+
COPY . .
|
25 |
+
|
26 |
+
ENTRYPOINT ["streamlit", "run", "app.py", "--server.port=7201", "--server.address=0.0.0.0", "--client.showErrorDetails=false"]
|
app.py
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import numpy as np
|
3 |
+
from pandas import DataFrame
|
4 |
+
# from keybert import KeyBERT
|
5 |
+
# For Flair (Keybert)
|
6 |
+
# from flair.embeddings import TransformerDocumentEmbeddings
|
7 |
+
import seaborn as sns
|
8 |
+
# For download buttons
|
9 |
+
from functionforDownloadButtons import download_button
|
10 |
+
import os
|
11 |
+
import json
|
12 |
+
|
13 |
+
from kpe_ranker import KpeRanker
|
14 |
+
|
15 |
+
st.set_page_config(
|
16 |
+
page_title="استخراج عبارات کلیدی عهد",
|
17 |
+
page_icon="🎈",
|
18 |
+
)
|
19 |
+
|
20 |
+
|
21 |
+
def _max_width_():
|
22 |
+
max_width_str = f"max-width: 1400px;"
|
23 |
+
st.markdown(
|
24 |
+
f"""
|
25 |
+
<style>
|
26 |
+
.reportview-container .main .block-container{{
|
27 |
+
{max_width_str}
|
28 |
+
}}
|
29 |
+
</style>
|
30 |
+
""",
|
31 |
+
unsafe_allow_html=True,
|
32 |
+
)
|
33 |
+
|
34 |
+
|
35 |
+
_max_width_()
|
36 |
+
|
37 |
+
c30, c31, c32 = st.columns([2.5, 1, 3])
|
38 |
+
|
39 |
+
with c30:
|
40 |
+
# st.image("logo.png", width=400)
|
41 |
+
st.title("🔑 استخراج عبارات کلیدی")
|
42 |
+
st.header("")
|
43 |
+
|
44 |
+
|
45 |
+
|
46 |
+
with st.expander("ℹ️ - About this app", expanded=True):
|
47 |
+
|
48 |
+
st.write(
|
49 |
+
"""
|
50 |
+
- استخراج عبارات کلیدی، محصولی نوین از شرکت عهد است که در ارزیابیهای صورتگرفته، دقت بیشتری را نسبت به رقبا از خود نشان داده است.
|
51 |
+
"""
|
52 |
+
)
|
53 |
+
|
54 |
+
st.markdown("")
|
55 |
+
|
56 |
+
st.markdown("")
|
57 |
+
# st.markdown("## **...**")
|
58 |
+
with st.form(key="my_form"):
|
59 |
+
|
60 |
+
|
61 |
+
ce, c1, ce, c2, c3 = st.columns([0.07, 1, 0.07, 5, 0.07])
|
62 |
+
with c1:
|
63 |
+
|
64 |
+
|
65 |
+
# if ModelType == "Default (DistilBERT)":
|
66 |
+
# kw_model = KeyBERT(model=roberta)
|
67 |
+
|
68 |
+
@st.cache(allow_output_mutation=True)
|
69 |
+
def load_model():
|
70 |
+
return KpeRanker()
|
71 |
+
|
72 |
+
kpe_ranker_extractor = load_model()
|
73 |
+
|
74 |
+
# else:
|
75 |
+
# @st.cache(allow_output_mutation=True)
|
76 |
+
# def load_model():
|
77 |
+
# return KeyBERT("distilbert-base-nli-mean-tokens")
|
78 |
+
|
79 |
+
# kw_model = load_model()
|
80 |
+
|
81 |
+
top_N = st.slider(
|
82 |
+
"# تعداد",
|
83 |
+
min_value=1,
|
84 |
+
max_value=30,
|
85 |
+
value=10,
|
86 |
+
help="You can choose the number of keywords/keyphrases to display. Between 1 and 30, default number is 10.",
|
87 |
+
)
|
88 |
+
# min_Ngrams = st.number_input(
|
89 |
+
# "Minimum Ngram",
|
90 |
+
# min_value=1,
|
91 |
+
# max_value=4,
|
92 |
+
# help="""The minimum value for the ngram range.
|
93 |
+
|
94 |
+
# *Keyphrase_ngram_range* sets the length of the resulting keywords/keyphrases.
|
95 |
+
|
96 |
+
# To extract keyphrases, simply set *keyphrase_ngram_range* to (1, 2) or higher depending on the number of words you would like in the resulting keyphrases.""",
|
97 |
+
# # help="Minimum value for the keyphrase_ngram_range. keyphrase_ngram_range sets the length of the resulting keywords/keyphrases. To extract keyphrases, simply set keyphrase_ngram_range to (1, # 2) or higher depending on the number of words you would like in the resulting keyphrases.",
|
98 |
+
# )
|
99 |
+
|
100 |
+
# max_Ngrams = st.number_input(
|
101 |
+
# "Maximum Ngram",
|
102 |
+
# value=2,
|
103 |
+
# min_value=1,
|
104 |
+
# max_value=4,
|
105 |
+
# help="""The maximum value for the keyphrase_ngram_range.
|
106 |
+
|
107 |
+
# *Keyphrase_ngram_range* sets the length of the resulting keywords/keyphrases.
|
108 |
+
|
109 |
+
# To extract keyphrases, simply set *keyphrase_ngram_range* to (1, 2) or higher depending on the number of words you would like in the resulting keyphrases.""",
|
110 |
+
# )
|
111 |
+
|
112 |
+
# StopWordsCheckbox = st.checkbox(
|
113 |
+
# "Remove stop words",
|
114 |
+
# help="Tick this box to remove stop words from the document (currently English only)",
|
115 |
+
# )
|
116 |
+
|
117 |
+
use_ner = st.checkbox(
|
118 |
+
"NER",
|
119 |
+
value=True,
|
120 |
+
help="استفاده از شناسایی موجودیتهای نامدار" )
|
121 |
+
|
122 |
+
|
123 |
+
with c2:
|
124 |
+
doc = st.text_area(
|
125 |
+
"متن خود را وارد کنید",
|
126 |
+
height=510,
|
127 |
+
)
|
128 |
+
|
129 |
+
MAX_WORDS = 500
|
130 |
+
import re
|
131 |
+
res = len(re.findall(r"\w+", doc))
|
132 |
+
if res > MAX_WORDS:
|
133 |
+
st.warning(
|
134 |
+
"⚠️ Your text contains "
|
135 |
+
+ str(res)
|
136 |
+
+ " words."
|
137 |
+
+ " Only the first 500 words will be reviewed. Stay tuned as increased allowance is coming! 😊"
|
138 |
+
)
|
139 |
+
|
140 |
+
doc = doc[:MAX_WORDS]
|
141 |
+
|
142 |
+
submit_button = st.form_submit_button(label="✨ پردازش")
|
143 |
+
|
144 |
+
|
145 |
+
if not submit_button:
|
146 |
+
st.stop()
|
147 |
+
|
148 |
+
|
149 |
+
|
150 |
+
|
151 |
+
|
152 |
+
|
153 |
+
|
154 |
+
|
155 |
+
|
156 |
+
#################################### get keyphrases #######################################################
|
157 |
+
|
158 |
+
keywords = kpe_ranker_extractor.extract(text=doc, count=top_N, using_ner=use_ner, return_sorted=True)
|
159 |
+
# print(keywords)
|
160 |
+
st.markdown("## **🎈 Check & download results **")
|
161 |
+
|
162 |
+
st.header("")
|
163 |
+
|
164 |
+
cs, c1, c2, c3, cLast = st.columns([2, 1.5, 1.5, 1.5, 2])
|
165 |
+
|
166 |
+
with c1:
|
167 |
+
CSVButton2 = download_button(keywords, "Data.csv", "📥 Download (.csv)")
|
168 |
+
with c2:
|
169 |
+
CSVButton2 = download_button(keywords, "Data.txt", "📥 Download (.txt)")
|
170 |
+
with c3:
|
171 |
+
CSVButton2 = download_button(keywords, "Data.json", "📥 Download (.json)")
|
172 |
+
|
173 |
+
st.header("")
|
174 |
+
|
175 |
+
df = (
|
176 |
+
DataFrame(keywords, columns=["Keyword/Keyphrase", "Relevancy"])
|
177 |
+
.sort_values(by="Relevancy", ascending=False)
|
178 |
+
.reset_index(drop=True)
|
179 |
+
)
|
180 |
+
|
181 |
+
df.index += 1
|
182 |
+
|
183 |
+
# Add styling
|
184 |
+
cmGreen = sns.light_palette("green", as_cmap=True)
|
185 |
+
cmRed = sns.light_palette("red", as_cmap=True)
|
186 |
+
df = df.style.background_gradient(
|
187 |
+
cmap=cmGreen,
|
188 |
+
subset=[
|
189 |
+
"Relevancy",
|
190 |
+
],
|
191 |
+
)
|
192 |
+
|
193 |
+
c1, c2, c3 = st.columns([1, 3, 1])
|
194 |
+
|
195 |
+
format_dictionary = {
|
196 |
+
"Relevancy": "{:.1%}",
|
197 |
+
}
|
198 |
+
|
199 |
+
df = df.format(format_dictionary)
|
200 |
+
|
201 |
+
with c2:
|
202 |
+
st.table(df)
|
docker-compose.yml
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
version: "3.8"
|
2 |
+
|
3 |
+
|
4 |
+
services:
|
5 |
+
kpe:
|
6 |
+
build: .
|
7 |
+
ports:
|
8 |
+
- "7201:7201"
|
9 |
+
volumes:
|
10 |
+
- "/home/dev/ml_models/kpe:/app/trained_model"
|
functionforDownloadButtons.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pickle
|
3 |
+
import pandas as pd
|
4 |
+
import json
|
5 |
+
import base64
|
6 |
+
import uuid
|
7 |
+
import re
|
8 |
+
|
9 |
+
import importlib.util
|
10 |
+
|
11 |
+
|
12 |
+
def import_from_file(module_name: str, filepath: str):
|
13 |
+
"""
|
14 |
+
Imports a module from file.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
module_name (str): Assigned to the module's __name__ parameter (does not
|
18 |
+
influence how the module is named outside of this function)
|
19 |
+
filepath (str): Path to the .py file
|
20 |
+
|
21 |
+
Returns:
|
22 |
+
The module
|
23 |
+
"""
|
24 |
+
spec = importlib.util.spec_from_file_location(module_name, filepath)
|
25 |
+
module = importlib.util.module_from_spec(spec)
|
26 |
+
spec.loader.exec_module(module)
|
27 |
+
return module
|
28 |
+
|
29 |
+
|
30 |
+
def notebook_header(text):
|
31 |
+
"""
|
32 |
+
Insert section header into a jinja file, formatted as notebook cell.
|
33 |
+
|
34 |
+
Leave 2 blank lines before the header.
|
35 |
+
"""
|
36 |
+
return f"""# # {text}
|
37 |
+
|
38 |
+
"""
|
39 |
+
|
40 |
+
|
41 |
+
def code_header(text):
|
42 |
+
"""
|
43 |
+
Insert section header into a jinja file, formatted as Python comment.
|
44 |
+
|
45 |
+
Leave 2 blank lines before the header.
|
46 |
+
"""
|
47 |
+
seperator_len = (75 - len(text)) / 2
|
48 |
+
seperator_len_left = math.floor(seperator_len)
|
49 |
+
seperator_len_right = math.ceil(seperator_len)
|
50 |
+
return f"# {'-' * seperator_len_left} {text} {'-' * seperator_len_right}"
|
51 |
+
|
52 |
+
|
53 |
+
def to_notebook(code):
|
54 |
+
"""Converts Python code to Jupyter notebook format."""
|
55 |
+
notebook = jupytext.reads(code, fmt="py")
|
56 |
+
return jupytext.writes(notebook, fmt="ipynb")
|
57 |
+
|
58 |
+
|
59 |
+
def open_link(url, new_tab=True):
|
60 |
+
"""Dirty hack to open a new web page with a streamlit button."""
|
61 |
+
# From: https://discuss.streamlit.io/t/how-to-link-a-button-to-a-webpage/1661/3
|
62 |
+
if new_tab:
|
63 |
+
js = f"window.open('{url}')" # New tab or window
|
64 |
+
else:
|
65 |
+
js = f"window.location.href = '{url}'" # Current tab
|
66 |
+
html = '<img src onerror="{}">'.format(js)
|
67 |
+
div = Div(text=html)
|
68 |
+
st.bokeh_chart(div)
|
69 |
+
|
70 |
+
|
71 |
+
def download_button(object_to_download, download_filename, button_text):
|
72 |
+
"""
|
73 |
+
Generates a link to download the given object_to_download.
|
74 |
+
|
75 |
+
From: https://discuss.streamlit.io/t/a-download-button-with-custom-css/4220
|
76 |
+
|
77 |
+
Params:
|
78 |
+
------
|
79 |
+
object_to_download: The object to be downloaded.
|
80 |
+
download_filename (str): filename and extension of file. e.g. mydata.csv,
|
81 |
+
some_txt_output.txt download_link_text (str): Text to display for download
|
82 |
+
link.
|
83 |
+
|
84 |
+
button_text (str): Text to display on download button (e.g. 'click here to download file')
|
85 |
+
pickle_it (bool): If True, pickle file.
|
86 |
+
|
87 |
+
Returns:
|
88 |
+
-------
|
89 |
+
(str): the anchor tag to download object_to_download
|
90 |
+
|
91 |
+
Examples:
|
92 |
+
--------
|
93 |
+
download_link(your_df, 'YOUR_DF.csv', 'Click to download data!')
|
94 |
+
download_link(your_str, 'YOUR_STRING.txt', 'Click to download text!')
|
95 |
+
|
96 |
+
"""
|
97 |
+
# if pickle_it:
|
98 |
+
# try:
|
99 |
+
# object_to_download = pickle.dumps(object_to_download)
|
100 |
+
# except pickle.PicklingError as e:
|
101 |
+
# st.write(e)
|
102 |
+
# return None
|
103 |
+
|
104 |
+
# if:
|
105 |
+
if isinstance(object_to_download, bytes):
|
106 |
+
pass
|
107 |
+
|
108 |
+
elif isinstance(object_to_download, pd.DataFrame):
|
109 |
+
object_to_download = object_to_download.to_csv(index=False)
|
110 |
+
# Try JSON encode for everything else
|
111 |
+
else:
|
112 |
+
object_to_download = json.dumps(object_to_download)
|
113 |
+
|
114 |
+
try:
|
115 |
+
# some strings <-> bytes conversions necessary here
|
116 |
+
b64 = base64.b64encode(object_to_download.encode()).decode()
|
117 |
+
except AttributeError as e:
|
118 |
+
b64 = base64.b64encode(object_to_download).decode()
|
119 |
+
|
120 |
+
button_uuid = str(uuid.uuid4()).replace("-", "")
|
121 |
+
button_id = re.sub("\d+", "", button_uuid)
|
122 |
+
|
123 |
+
custom_css = f"""
|
124 |
+
<style>
|
125 |
+
#{button_id} {{
|
126 |
+
display: inline-flex;
|
127 |
+
align-items: center;
|
128 |
+
justify-content: center;
|
129 |
+
background-color: rgb(255, 255, 255);
|
130 |
+
color: rgb(38, 39, 48);
|
131 |
+
padding: .25rem .75rem;
|
132 |
+
position: relative;
|
133 |
+
text-decoration: none;
|
134 |
+
border-radius: 4px;
|
135 |
+
border-width: 1px;
|
136 |
+
border-style: solid;
|
137 |
+
border-color: rgb(230, 234, 241);
|
138 |
+
border-image: initial;
|
139 |
+
}}
|
140 |
+
#{button_id}:hover {{
|
141 |
+
border-color: rgb(246, 51, 102);
|
142 |
+
color: rgb(246, 51, 102);
|
143 |
+
}}
|
144 |
+
#{button_id}:active {{
|
145 |
+
box-shadow: none;
|
146 |
+
background-color: rgb(246, 51, 102);
|
147 |
+
color: white;
|
148 |
+
}}
|
149 |
+
</style> """
|
150 |
+
|
151 |
+
dl_link = (
|
152 |
+
custom_css
|
153 |
+
+ f'<a download="{download_filename}" id="{button_id}" href="data:file/txt;base64,{b64}">{button_text}</a><br><br>'
|
154 |
+
)
|
155 |
+
# dl_link = f'<a download="{download_filename}" id="{button_id}" href="data:file/txt;base64,{b64}"><input type="button" kind="primary" value="{button_text}"></a><br></br>'
|
156 |
+
|
157 |
+
st.markdown(dl_link, unsafe_allow_html=True)
|
158 |
+
|
159 |
+
|
160 |
+
# def download_link(
|
161 |
+
# content, label="Download", filename="file.txt", mimetype="text/plain"
|
162 |
+
# ):
|
163 |
+
# """Create a HTML link to download a string as a file."""
|
164 |
+
# # From: https://discuss.streamlit.io/t/how-to-download-file-in-streamlit/1806/9
|
165 |
+
# b64 = base64.b64encode(
|
166 |
+
# content.encode()
|
167 |
+
# ).decode() # some strings <-> bytes conversions necessary here
|
168 |
+
# href = (
|
169 |
+
# f'<a href="data:{mimetype};base64,{b64}" download="{filename}">{label}</a>'
|
170 |
+
# )
|
171 |
+
# return href
|
kpe.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from flair.data import Sentence
|
2 |
+
from flair.models import SequenceTagger
|
3 |
+
from NERDA.models import NERDA
|
4 |
+
from hazm import word_tokenize
|
5 |
+
import flair
|
6 |
+
import utils
|
7 |
+
|
8 |
+
class KPE:
|
9 |
+
def __init__(self, trained_kpe_model, flair_ner_model, device='cpu') -> None:
|
10 |
+
self.extractor_model = NERDA(
|
11 |
+
tag_scheme = ['B-KEYWORD', 'I-KEYWORD'],
|
12 |
+
tag_outside = 'O',
|
13 |
+
transformer = 'xlm-roberta-large',
|
14 |
+
max_len=512,
|
15 |
+
device=device)
|
16 |
+
flair.device = device
|
17 |
+
|
18 |
+
self.extractor_model.load_network_from_file(trained_kpe_model)
|
19 |
+
self.ner_tagger = SequenceTagger.load(flair_ner_model)
|
20 |
+
self.IGNORE_TAGS = {'ORDINAL', 'DATE', 'CARDINAL'}
|
21 |
+
|
22 |
+
@staticmethod
|
23 |
+
def combine_keywords_nes(init_keywords, nes):
|
24 |
+
# init_keywords = list(set(init_keywords))
|
25 |
+
nes = list(set(nes))
|
26 |
+
print('nes before combined ', nes)
|
27 |
+
combined_keywords = []
|
28 |
+
for kw in init_keywords:
|
29 |
+
matched_index = utils.fuzzy_subword_match(kw, nes)
|
30 |
+
if matched_index != -1:
|
31 |
+
print(kw, nes[matched_index])
|
32 |
+
combined_keywords.append(nes[matched_index])
|
33 |
+
del nes[matched_index]
|
34 |
+
else:
|
35 |
+
combined_keywords.append(kw)
|
36 |
+
print('nes after combined ', nes)
|
37 |
+
combined_keywords.extend([n for n in nes if n not in combined_keywords])
|
38 |
+
return combined_keywords
|
39 |
+
|
40 |
+
|
41 |
+
def extract(self, txt, using_ner=True):
|
42 |
+
sentence = Sentence(txt)
|
43 |
+
|
44 |
+
# predict NER tags
|
45 |
+
if using_ner:
|
46 |
+
self.ner_tagger.predict(sentence)
|
47 |
+
nes = [entity.text for entity in sentence.get_spans('ner') if entity.tag not in self.IGNORE_TAGS]
|
48 |
+
else:
|
49 |
+
nes = []
|
50 |
+
|
51 |
+
#remove puncs
|
52 |
+
nes = list(map(utils.remove_puncs, nes))
|
53 |
+
print('nes ', nes)
|
54 |
+
sentences, tags_conf = self.extractor_model.predict_text(txt, sent_tokenize=lambda txt: [txt], word_tokenize=lambda txt: txt.split(), return_confidence=True)
|
55 |
+
init_keywords = utils.get_ne_from_iob_output(sentences, tags_conf)
|
56 |
+
init_keywords = list(map(utils.remove_puncs, init_keywords))
|
57 |
+
print('init keywords : ', init_keywords)
|
58 |
+
|
59 |
+
# combine ner response and init keywords
|
60 |
+
merged_keywords = self.combine_keywords_nes(init_keywords, nes)
|
61 |
+
|
62 |
+
#set but keep order
|
63 |
+
final_keywords = []
|
64 |
+
for kw in merged_keywords:
|
65 |
+
if kw not in final_keywords:
|
66 |
+
final_keywords.append(kw)
|
67 |
+
return final_keywords
|
kpe_ranker.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from kpe import KPE
|
2 |
+
import utils
|
3 |
+
import os
|
4 |
+
from sentence_transformers import SentenceTransformer
|
5 |
+
import ranker
|
6 |
+
|
7 |
+
class KpeRanker:
|
8 |
+
def __init__(self):
|
9 |
+
TRAINED_MODEL_ADDR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'trained_model', 'trained_model_10000.pt')
|
10 |
+
self.kpe = KPE(trained_kpe_model= TRAINED_MODEL_ADDR, flair_ner_model='flair/ner-english-ontonotes-large', device='cpu')
|
11 |
+
self.ranker_transformer = SentenceTransformer('paraphrase-multilingual-mpnet-base-v2', device='cpu')
|
12 |
+
|
13 |
+
|
14 |
+
def extract(self, text, count, using_ner, return_sorted):
|
15 |
+
text = utils.normalize(text)
|
16 |
+
kps = self.kpe.extract(text, using_ner=using_ner)
|
17 |
+
if return_sorted:
|
18 |
+
kps = ranker.get_sorted_keywords(self.ranker_transformer, text, kps)
|
19 |
+
else:
|
20 |
+
kps = [(kp, 1) for kp in kps]
|
21 |
+
if len(kps) > count:
|
22 |
+
kps = kps[:count]
|
23 |
+
return kps
|
24 |
+
|
labeling.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import jsonlines
|
2 |
+
import json
|
3 |
+
from tqdm import tqdm
|
4 |
+
import time
|
5 |
+
from openai import error as openai_error
|
6 |
+
import pandas as pd
|
7 |
+
import openai
|
8 |
+
import time
|
9 |
+
import tiktoken
|
10 |
+
import os
|
11 |
+
import glob
|
12 |
+
|
13 |
+
GPT_MODEL = 'gpt-3.5-turbo'
|
14 |
+
GPT_TOKEN_LIMIT = 1500
|
15 |
+
os.environ["OPENAI_API_KEY"] = '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'
|
16 |
+
# os.environ["OPENAI_API_KEY"] = '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'
|
17 |
+
openai.api_key = os.environ["OPENAI_API_KEY"]
|
18 |
+
|
19 |
+
LAST_INDEX_FILE_ADDR = 'last_index.txt'
|
20 |
+
TOKEN_COUNT_FILE_ADDR = 'tikitoken_count.txt'
|
21 |
+
|
22 |
+
def num_tokens(text: str, model: str = GPT_MODEL) -> int:
|
23 |
+
"""Return the number of tokens in a string."""
|
24 |
+
encoding = tiktoken.encoding_for_model(model)
|
25 |
+
return len(encoding.encode(text))
|
26 |
+
|
27 |
+
|
28 |
+
def extract_seen_ids():
|
29 |
+
seen_ids = set()
|
30 |
+
for tagged_data_addr in glob.iglob('./tagged_data*'):
|
31 |
+
seen_ids.update([json.loads(line)['id'] for line in open(tagged_data_addr)])
|
32 |
+
return seen_ids
|
33 |
+
|
34 |
+
|
35 |
+
def get_keyphrase_by_gpt(document) -> str:
|
36 |
+
global error_count
|
37 |
+
# prompt = 'extract main keywords from below document as sorted list (sort by importance). you should not use numbers for counting them. you should generate less than 10 keywords.'
|
38 |
+
# prompt = 'Output only valid JSON list. Please extract the main keywords from the following document. The keywords should be in a comma-separated list, sorted by their importance. Do not use numbers to count the keywords. Try to generate less than 10 keywords.'
|
39 |
+
prompt = 'there is a popular NLP task named KPE (keyphrase Extraction). please extract keyphrases from below article as a perfect Persian KPE model. '
|
40 |
+
role_prompt = 'return your answer using json list format'
|
41 |
+
message = prompt + '\n' + document
|
42 |
+
# message = prompt + '\n' + document
|
43 |
+
# message = document
|
44 |
+
messages = [
|
45 |
+
# {"role": "system", "content": "Output only valid JSON list"},
|
46 |
+
{"role": "system", "content": role_prompt},
|
47 |
+
{"role": "user", "content": message},
|
48 |
+
]
|
49 |
+
try:
|
50 |
+
response = openai.ChatCompletion.create(
|
51 |
+
model=GPT_MODEL,
|
52 |
+
messages=messages,
|
53 |
+
temperature=0
|
54 |
+
)
|
55 |
+
response_message = response["choices"][0]["message"]["content"]
|
56 |
+
error_count = 0
|
57 |
+
return response_message
|
58 |
+
except Exception as e:
|
59 |
+
if error_count > 3:
|
60 |
+
raise e
|
61 |
+
error_count += 1
|
62 |
+
time.sleep(20)
|
63 |
+
return []
|
64 |
+
|
65 |
+
#input_data = [json.load(line) for line in open('all_data.json').read().splitlines())
|
66 |
+
#input_data = open('all_data.json')
|
67 |
+
input_data = pd.read_csv('truncated_wiki_plus_shuffled_41203.csv')
|
68 |
+
#print('len input data : ', len(input_data))
|
69 |
+
try:
|
70 |
+
last_index = int(open(LAST_INDEX_FILE_ADDR).read())
|
71 |
+
print('load last index: ', last_index)
|
72 |
+
except:
|
73 |
+
print('error in loading last index')
|
74 |
+
last_index = 0
|
75 |
+
|
76 |
+
|
77 |
+
try:
|
78 |
+
token_count = int(open(TOKEN_COUNT_FILE_ADDR).read())
|
79 |
+
print('load token count: ', token_count)
|
80 |
+
except:
|
81 |
+
print('error in loading token_count')
|
82 |
+
token_count = 0
|
83 |
+
|
84 |
+
json_f_writer = jsonlines.open(f'tagged_data.jsonl_{str(last_index)}', mode='w')
|
85 |
+
seen_ids = extract_seen_ids()
|
86 |
+
for _, row_tup in enumerate(tqdm(input_data.iterrows(),total=len(input_data))):
|
87 |
+
index, row = row_tup
|
88 |
+
text = row['truncated_text_300']
|
89 |
+
id = row['id']
|
90 |
+
|
91 |
+
#filter by last index
|
92 |
+
if index < last_index:
|
93 |
+
print('skipping index: ', index)
|
94 |
+
continue
|
95 |
+
|
96 |
+
#filter by seen ids
|
97 |
+
if id in seen_ids:
|
98 |
+
print('repated id and skip')
|
99 |
+
continue
|
100 |
+
|
101 |
+
#filter by gpt max token
|
102 |
+
text_gpt_token_count = num_tokens(text, model=GPT_MODEL)
|
103 |
+
if text_gpt_token_count > GPT_TOKEN_LIMIT:
|
104 |
+
continue
|
105 |
+
|
106 |
+
token_count += text_gpt_token_count
|
107 |
+
keyphrases = get_keyphrase_by_gpt(text)
|
108 |
+
try:
|
109 |
+
keyphrases = json.loads(keyphrases)
|
110 |
+
if type(keyphrases) != list:
|
111 |
+
# if type(keyphrases) == str:
|
112 |
+
# keyphrases = keyphrases.split(',')
|
113 |
+
# else:
|
114 |
+
print(str(index), ': not a list!')
|
115 |
+
except:
|
116 |
+
print(str(index), ':invalid json!')
|
117 |
+
|
118 |
+
new_train_item = {'id': id, 'keyphrases':keyphrases}
|
119 |
+
json_f_writer.write(new_train_item)
|
120 |
+
last_index_f = open(LAST_INDEX_FILE_ADDR, 'w+')
|
121 |
+
last_index_f.write(str(index))
|
122 |
+
token_count_f = open(TOKEN_COUNT_FILE_ADDR, 'w+')
|
123 |
+
token_count_f.write(str(token_count))
|
124 |
+
|
125 |
+
print(token_count)
|
logo.png
ADDED
main.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import uvicorn
|
2 |
+
import os
|
3 |
+
from typing import Union
|
4 |
+
from fastapi import FastAPI
|
5 |
+
from kpe import KPE
|
6 |
+
from fastapi.middleware.cors import CORSMiddleware
|
7 |
+
# from fastapi.middleware.trustedhost import TrustedHostMiddleware
|
8 |
+
from fastapi import APIRouter , Query
|
9 |
+
from sentence_transformers import SentenceTransformer
|
10 |
+
import utils
|
11 |
+
from ranker import get_sorted_keywords
|
12 |
+
from pydantic import BaseModel
|
13 |
+
|
14 |
+
|
15 |
+
app = FastAPI(
|
16 |
+
title="AHD Persian KPE",
|
17 |
+
# version=config.settings.VERSION,
|
18 |
+
description="Keyphrase Extraction",
|
19 |
+
openapi_url="/openapi.json",
|
20 |
+
docs_url="/",
|
21 |
+
)
|
22 |
+
|
23 |
+
TRAINED_MODEL_ADDR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'trained_model', 'trained_model_10000.pt')
|
24 |
+
kpe = KPE(trained_kpe_model= TRAINED_MODEL_ADDR, flair_ner_model='flair/ner-english-ontonotes-large', device='cpu')
|
25 |
+
ranker_transformer = SentenceTransformer('paraphrase-multilingual-mpnet-base-v2', device='cpu')
|
26 |
+
# Sets all CORS enabled origins
|
27 |
+
app.add_middleware(
|
28 |
+
CORSMiddleware,
|
29 |
+
allow_origins=["*"], #str(origin) for origin in config.settings.BACKEND_CORS_ORIGINS
|
30 |
+
allow_credentials=True,
|
31 |
+
allow_methods=["*"],
|
32 |
+
allow_headers=["*"],
|
33 |
+
)
|
34 |
+
|
35 |
+
|
36 |
+
|
37 |
+
|
38 |
+
class KpeParams(BaseModel):
|
39 |
+
text:str
|
40 |
+
count:int=10000
|
41 |
+
using_ner:bool=True
|
42 |
+
return_sorted:bool=False
|
43 |
+
|
44 |
+
|
45 |
+
router = APIRouter()
|
46 |
+
|
47 |
+
|
48 |
+
@router.get("/")
|
49 |
+
def home():
|
50 |
+
return "Welcome to AHD Keyphrase Extraction Service"
|
51 |
+
|
52 |
+
|
53 |
+
@router.post("/extract", description="extract keyphrase from persian documents")
|
54 |
+
async def extract(kpe_params: KpeParams):
|
55 |
+
global kpe
|
56 |
+
text = utils.normalize(kpe_params.text)
|
57 |
+
kps = kpe.extract(text, using_ner=kpe_params.using_ner)
|
58 |
+
if kpe_params.return_sorted:
|
59 |
+
kps = get_sorted_keywords(ranker_transformer, text, kps)
|
60 |
+
else:
|
61 |
+
kps = [(kp, 1) for kp in kps]
|
62 |
+
if len(kps) > kpe_params.count:
|
63 |
+
kps = kps[:kpe_params.count]
|
64 |
+
return kps
|
65 |
+
|
66 |
+
|
67 |
+
app.include_router(router)
|
68 |
+
|
69 |
+
|
70 |
+
if __name__ == "__main__":
|
71 |
+
uvicorn.run("main:app",host="0.0.0.0", port=7201)
|
ner_data_construction.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import json
|
3 |
+
import glob
|
4 |
+
|
5 |
+
|
6 |
+
def tag_document(keywords, tokens):
|
7 |
+
|
8 |
+
# Initialize the tags list with all O's
|
9 |
+
tags = ['O'] * len(tokens)
|
10 |
+
|
11 |
+
# Loop over the keywords and tag the document
|
12 |
+
for keyword in keywords:
|
13 |
+
# Split the keyword into words
|
14 |
+
keyword_words = keyword.split()
|
15 |
+
|
16 |
+
# Loop over the words in the document
|
17 |
+
for i in range(len(tokens)):
|
18 |
+
# If the current word matches the first word of the keyword
|
19 |
+
if tokens[i] == keyword_words[0]:
|
20 |
+
match = True
|
21 |
+
# Check if the rest of the words in the keyword match the following words in the document
|
22 |
+
for j in range(1, len(keyword_words)):
|
23 |
+
if i+j >= len(tokens) or tokens[i+j] != keyword_words[j]:
|
24 |
+
match = False
|
25 |
+
break
|
26 |
+
# If all the words in the keyword match the following words in the document, tag them as B-KEYWORD and I-KEYWORD
|
27 |
+
if match:
|
28 |
+
tags[i] = 'B-KEYWORD'
|
29 |
+
for j in range(1, len(keyword_words)):
|
30 |
+
tags[i+j] = 'I-KEYWORD'
|
31 |
+
|
32 |
+
return tags
|
33 |
+
|
34 |
+
|
35 |
+
def create_tner_dataset(all_tags, all_tokens, output_file_addr):
|
36 |
+
output_f = open(output_file_addr, 'a+')
|
37 |
+
for tags, tokens in zip(all_tags, all_tokens):
|
38 |
+
for tag, tok in zip(tags, tokens):
|
39 |
+
line = '\t'.join([tok, tag])
|
40 |
+
output_f.write(line)
|
41 |
+
output_f.write('\n')
|
42 |
+
output_f.write('\n')
|
43 |
+
|
44 |
+
|
45 |
+
if __name__ == '__main__':
|
46 |
+
|
47 |
+
data_df = pd.read_csv('truncated_wiki_plus_shuffled_41203.csv')
|
48 |
+
id2document = data_df.set_index('id')['truncated_text_300'].to_dict()
|
49 |
+
|
50 |
+
|
51 |
+
#tag documents!
|
52 |
+
print('------------------ tag documents --------------------')
|
53 |
+
all_tags = []
|
54 |
+
all_tokens = []
|
55 |
+
for tagged_data_addr in glob.iglob('./tagged_data*'):
|
56 |
+
for line in open(tagged_data_addr):
|
57 |
+
item = json.loads(line)
|
58 |
+
if type(item['keyphrases']) == list:
|
59 |
+
keywords = item['keyphrases']
|
60 |
+
document = id2document[item['id']]
|
61 |
+
tokens = document.split()
|
62 |
+
tags = tag_document(keywords, tokens)
|
63 |
+
assert len(tokens) == len(tags)
|
64 |
+
all_tags.append(tags)
|
65 |
+
all_tokens.append(tokens)
|
66 |
+
print(len(keywords), len(tags), len(document.split()), len([t for t in tags if t[0]== 'B']))
|
67 |
+
nerda_dataset = {'sentences':all_tokens, 'tags': all_tags}
|
68 |
+
with open('nerda_dataset.json', 'w+') as f:
|
69 |
+
json.dump(nerda_dataset, f)
|
70 |
+
# create_tner_dataset(all_tags, all_tokens, output_file_addr='./sample_train.conll')
|
predict.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
from kpe import KPE
|
3 |
+
import sys
|
4 |
+
import os
|
5 |
+
|
6 |
+
|
7 |
+
if __name__ == '__main__':
|
8 |
+
TRAINED_MODEL_ADDR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'trained_model_10000.pt')
|
9 |
+
text_addr = sys.argv[1]
|
10 |
+
text = open(text_addr).read()
|
11 |
+
kpe = KPE(trained_kpe_model= TRAINED_MODEL_ADDR, flair_ner_model='flair/ner-english-ontonotes-large', device='cpu')
|
12 |
+
s =time.time()
|
13 |
+
print(kpe.extract(text))
|
14 |
+
print(time.time() - s)
|
ranker.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This is a simple application for sentence embeddings: semantic search
|
3 |
+
|
4 |
+
We have a corpus with various sentences. Then, for a given query sentence,
|
5 |
+
we want to find the most similar sentence in this corpus.
|
6 |
+
|
7 |
+
This script outputs for various queries the top 5 most similar sentences in the corpus.
|
8 |
+
"""
|
9 |
+
from sentence_transformers import util
|
10 |
+
import torch
|
11 |
+
|
12 |
+
|
13 |
+
def get_sorted_keywords(embedder, text, keywords):
|
14 |
+
top_k = len(keywords)
|
15 |
+
keywords_embedding = embedder.encode(keywords, convert_to_tensor=True)
|
16 |
+
text_embedding = embedder.encode(text, convert_to_tensor=True)
|
17 |
+
|
18 |
+
cos_scores = util.cos_sim(keywords_embedding, text_embedding).squeeze(dim=1)
|
19 |
+
# print(cos_scores.size())
|
20 |
+
top_results = torch.topk(cos_scores, k=top_k)
|
21 |
+
return [(keywords[idx], top_results[0][index].item()) for index, idx in enumerate(top_results[1])]
|
22 |
+
# return [keywords[idx] for idx in top_results[1]]
|
23 |
+
# for score, idx in zip(top_results[0], top_results[1]):
|
24 |
+
# print(keywords[idx], "(Score: {:.4f})".format(score))
|
25 |
+
|
26 |
+
|
27 |
+
|
28 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
fastapi
|
2 |
+
uvicorn
|
3 |
+
flair
|
4 |
+
hazm
|
5 |
+
parsinorm
|
6 |
+
pydantic
|
7 |
+
seaborn
|
8 |
+
streamlit
|
9 |
+
altair==4.2.2
|
utils.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from parsinorm import General_normalization
|
2 |
+
import re
|
3 |
+
|
4 |
+
|
5 |
+
def get_ne_from_iob_output(sentences, tags_conf):
|
6 |
+
sentences = sentences[0]
|
7 |
+
tags = tags_conf[0][0]
|
8 |
+
confs = tags_conf[1][0]
|
9 |
+
|
10 |
+
seen_b = False
|
11 |
+
keywords = {}
|
12 |
+
new_token = []
|
13 |
+
begin_index = 0
|
14 |
+
for index, (tok, tag) in enumerate(zip(sentences, tags)):
|
15 |
+
if tag[0] == 'I' and seen_b:
|
16 |
+
new_token.append(tok)
|
17 |
+
if tag[0] == 'B':
|
18 |
+
if new_token:
|
19 |
+
keywords[' '.join(new_token)] = confs[begin_index]
|
20 |
+
new_token = []
|
21 |
+
new_token.append(tok)
|
22 |
+
begin_index = index
|
23 |
+
seen_b = True
|
24 |
+
if tag[0] == 'O':
|
25 |
+
if new_token:
|
26 |
+
keywords[' '.join(new_token)] = confs[begin_index]
|
27 |
+
new_token = []
|
28 |
+
seen_b = False
|
29 |
+
|
30 |
+
# print('keywords before sort: ', [k for k in keywords.keys])
|
31 |
+
#sort
|
32 |
+
sorted_keywords = sorted(list(keywords.keys()), key=lambda kw: keywords[kw], reverse=True)
|
33 |
+
print('keywords after sort: ', sorted_keywords)
|
34 |
+
return sorted_keywords
|
35 |
+
|
36 |
+
|
37 |
+
def fuzzy_subword_match(key, words):
|
38 |
+
for index, w in enumerate(words):
|
39 |
+
if (len(key.split()) < len(w.split())) and key in w:
|
40 |
+
return index
|
41 |
+
return -1
|
42 |
+
|
43 |
+
|
44 |
+
#normalize
|
45 |
+
def normalize(txt):
|
46 |
+
general_normalization = General_normalization()
|
47 |
+
txt = general_normalization.alphabet_correction(txt)
|
48 |
+
txt = general_normalization.semi_space_correction(txt)
|
49 |
+
txt = general_normalization.english_correction(txt)
|
50 |
+
txt = general_normalization.html_correction(txt)
|
51 |
+
txt = general_normalization.arabic_correction(txt)
|
52 |
+
txt = general_normalization.punctuation_correction(txt)
|
53 |
+
txt = general_normalization.specials_chars(txt)
|
54 |
+
txt = general_normalization.remove_emojis(txt)
|
55 |
+
txt = general_normalization.number_correction(txt)
|
56 |
+
txt = general_normalization.remove_not_desired_chars(txt)
|
57 |
+
txt = general_normalization.remove_repeated_punctuation(txt)
|
58 |
+
return ' '.join(txt.replace('\n', ' ').replace('\t', ' ').replace('\r', ' ').split())
|
59 |
+
|
60 |
+
|
61 |
+
|
62 |
+
def remove_puncs(txt):
|
63 |
+
return re.sub('[!?،\(\)\.]','', txt)
|