mrmft commited on
Commit
4da642e
1 Parent(s): 39d7a1a

adding project source

Browse files
Files changed (14) hide show
  1. Dockerfile +26 -0
  2. app.py +202 -0
  3. docker-compose.yml +10 -0
  4. functionforDownloadButtons.py +171 -0
  5. kpe.py +67 -0
  6. kpe_ranker.py +24 -0
  7. labeling.py +125 -0
  8. logo.png +0 -0
  9. main.py +71 -0
  10. ner_data_construction.py +70 -0
  11. predict.py +14 -0
  12. ranker.py +28 -0
  13. requirements.txt +9 -0
  14. utils.py +63 -0
Dockerfile ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9
2
+
3
+ RUN mkdir /app
4
+ WORKDIR /app
5
+
6
+
7
+ # download model and put in trained_model folder
8
+ # RUN wget https://drive.ahdsoft.dev/s/xp5Mb7bQ34Z7BRX/download/trained_model_10000.pt
9
+ # RUN mkdir trained_model
10
+ # RUN mv trained_model_10000.pt trained_model/
11
+
12
+ # download packages
13
+ COPY requirements.txt .
14
+
15
+ ENV HTTP_PROXY http://172.17.0.1:10805
16
+ ENV HTTPS_PROXY http://172.17.0.1:10805
17
+ ENV http_proxy http://172.17.0.1:10805
18
+ ENV https_proxy http://172.17.0.1:10805
19
+
20
+ RUN pip install git+https://github.com/mohammadkarrabi/NERDA.git
21
+ RUN pip install -r requirements.txt
22
+ RUN pip install sentence_transformers
23
+
24
+ COPY . .
25
+
26
+ ENTRYPOINT ["streamlit", "run", "app.py", "--server.port=7201", "--server.address=0.0.0.0", "--client.showErrorDetails=false"]
app.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import numpy as np
3
+ from pandas import DataFrame
4
+ # from keybert import KeyBERT
5
+ # For Flair (Keybert)
6
+ # from flair.embeddings import TransformerDocumentEmbeddings
7
+ import seaborn as sns
8
+ # For download buttons
9
+ from functionforDownloadButtons import download_button
10
+ import os
11
+ import json
12
+
13
+ from kpe_ranker import KpeRanker
14
+
15
+ st.set_page_config(
16
+ page_title="استخراج عبارات کلیدی عهد",
17
+ page_icon="🎈",
18
+ )
19
+
20
+
21
+ def _max_width_():
22
+ max_width_str = f"max-width: 1400px;"
23
+ st.markdown(
24
+ f"""
25
+ <style>
26
+ .reportview-container .main .block-container{{
27
+ {max_width_str}
28
+ }}
29
+ </style>
30
+ """,
31
+ unsafe_allow_html=True,
32
+ )
33
+
34
+
35
+ _max_width_()
36
+
37
+ c30, c31, c32 = st.columns([2.5, 1, 3])
38
+
39
+ with c30:
40
+ # st.image("logo.png", width=400)
41
+ st.title("🔑 استخراج عبارات کلیدی")
42
+ st.header("")
43
+
44
+
45
+
46
+ with st.expander("ℹ️ - About this app", expanded=True):
47
+
48
+ st.write(
49
+ """
50
+ - استخراج عبارات کلیدی، محصولی نوین از شرکت عهد است که در ارزیابی‌های صورت‌گرفته، دقت بیشتری را نسبت به رقبا از خود نشان داده است.
51
+ """
52
+ )
53
+
54
+ st.markdown("")
55
+
56
+ st.markdown("")
57
+ # st.markdown("## **...**")
58
+ with st.form(key="my_form"):
59
+
60
+
61
+ ce, c1, ce, c2, c3 = st.columns([0.07, 1, 0.07, 5, 0.07])
62
+ with c1:
63
+
64
+
65
+ # if ModelType == "Default (DistilBERT)":
66
+ # kw_model = KeyBERT(model=roberta)
67
+
68
+ @st.cache(allow_output_mutation=True)
69
+ def load_model():
70
+ return KpeRanker()
71
+
72
+ kpe_ranker_extractor = load_model()
73
+
74
+ # else:
75
+ # @st.cache(allow_output_mutation=True)
76
+ # def load_model():
77
+ # return KeyBERT("distilbert-base-nli-mean-tokens")
78
+
79
+ # kw_model = load_model()
80
+
81
+ top_N = st.slider(
82
+ "# تعداد",
83
+ min_value=1,
84
+ max_value=30,
85
+ value=10,
86
+ help="You can choose the number of keywords/keyphrases to display. Between 1 and 30, default number is 10.",
87
+ )
88
+ # min_Ngrams = st.number_input(
89
+ # "Minimum Ngram",
90
+ # min_value=1,
91
+ # max_value=4,
92
+ # help="""The minimum value for the ngram range.
93
+
94
+ # *Keyphrase_ngram_range* sets the length of the resulting keywords/keyphrases.
95
+
96
+ # To extract keyphrases, simply set *keyphrase_ngram_range* to (1, 2) or higher depending on the number of words you would like in the resulting keyphrases.""",
97
+ # # help="Minimum value for the keyphrase_ngram_range. keyphrase_ngram_range sets the length of the resulting keywords/keyphrases. To extract keyphrases, simply set keyphrase_ngram_range to (1, # 2) or higher depending on the number of words you would like in the resulting keyphrases.",
98
+ # )
99
+
100
+ # max_Ngrams = st.number_input(
101
+ # "Maximum Ngram",
102
+ # value=2,
103
+ # min_value=1,
104
+ # max_value=4,
105
+ # help="""The maximum value for the keyphrase_ngram_range.
106
+
107
+ # *Keyphrase_ngram_range* sets the length of the resulting keywords/keyphrases.
108
+
109
+ # To extract keyphrases, simply set *keyphrase_ngram_range* to (1, 2) or higher depending on the number of words you would like in the resulting keyphrases.""",
110
+ # )
111
+
112
+ # StopWordsCheckbox = st.checkbox(
113
+ # "Remove stop words",
114
+ # help="Tick this box to remove stop words from the document (currently English only)",
115
+ # )
116
+
117
+ use_ner = st.checkbox(
118
+ "NER",
119
+ value=True,
120
+ help="استفاده از شناسایی موجودیت‌های نام‌دار" )
121
+
122
+
123
+ with c2:
124
+ doc = st.text_area(
125
+ "متن خود را وارد کنید",
126
+ height=510,
127
+ )
128
+
129
+ MAX_WORDS = 500
130
+ import re
131
+ res = len(re.findall(r"\w+", doc))
132
+ if res > MAX_WORDS:
133
+ st.warning(
134
+ "⚠️ Your text contains "
135
+ + str(res)
136
+ + " words."
137
+ + " Only the first 500 words will be reviewed. Stay tuned as increased allowance is coming! 😊"
138
+ )
139
+
140
+ doc = doc[:MAX_WORDS]
141
+
142
+ submit_button = st.form_submit_button(label="✨ پردازش")
143
+
144
+
145
+ if not submit_button:
146
+ st.stop()
147
+
148
+
149
+
150
+
151
+
152
+
153
+
154
+
155
+
156
+ #################################### get keyphrases #######################################################
157
+
158
+ keywords = kpe_ranker_extractor.extract(text=doc, count=top_N, using_ner=use_ner, return_sorted=True)
159
+ # print(keywords)
160
+ st.markdown("## **🎈 Check & download results **")
161
+
162
+ st.header("")
163
+
164
+ cs, c1, c2, c3, cLast = st.columns([2, 1.5, 1.5, 1.5, 2])
165
+
166
+ with c1:
167
+ CSVButton2 = download_button(keywords, "Data.csv", "📥 Download (.csv)")
168
+ with c2:
169
+ CSVButton2 = download_button(keywords, "Data.txt", "📥 Download (.txt)")
170
+ with c3:
171
+ CSVButton2 = download_button(keywords, "Data.json", "📥 Download (.json)")
172
+
173
+ st.header("")
174
+
175
+ df = (
176
+ DataFrame(keywords, columns=["Keyword/Keyphrase", "Relevancy"])
177
+ .sort_values(by="Relevancy", ascending=False)
178
+ .reset_index(drop=True)
179
+ )
180
+
181
+ df.index += 1
182
+
183
+ # Add styling
184
+ cmGreen = sns.light_palette("green", as_cmap=True)
185
+ cmRed = sns.light_palette("red", as_cmap=True)
186
+ df = df.style.background_gradient(
187
+ cmap=cmGreen,
188
+ subset=[
189
+ "Relevancy",
190
+ ],
191
+ )
192
+
193
+ c1, c2, c3 = st.columns([1, 3, 1])
194
+
195
+ format_dictionary = {
196
+ "Relevancy": "{:.1%}",
197
+ }
198
+
199
+ df = df.format(format_dictionary)
200
+
201
+ with c2:
202
+ st.table(df)
docker-compose.yml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ version: "3.8"
2
+
3
+
4
+ services:
5
+ kpe:
6
+ build: .
7
+ ports:
8
+ - "7201:7201"
9
+ volumes:
10
+ - "/home/dev/ml_models/kpe:/app/trained_model"
functionforDownloadButtons.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pickle
3
+ import pandas as pd
4
+ import json
5
+ import base64
6
+ import uuid
7
+ import re
8
+
9
+ import importlib.util
10
+
11
+
12
+ def import_from_file(module_name: str, filepath: str):
13
+ """
14
+ Imports a module from file.
15
+
16
+ Args:
17
+ module_name (str): Assigned to the module's __name__ parameter (does not
18
+ influence how the module is named outside of this function)
19
+ filepath (str): Path to the .py file
20
+
21
+ Returns:
22
+ The module
23
+ """
24
+ spec = importlib.util.spec_from_file_location(module_name, filepath)
25
+ module = importlib.util.module_from_spec(spec)
26
+ spec.loader.exec_module(module)
27
+ return module
28
+
29
+
30
+ def notebook_header(text):
31
+ """
32
+ Insert section header into a jinja file, formatted as notebook cell.
33
+
34
+ Leave 2 blank lines before the header.
35
+ """
36
+ return f"""# # {text}
37
+
38
+ """
39
+
40
+
41
+ def code_header(text):
42
+ """
43
+ Insert section header into a jinja file, formatted as Python comment.
44
+
45
+ Leave 2 blank lines before the header.
46
+ """
47
+ seperator_len = (75 - len(text)) / 2
48
+ seperator_len_left = math.floor(seperator_len)
49
+ seperator_len_right = math.ceil(seperator_len)
50
+ return f"# {'-' * seperator_len_left} {text} {'-' * seperator_len_right}"
51
+
52
+
53
+ def to_notebook(code):
54
+ """Converts Python code to Jupyter notebook format."""
55
+ notebook = jupytext.reads(code, fmt="py")
56
+ return jupytext.writes(notebook, fmt="ipynb")
57
+
58
+
59
+ def open_link(url, new_tab=True):
60
+ """Dirty hack to open a new web page with a streamlit button."""
61
+ # From: https://discuss.streamlit.io/t/how-to-link-a-button-to-a-webpage/1661/3
62
+ if new_tab:
63
+ js = f"window.open('{url}')" # New tab or window
64
+ else:
65
+ js = f"window.location.href = '{url}'" # Current tab
66
+ html = '<img src onerror="{}">'.format(js)
67
+ div = Div(text=html)
68
+ st.bokeh_chart(div)
69
+
70
+
71
+ def download_button(object_to_download, download_filename, button_text):
72
+ """
73
+ Generates a link to download the given object_to_download.
74
+
75
+ From: https://discuss.streamlit.io/t/a-download-button-with-custom-css/4220
76
+
77
+ Params:
78
+ ------
79
+ object_to_download: The object to be downloaded.
80
+ download_filename (str): filename and extension of file. e.g. mydata.csv,
81
+ some_txt_output.txt download_link_text (str): Text to display for download
82
+ link.
83
+
84
+ button_text (str): Text to display on download button (e.g. 'click here to download file')
85
+ pickle_it (bool): If True, pickle file.
86
+
87
+ Returns:
88
+ -------
89
+ (str): the anchor tag to download object_to_download
90
+
91
+ Examples:
92
+ --------
93
+ download_link(your_df, 'YOUR_DF.csv', 'Click to download data!')
94
+ download_link(your_str, 'YOUR_STRING.txt', 'Click to download text!')
95
+
96
+ """
97
+ # if pickle_it:
98
+ # try:
99
+ # object_to_download = pickle.dumps(object_to_download)
100
+ # except pickle.PicklingError as e:
101
+ # st.write(e)
102
+ # return None
103
+
104
+ # if:
105
+ if isinstance(object_to_download, bytes):
106
+ pass
107
+
108
+ elif isinstance(object_to_download, pd.DataFrame):
109
+ object_to_download = object_to_download.to_csv(index=False)
110
+ # Try JSON encode for everything else
111
+ else:
112
+ object_to_download = json.dumps(object_to_download)
113
+
114
+ try:
115
+ # some strings <-> bytes conversions necessary here
116
+ b64 = base64.b64encode(object_to_download.encode()).decode()
117
+ except AttributeError as e:
118
+ b64 = base64.b64encode(object_to_download).decode()
119
+
120
+ button_uuid = str(uuid.uuid4()).replace("-", "")
121
+ button_id = re.sub("\d+", "", button_uuid)
122
+
123
+ custom_css = f"""
124
+ <style>
125
+ #{button_id} {{
126
+ display: inline-flex;
127
+ align-items: center;
128
+ justify-content: center;
129
+ background-color: rgb(255, 255, 255);
130
+ color: rgb(38, 39, 48);
131
+ padding: .25rem .75rem;
132
+ position: relative;
133
+ text-decoration: none;
134
+ border-radius: 4px;
135
+ border-width: 1px;
136
+ border-style: solid;
137
+ border-color: rgb(230, 234, 241);
138
+ border-image: initial;
139
+ }}
140
+ #{button_id}:hover {{
141
+ border-color: rgb(246, 51, 102);
142
+ color: rgb(246, 51, 102);
143
+ }}
144
+ #{button_id}:active {{
145
+ box-shadow: none;
146
+ background-color: rgb(246, 51, 102);
147
+ color: white;
148
+ }}
149
+ </style> """
150
+
151
+ dl_link = (
152
+ custom_css
153
+ + f'<a download="{download_filename}" id="{button_id}" href="data:file/txt;base64,{b64}">{button_text}</a><br><br>'
154
+ )
155
+ # dl_link = f'<a download="{download_filename}" id="{button_id}" href="data:file/txt;base64,{b64}"><input type="button" kind="primary" value="{button_text}"></a><br></br>'
156
+
157
+ st.markdown(dl_link, unsafe_allow_html=True)
158
+
159
+
160
+ # def download_link(
161
+ # content, label="Download", filename="file.txt", mimetype="text/plain"
162
+ # ):
163
+ # """Create a HTML link to download a string as a file."""
164
+ # # From: https://discuss.streamlit.io/t/how-to-download-file-in-streamlit/1806/9
165
+ # b64 = base64.b64encode(
166
+ # content.encode()
167
+ # ).decode() # some strings <-> bytes conversions necessary here
168
+ # href = (
169
+ # f'<a href="data:{mimetype};base64,{b64}" download="{filename}">{label}</a>'
170
+ # )
171
+ # return href
kpe.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flair.data import Sentence
2
+ from flair.models import SequenceTagger
3
+ from NERDA.models import NERDA
4
+ from hazm import word_tokenize
5
+ import flair
6
+ import utils
7
+
8
+ class KPE:
9
+ def __init__(self, trained_kpe_model, flair_ner_model, device='cpu') -> None:
10
+ self.extractor_model = NERDA(
11
+ tag_scheme = ['B-KEYWORD', 'I-KEYWORD'],
12
+ tag_outside = 'O',
13
+ transformer = 'xlm-roberta-large',
14
+ max_len=512,
15
+ device=device)
16
+ flair.device = device
17
+
18
+ self.extractor_model.load_network_from_file(trained_kpe_model)
19
+ self.ner_tagger = SequenceTagger.load(flair_ner_model)
20
+ self.IGNORE_TAGS = {'ORDINAL', 'DATE', 'CARDINAL'}
21
+
22
+ @staticmethod
23
+ def combine_keywords_nes(init_keywords, nes):
24
+ # init_keywords = list(set(init_keywords))
25
+ nes = list(set(nes))
26
+ print('nes before combined ', nes)
27
+ combined_keywords = []
28
+ for kw in init_keywords:
29
+ matched_index = utils.fuzzy_subword_match(kw, nes)
30
+ if matched_index != -1:
31
+ print(kw, nes[matched_index])
32
+ combined_keywords.append(nes[matched_index])
33
+ del nes[matched_index]
34
+ else:
35
+ combined_keywords.append(kw)
36
+ print('nes after combined ', nes)
37
+ combined_keywords.extend([n for n in nes if n not in combined_keywords])
38
+ return combined_keywords
39
+
40
+
41
+ def extract(self, txt, using_ner=True):
42
+ sentence = Sentence(txt)
43
+
44
+ # predict NER tags
45
+ if using_ner:
46
+ self.ner_tagger.predict(sentence)
47
+ nes = [entity.text for entity in sentence.get_spans('ner') if entity.tag not in self.IGNORE_TAGS]
48
+ else:
49
+ nes = []
50
+
51
+ #remove puncs
52
+ nes = list(map(utils.remove_puncs, nes))
53
+ print('nes ', nes)
54
+ sentences, tags_conf = self.extractor_model.predict_text(txt, sent_tokenize=lambda txt: [txt], word_tokenize=lambda txt: txt.split(), return_confidence=True)
55
+ init_keywords = utils.get_ne_from_iob_output(sentences, tags_conf)
56
+ init_keywords = list(map(utils.remove_puncs, init_keywords))
57
+ print('init keywords : ', init_keywords)
58
+
59
+ # combine ner response and init keywords
60
+ merged_keywords = self.combine_keywords_nes(init_keywords, nes)
61
+
62
+ #set but keep order
63
+ final_keywords = []
64
+ for kw in merged_keywords:
65
+ if kw not in final_keywords:
66
+ final_keywords.append(kw)
67
+ return final_keywords
kpe_ranker.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from kpe import KPE
2
+ import utils
3
+ import os
4
+ from sentence_transformers import SentenceTransformer
5
+ import ranker
6
+
7
+ class KpeRanker:
8
+ def __init__(self):
9
+ TRAINED_MODEL_ADDR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'trained_model', 'trained_model_10000.pt')
10
+ self.kpe = KPE(trained_kpe_model= TRAINED_MODEL_ADDR, flair_ner_model='flair/ner-english-ontonotes-large', device='cpu')
11
+ self.ranker_transformer = SentenceTransformer('paraphrase-multilingual-mpnet-base-v2', device='cpu')
12
+
13
+
14
+ def extract(self, text, count, using_ner, return_sorted):
15
+ text = utils.normalize(text)
16
+ kps = self.kpe.extract(text, using_ner=using_ner)
17
+ if return_sorted:
18
+ kps = ranker.get_sorted_keywords(self.ranker_transformer, text, kps)
19
+ else:
20
+ kps = [(kp, 1) for kp in kps]
21
+ if len(kps) > count:
22
+ kps = kps[:count]
23
+ return kps
24
+
labeling.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import jsonlines
2
+ import json
3
+ from tqdm import tqdm
4
+ import time
5
+ from openai import error as openai_error
6
+ import pandas as pd
7
+ import openai
8
+ import time
9
+ import tiktoken
10
+ import os
11
+ import glob
12
+
13
+ GPT_MODEL = 'gpt-3.5-turbo'
14
+ GPT_TOKEN_LIMIT = 1500
15
+ os.environ["OPENAI_API_KEY"] = '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'
16
+ # os.environ["OPENAI_API_KEY"] = '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'
17
+ openai.api_key = os.environ["OPENAI_API_KEY"]
18
+
19
+ LAST_INDEX_FILE_ADDR = 'last_index.txt'
20
+ TOKEN_COUNT_FILE_ADDR = 'tikitoken_count.txt'
21
+
22
+ def num_tokens(text: str, model: str = GPT_MODEL) -> int:
23
+ """Return the number of tokens in a string."""
24
+ encoding = tiktoken.encoding_for_model(model)
25
+ return len(encoding.encode(text))
26
+
27
+
28
+ def extract_seen_ids():
29
+ seen_ids = set()
30
+ for tagged_data_addr in glob.iglob('./tagged_data*'):
31
+ seen_ids.update([json.loads(line)['id'] for line in open(tagged_data_addr)])
32
+ return seen_ids
33
+
34
+
35
+ def get_keyphrase_by_gpt(document) -> str:
36
+ global error_count
37
+ # prompt = 'extract main keywords from below document as sorted list (sort by importance). you should not use numbers for counting them. you should generate less than 10 keywords.'
38
+ # prompt = 'Output only valid JSON list. Please extract the main keywords from the following document. The keywords should be in a comma-separated list, sorted by their importance. Do not use numbers to count the keywords. Try to generate less than 10 keywords.'
39
+ prompt = 'there is a popular NLP task named KPE (keyphrase Extraction). please extract keyphrases from below article as a perfect Persian KPE model. '
40
+ role_prompt = 'return your answer using json list format'
41
+ message = prompt + '\n' + document
42
+ # message = prompt + '\n' + document
43
+ # message = document
44
+ messages = [
45
+ # {"role": "system", "content": "Output only valid JSON list"},
46
+ {"role": "system", "content": role_prompt},
47
+ {"role": "user", "content": message},
48
+ ]
49
+ try:
50
+ response = openai.ChatCompletion.create(
51
+ model=GPT_MODEL,
52
+ messages=messages,
53
+ temperature=0
54
+ )
55
+ response_message = response["choices"][0]["message"]["content"]
56
+ error_count = 0
57
+ return response_message
58
+ except Exception as e:
59
+ if error_count > 3:
60
+ raise e
61
+ error_count += 1
62
+ time.sleep(20)
63
+ return []
64
+
65
+ #input_data = [json.load(line) for line in open('all_data.json').read().splitlines())
66
+ #input_data = open('all_data.json')
67
+ input_data = pd.read_csv('truncated_wiki_plus_shuffled_41203.csv')
68
+ #print('len input data : ', len(input_data))
69
+ try:
70
+ last_index = int(open(LAST_INDEX_FILE_ADDR).read())
71
+ print('load last index: ', last_index)
72
+ except:
73
+ print('error in loading last index')
74
+ last_index = 0
75
+
76
+
77
+ try:
78
+ token_count = int(open(TOKEN_COUNT_FILE_ADDR).read())
79
+ print('load token count: ', token_count)
80
+ except:
81
+ print('error in loading token_count')
82
+ token_count = 0
83
+
84
+ json_f_writer = jsonlines.open(f'tagged_data.jsonl_{str(last_index)}', mode='w')
85
+ seen_ids = extract_seen_ids()
86
+ for _, row_tup in enumerate(tqdm(input_data.iterrows(),total=len(input_data))):
87
+ index, row = row_tup
88
+ text = row['truncated_text_300']
89
+ id = row['id']
90
+
91
+ #filter by last index
92
+ if index < last_index:
93
+ print('skipping index: ', index)
94
+ continue
95
+
96
+ #filter by seen ids
97
+ if id in seen_ids:
98
+ print('repated id and skip')
99
+ continue
100
+
101
+ #filter by gpt max token
102
+ text_gpt_token_count = num_tokens(text, model=GPT_MODEL)
103
+ if text_gpt_token_count > GPT_TOKEN_LIMIT:
104
+ continue
105
+
106
+ token_count += text_gpt_token_count
107
+ keyphrases = get_keyphrase_by_gpt(text)
108
+ try:
109
+ keyphrases = json.loads(keyphrases)
110
+ if type(keyphrases) != list:
111
+ # if type(keyphrases) == str:
112
+ # keyphrases = keyphrases.split(',')
113
+ # else:
114
+ print(str(index), ': not a list!')
115
+ except:
116
+ print(str(index), ':invalid json!')
117
+
118
+ new_train_item = {'id': id, 'keyphrases':keyphrases}
119
+ json_f_writer.write(new_train_item)
120
+ last_index_f = open(LAST_INDEX_FILE_ADDR, 'w+')
121
+ last_index_f.write(str(index))
122
+ token_count_f = open(TOKEN_COUNT_FILE_ADDR, 'w+')
123
+ token_count_f.write(str(token_count))
124
+
125
+ print(token_count)
logo.png ADDED
main.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uvicorn
2
+ import os
3
+ from typing import Union
4
+ from fastapi import FastAPI
5
+ from kpe import KPE
6
+ from fastapi.middleware.cors import CORSMiddleware
7
+ # from fastapi.middleware.trustedhost import TrustedHostMiddleware
8
+ from fastapi import APIRouter , Query
9
+ from sentence_transformers import SentenceTransformer
10
+ import utils
11
+ from ranker import get_sorted_keywords
12
+ from pydantic import BaseModel
13
+
14
+
15
+ app = FastAPI(
16
+ title="AHD Persian KPE",
17
+ # version=config.settings.VERSION,
18
+ description="Keyphrase Extraction",
19
+ openapi_url="/openapi.json",
20
+ docs_url="/",
21
+ )
22
+
23
+ TRAINED_MODEL_ADDR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'trained_model', 'trained_model_10000.pt')
24
+ kpe = KPE(trained_kpe_model= TRAINED_MODEL_ADDR, flair_ner_model='flair/ner-english-ontonotes-large', device='cpu')
25
+ ranker_transformer = SentenceTransformer('paraphrase-multilingual-mpnet-base-v2', device='cpu')
26
+ # Sets all CORS enabled origins
27
+ app.add_middleware(
28
+ CORSMiddleware,
29
+ allow_origins=["*"], #str(origin) for origin in config.settings.BACKEND_CORS_ORIGINS
30
+ allow_credentials=True,
31
+ allow_methods=["*"],
32
+ allow_headers=["*"],
33
+ )
34
+
35
+
36
+
37
+
38
+ class KpeParams(BaseModel):
39
+ text:str
40
+ count:int=10000
41
+ using_ner:bool=True
42
+ return_sorted:bool=False
43
+
44
+
45
+ router = APIRouter()
46
+
47
+
48
+ @router.get("/")
49
+ def home():
50
+ return "Welcome to AHD Keyphrase Extraction Service"
51
+
52
+
53
+ @router.post("/extract", description="extract keyphrase from persian documents")
54
+ async def extract(kpe_params: KpeParams):
55
+ global kpe
56
+ text = utils.normalize(kpe_params.text)
57
+ kps = kpe.extract(text, using_ner=kpe_params.using_ner)
58
+ if kpe_params.return_sorted:
59
+ kps = get_sorted_keywords(ranker_transformer, text, kps)
60
+ else:
61
+ kps = [(kp, 1) for kp in kps]
62
+ if len(kps) > kpe_params.count:
63
+ kps = kps[:kpe_params.count]
64
+ return kps
65
+
66
+
67
+ app.include_router(router)
68
+
69
+
70
+ if __name__ == "__main__":
71
+ uvicorn.run("main:app",host="0.0.0.0", port=7201)
ner_data_construction.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import json
3
+ import glob
4
+
5
+
6
+ def tag_document(keywords, tokens):
7
+
8
+ # Initialize the tags list with all O's
9
+ tags = ['O'] * len(tokens)
10
+
11
+ # Loop over the keywords and tag the document
12
+ for keyword in keywords:
13
+ # Split the keyword into words
14
+ keyword_words = keyword.split()
15
+
16
+ # Loop over the words in the document
17
+ for i in range(len(tokens)):
18
+ # If the current word matches the first word of the keyword
19
+ if tokens[i] == keyword_words[0]:
20
+ match = True
21
+ # Check if the rest of the words in the keyword match the following words in the document
22
+ for j in range(1, len(keyword_words)):
23
+ if i+j >= len(tokens) or tokens[i+j] != keyword_words[j]:
24
+ match = False
25
+ break
26
+ # If all the words in the keyword match the following words in the document, tag them as B-KEYWORD and I-KEYWORD
27
+ if match:
28
+ tags[i] = 'B-KEYWORD'
29
+ for j in range(1, len(keyword_words)):
30
+ tags[i+j] = 'I-KEYWORD'
31
+
32
+ return tags
33
+
34
+
35
+ def create_tner_dataset(all_tags, all_tokens, output_file_addr):
36
+ output_f = open(output_file_addr, 'a+')
37
+ for tags, tokens in zip(all_tags, all_tokens):
38
+ for tag, tok in zip(tags, tokens):
39
+ line = '\t'.join([tok, tag])
40
+ output_f.write(line)
41
+ output_f.write('\n')
42
+ output_f.write('\n')
43
+
44
+
45
+ if __name__ == '__main__':
46
+
47
+ data_df = pd.read_csv('truncated_wiki_plus_shuffled_41203.csv')
48
+ id2document = data_df.set_index('id')['truncated_text_300'].to_dict()
49
+
50
+
51
+ #tag documents!
52
+ print('------------------ tag documents --------------------')
53
+ all_tags = []
54
+ all_tokens = []
55
+ for tagged_data_addr in glob.iglob('./tagged_data*'):
56
+ for line in open(tagged_data_addr):
57
+ item = json.loads(line)
58
+ if type(item['keyphrases']) == list:
59
+ keywords = item['keyphrases']
60
+ document = id2document[item['id']]
61
+ tokens = document.split()
62
+ tags = tag_document(keywords, tokens)
63
+ assert len(tokens) == len(tags)
64
+ all_tags.append(tags)
65
+ all_tokens.append(tokens)
66
+ print(len(keywords), len(tags), len(document.split()), len([t for t in tags if t[0]== 'B']))
67
+ nerda_dataset = {'sentences':all_tokens, 'tags': all_tags}
68
+ with open('nerda_dataset.json', 'w+') as f:
69
+ json.dump(nerda_dataset, f)
70
+ # create_tner_dataset(all_tags, all_tokens, output_file_addr='./sample_train.conll')
predict.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from kpe import KPE
3
+ import sys
4
+ import os
5
+
6
+
7
+ if __name__ == '__main__':
8
+ TRAINED_MODEL_ADDR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'trained_model_10000.pt')
9
+ text_addr = sys.argv[1]
10
+ text = open(text_addr).read()
11
+ kpe = KPE(trained_kpe_model= TRAINED_MODEL_ADDR, flair_ner_model='flair/ner-english-ontonotes-large', device='cpu')
12
+ s =time.time()
13
+ print(kpe.extract(text))
14
+ print(time.time() - s)
ranker.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This is a simple application for sentence embeddings: semantic search
3
+
4
+ We have a corpus with various sentences. Then, for a given query sentence,
5
+ we want to find the most similar sentence in this corpus.
6
+
7
+ This script outputs for various queries the top 5 most similar sentences in the corpus.
8
+ """
9
+ from sentence_transformers import util
10
+ import torch
11
+
12
+
13
+ def get_sorted_keywords(embedder, text, keywords):
14
+ top_k = len(keywords)
15
+ keywords_embedding = embedder.encode(keywords, convert_to_tensor=True)
16
+ text_embedding = embedder.encode(text, convert_to_tensor=True)
17
+
18
+ cos_scores = util.cos_sim(keywords_embedding, text_embedding).squeeze(dim=1)
19
+ # print(cos_scores.size())
20
+ top_results = torch.topk(cos_scores, k=top_k)
21
+ return [(keywords[idx], top_results[0][index].item()) for index, idx in enumerate(top_results[1])]
22
+ # return [keywords[idx] for idx in top_results[1]]
23
+ # for score, idx in zip(top_results[0], top_results[1]):
24
+ # print(keywords[idx], "(Score: {:.4f})".format(score))
25
+
26
+
27
+
28
+
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ flair
4
+ hazm
5
+ parsinorm
6
+ pydantic
7
+ seaborn
8
+ streamlit
9
+ altair==4.2.2
utils.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from parsinorm import General_normalization
2
+ import re
3
+
4
+
5
+ def get_ne_from_iob_output(sentences, tags_conf):
6
+ sentences = sentences[0]
7
+ tags = tags_conf[0][0]
8
+ confs = tags_conf[1][0]
9
+
10
+ seen_b = False
11
+ keywords = {}
12
+ new_token = []
13
+ begin_index = 0
14
+ for index, (tok, tag) in enumerate(zip(sentences, tags)):
15
+ if tag[0] == 'I' and seen_b:
16
+ new_token.append(tok)
17
+ if tag[0] == 'B':
18
+ if new_token:
19
+ keywords[' '.join(new_token)] = confs[begin_index]
20
+ new_token = []
21
+ new_token.append(tok)
22
+ begin_index = index
23
+ seen_b = True
24
+ if tag[0] == 'O':
25
+ if new_token:
26
+ keywords[' '.join(new_token)] = confs[begin_index]
27
+ new_token = []
28
+ seen_b = False
29
+
30
+ # print('keywords before sort: ', [k for k in keywords.keys])
31
+ #sort
32
+ sorted_keywords = sorted(list(keywords.keys()), key=lambda kw: keywords[kw], reverse=True)
33
+ print('keywords after sort: ', sorted_keywords)
34
+ return sorted_keywords
35
+
36
+
37
+ def fuzzy_subword_match(key, words):
38
+ for index, w in enumerate(words):
39
+ if (len(key.split()) < len(w.split())) and key in w:
40
+ return index
41
+ return -1
42
+
43
+
44
+ #normalize
45
+ def normalize(txt):
46
+ general_normalization = General_normalization()
47
+ txt = general_normalization.alphabet_correction(txt)
48
+ txt = general_normalization.semi_space_correction(txt)
49
+ txt = general_normalization.english_correction(txt)
50
+ txt = general_normalization.html_correction(txt)
51
+ txt = general_normalization.arabic_correction(txt)
52
+ txt = general_normalization.punctuation_correction(txt)
53
+ txt = general_normalization.specials_chars(txt)
54
+ txt = general_normalization.remove_emojis(txt)
55
+ txt = general_normalization.number_correction(txt)
56
+ txt = general_normalization.remove_not_desired_chars(txt)
57
+ txt = general_normalization.remove_repeated_punctuation(txt)
58
+ return ' '.join(txt.replace('\n', ' ').replace('\t', ' ').replace('\r', ' ').split())
59
+
60
+
61
+
62
+ def remove_puncs(txt):
63
+ return re.sub('[!?،\(\)\.]','', txt)