abxhr commited on
Commit
34473f3
1 Parent(s): ce608f1
Pipfile ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [[source]]
2
+ url = "https://pypi.org/simple"
3
+ verify_ssl = true
4
+ name = "pypi"
5
+
6
+ [packages]
7
+ streamlit = "==0.84.2"
8
+ arabic-reshaper = "==2.1.3"
9
+ python-bidi = "==0.4.2"
10
+ pyarabic = "*"
11
+ farasapy = "==0.0.14"
12
+ emoji = "==1.4.2"
13
+ awesome-streamlit = "*"
14
+ torch = "==1.9.0"
15
+ transformers = "==4.10.0"
16
+ psutil = "==5.8.0"
17
+ fuzzysearch = "==0.7.3"
18
+ more-itertools = "==8.9.0"
19
+ cookiecutter = "*"
20
+
21
+ [dev-packages]
22
+
23
+ [requires]
24
+ python_version = "3.8"
README.md CHANGED
@@ -1,13 +1,9 @@
1
  ---
2
- title: Design Project
3
- emoji: 🏃
4
  colorFrom: purple
5
  colorTo: green
6
  sdk: streamlit
7
- sdk_version: 1.9.0
8
  app_file: app.py
9
- pinned: false
10
- license: unlicense
11
  ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference
1
  ---
2
+ title: Arabic NLP Demo
3
+ emoji:
4
  colorFrom: purple
5
  colorTo: green
6
  sdk: streamlit
 
7
  app_file: app.py
8
+ pinned: true
 
9
  ---
 
 
app.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import awesome_streamlit as ast
2
+ import streamlit as st
3
+
4
+ from backend.utils import get_current_ram_usage, ga
5
+
6
+ import backend.aragpt
7
+ import backend.home
8
+ import backend.processor
9
+ import backend.sa
10
+ import backend.qa
11
+
12
+ st.set_page_config(
13
+ page_title="TEST", page_icon="📖", initial_sidebar_state="expanded", layout="wide"
14
+ )
15
+
16
+ ga(st.__file__)
17
+
18
+ PAGES = {
19
+ "Home": backend.home,
20
+ "Arabic Text Preprocessor": backend.processor,
21
+ "Arabic Language Generation": backend.aragpt,
22
+ "Arabic Sentiment Analysis": backend.sa,
23
+ # "Arabic Sarcasm Detection": backend.sarcasm,
24
+ "Arabic Question Answering": backend.qa,
25
+ }
26
+
27
+
28
+ st.sidebar.title("Navigation")
29
+ selection = st.sidebar.radio("Pages", list(PAGES.keys()))
30
+
31
+ page = PAGES[selection]
32
+ # with st.spinner(f"Loading {selection} ..."):
33
+ ast.shared.components.write_page(page)
34
+
35
+ st.sidebar.header("Info")
36
+ st.sidebar.write("Made by [Wissam Antoun](https://twitter.com/wissam_antoun)")
37
+ st.sidebar.write(
38
+ "Pre-trained models are available on [HF Hub](https://huggingface.co/aubmindlab)"
39
+ )
40
+ st.sidebar.write(
41
+ "Models source code available on [GitHub](https://github.com/aub-mind/arabert)"
42
+ )
43
+ st.sidebar.write(
44
+ "App source code available on [GitHub](https://github.com/WissamAntoun/Arabic-NLP-app)"
45
+ )
46
+ if st.sidebar.checkbox("Show RAM usage"):
47
+ ram = get_current_ram_usage()
48
+ st.sidebar.write("Ram usage: {:.2f}/{:.2f} GB".format(ram[0], ram[1]))
backend/__init__.py ADDED
File without changes
backend/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (214 Bytes). View file
backend/__pycache__/aragpt.cpython-38.pyc ADDED
Binary file (4.41 kB). View file
backend/__pycache__/modeling_gpt2.cpython-38.pyc ADDED
Binary file (42.9 kB). View file
backend/__pycache__/preprocess.cpython-38.pyc ADDED
Binary file (17.8 kB). View file
backend/__pycache__/sa_utils.cpython-38.pyc ADDED
Binary file (14.5 kB). View file
backend/__pycache__/services.cpython-38.pyc ADDED
Binary file (11.6 kB). View file
backend/__pycache__/utils.cpython-38.pyc ADDED
Binary file (2.28 kB). View file
backend/aragpt.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from .services import TextGeneration
3
+ from tokenizers import Tokenizer
4
+ from functools import lru_cache
5
+
6
+ # @st.cache(allow_output_mutation=False, hash_funcs={Tokenizer: str})
7
+ @lru_cache(maxsize=1)
8
+ def load_text_generator():
9
+ generator = TextGeneration()
10
+ generator.load()
11
+ return generator
12
+
13
+
14
+ generator = load_text_generator()
15
+
16
+ qa_prompt = """
17
+ أجب عن السؤال التالي:
18
+ """
19
+ qa_prompt_post = """ الجواب هو """
20
+ qa_prompt_post_year = """ في سنة: """
21
+
22
+
23
+ def write():
24
+ st.markdown(
25
+ """
26
+ <h1 style="text-align:left;">Arabic Language Generation</h1>
27
+ """,
28
+ unsafe_allow_html=True,
29
+ )
30
+
31
+ # Sidebar
32
+
33
+ # Taken from https://huggingface.co/spaces/flax-community/spanish-gpt2/blob/main/app.py
34
+ st.sidebar.subheader("Configurable parameters")
35
+
36
+ model_name = st.sidebar.selectbox(
37
+ "Model Selector",
38
+ options=[
39
+ "AraGPT2-Base",
40
+ # "AraGPT2-Medium",
41
+ # "Aragpt2-Large",
42
+ "AraGPT2-Mega",
43
+ ],
44
+ index=0,
45
+ )
46
+
47
+ max_new_tokens = st.sidebar.number_input(
48
+ "Maximum length",
49
+ min_value=0,
50
+ max_value=1024,
51
+ value=100,
52
+ help="The maximum length of the sequence to be generated.",
53
+ )
54
+ temp = st.sidebar.slider(
55
+ "Temperature",
56
+ value=1.0,
57
+ min_value=0.1,
58
+ max_value=100.0,
59
+ help="The value used to module the next token probabilities.",
60
+ )
61
+ top_k = st.sidebar.number_input(
62
+ "Top k",
63
+ value=10,
64
+ help="The number of highest probability vocabulary tokens to keep for top-k-filtering.",
65
+ )
66
+ top_p = st.sidebar.number_input(
67
+ "Top p",
68
+ value=0.95,
69
+ help=" If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.",
70
+ )
71
+ do_sample = st.sidebar.selectbox(
72
+ "Sampling?",
73
+ (True, False),
74
+ help="Whether or not to use sampling; use greedy decoding otherwise.",
75
+ )
76
+ num_beams = st.sidebar.number_input(
77
+ "Number of beams",
78
+ min_value=1,
79
+ max_value=10,
80
+ value=3,
81
+ help="The number of beams to use for beam search.",
82
+ )
83
+ repetition_penalty = st.sidebar.number_input(
84
+ "Repetition Penalty",
85
+ min_value=0.0,
86
+ value=3.0,
87
+ step=0.1,
88
+ help="The parameter for repetition penalty. 1.0 means no penalty",
89
+ )
90
+ no_repeat_ngram_size = st.sidebar.number_input(
91
+ "No Repeat N-Gram Size",
92
+ min_value=0,
93
+ value=3,
94
+ help="If set to int > 0, all ngrams of that size can only occur once.",
95
+ )
96
+
97
+ st.write("#")
98
+
99
+ col = st.columns(2)
100
+
101
+ col[0].image("images/AraGPT2.png", width=200)
102
+
103
+ st.markdown(
104
+ """
105
+
106
+ <h3 style="text-align:left;">AraGPT2 is GPT2 model trained from scratch on 77GB of Arabic text.</h3>
107
+ <h4 style="text-align:left;"> More details in our <a href="https://github.com/aub-mind/arabert/tree/master/aragpt2">repo</a>.</h4>
108
+
109
+ <p style="text-align:left;"><p>
110
+ <p style="text-align:left;">Use the generation paramters on the sidebar to adjust generation quality.</p>
111
+ <p style="text-align:right;"><p>
112
+ """,
113
+ unsafe_allow_html=True,
114
+ )
115
+
116
+ # col[0].write(
117
+ # "AraGPT2 is trained from screatch on 77GB of Arabic text. More details in our [repo](https://github.com/aub-mind/arabert/tree/master/aragpt2)."
118
+ # )
119
+ # st.write("## Generate Arabic Text")
120
+
121
+ st.markdown(
122
+ """
123
+ <style>
124
+ p, div, input, label, textarea{
125
+ text-align: right;
126
+ }
127
+ </style>
128
+ """,
129
+ unsafe_allow_html=True,
130
+ )
131
+
132
+ prompt = st.text_area(
133
+ "Prompt",
134
+ "يحكى أن مزارعا مخادعا قام ببيع بئر الماء الموجود في أرضه لجاره مقابل مبلغ كبير من المال",
135
+ )
136
+ if st.button("Generate"):
137
+ with st.spinner("Generating..."):
138
+ generated_text = generator.generate(
139
+ prompt=prompt,
140
+ model_name=model_name,
141
+ max_new_tokens=max_new_tokens,
142
+ temperature=temp,
143
+ top_k=top_k,
144
+ top_p=top_p,
145
+ repetition_penalty=repetition_penalty,
146
+ do_sample=do_sample,
147
+ num_beams=num_beams,
148
+ no_repeat_ngram_size=no_repeat_ngram_size,
149
+ )
150
+ st.write(generated_text)
151
+
152
+ st.markdown("---")
153
+ st.subheader("")
154
+ st.markdown(
155
+ """
156
+ <p style="text-align:left;"><p>
157
+ <h2 style="text-align:left;">Zero-Shot Question Answering</h2>
158
+
159
+ <p style="text-align:left;">Adjust the maximum length to closely match the expected output length. Setting the Sampling paramter to False is recommended</p>
160
+ <p style="text-align:left;"><p>
161
+ """,
162
+ unsafe_allow_html=True,
163
+ )
164
+
165
+ question = st.text_input(
166
+ "Question", "من كان رئيس ألمانيا النازية في الحرب العالمية الثانية ؟"
167
+ )
168
+ is_date = st.checkbox("Help the model: Is the answer a date?")
169
+ if st.button("Answer"):
170
+
171
+ prompt2 = qa_prompt + question + qa_prompt_post
172
+ if is_date:
173
+ prompt2 += qa_prompt_post_year
174
+ else:
175
+ prompt2 += " : "
176
+ with st.spinner("Thinking..."):
177
+ answer = generator.generate(
178
+ prompt=prompt2,
179
+ model_name=model_name,
180
+ max_new_tokens=max_new_tokens,
181
+ temperature=temp,
182
+ top_k=top_k,
183
+ top_p=top_p,
184
+ repetition_penalty=repetition_penalty,
185
+ do_sample=do_sample,
186
+ num_beams=num_beams,
187
+ no_repeat_ngram_size=no_repeat_ngram_size,
188
+ )
189
+ st.write(answer)
backend/home.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import awesome_streamlit as ast
3
+
4
+
5
+ def write():
6
+ st.markdown(
7
+ """
8
+ # Arabic Natural Language Processing
9
+
10
+ ![visitors](https://visitor-badge.glitch.me/badge?page_id=wissamantoun.arabicnlpapp)
11
+
12
+
13
+ In this HuggingFace space you will be able to test the different Arabic NLP models that my colleges at [AUB MIND Lab](https://sites.aub.edu.lb/mindlab/) have built, with some other applications.
14
+
15
+ Check the **Navigation bar** to access the apps:
16
+ - Arabic Text Preprocessor: Test how text imput is treated by our preprocessor
17
+ - Arabic Language Generation: Generate Arabic text using our AraGPT2 language models
18
+ - Arabic Sentiment Analysis: Test the senitment analysis model that won the [Arabic Senitment Analysis competition @ KAUST](https://www.kaggle.com/c/arabic-sentiment-analysis-2021-kaust)
19
+ - Arabic Question Answering: Test our AraELECTRA QA capabilities
20
+ """
21
+ )
22
+ st.markdown("#")
23
+ col1, col2, col3 = st.columns(3)
24
+
25
+ col1.write("## **AraBERT**")
26
+ col1.image("images/arabert_logo.png", width=200)
27
+
28
+ col2.write("## **AraGPT2**")
29
+ col2.image("images/AraGPT2.png", width=200)
30
+
31
+ col3.write("## **AraElectra**")
32
+ col3.image("images/AraELECTRA.png", width=200)
33
+
34
+ st.markdown(
35
+ """
36
+
37
+ You can find the more details in the source code and paper linked in our repository on GitHub [repo](https://github.com/aub-mind/arabert).
38
+
39
+ ## Dataset
40
+
41
+ The pretraining data used for the new **AraBERT** model is also used for **AraGPT2 and AraELECTRA**.
42
+
43
+ The dataset consists of 77GB or 200,095,961 lines or 8,655,948,860 words or 82,232,988,358 chars (before applying Farasa Segmentation)
44
+
45
+ Our large models were train a TPUv3-128 provided by TFRC.
46
+
47
+ For the new dataset we added the unshuffled OSCAR corpus, after we thoroughly filter it, to the previous dataset used in AraBERTv1 but with out the websites that we previously crawled:
48
+ - OSCAR unshuffled and filtered.
49
+ - [Arabic Wikipedia dump](https://archive.org/details/arwiki-20190201) from 2020/09/01
50
+ - [The 1.5B words Arabic Corpus](https://www.semanticscholar.org/paper/1.5-billion-words-Arabic-Corpus-El-Khair/f3eeef4afb81223df96575adadf808fe7fe440b4)
51
+ - [The OSIAN Corpus](https://www.aclweb.org/anthology/W19-4619)
52
+ - Assafir news articles. Huge thank you for Assafir for the data
53
+
54
+ ## Models
55
+
56
+ Model | HuggingFace Model Name | Size (MB/Params)| Pre-Segmentation | Hardware | Sequence Length | Batch Size | Num of Steps | Total Time (in Days) |
57
+ ---|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:
58
+ AraBERTv0.2-base | [bert-base-arabertv02](https://huggingface.co/aubmindlab/bert-base-arabertv02) | 543MB / 136M | No | TPUv3-8 | 128 /512 | 2560/384 | 1M/ 2M | 36 |
59
+ AraBERTv0.2-large| [bert-large-arabertv02](https://huggingface.co/aubmindlab/bert-large-arabertv02) | 1.38G / 371M | No | TPUv3-128 | 128 /512 | 13440 / 2056 | 250K / 300K | 7 |
60
+ AraBERTv2-base| [bert-base-arabertv2](https://huggingface.co/aubmindlab/bert-base-arabertv2) | 543MB / 136M | Yes | TPUv3-8 |128 /512 | 2560 / 384 | 1M / 2M | 36 |
61
+ AraBERTv2-large| [bert-large-arabertv2](https://huggingface.co/aubmindlab/bert-large-arabertv2) | 1.38G / 371M | Yes | TPUv3-128 |128 /512 | 13440 / 2056| 250K / 300K | 7 |
62
+ AraBERTv0.1-base| [bert-base-arabertv01](https://huggingface.co/aubmindlab/bert-base-arabertv01) | 543MB / 136M | No | TPUv2-8 |128 /512 |128 / 512 | 900K / 300K| 4 |
63
+ AraBERTv1-base| [bert-base-arabert](https://huggingface.co/aubmindlab/bert-base-arabert) | 543MB / 136M | Yes | TPUv2-8 |128 /512 |128 / 512 | 900K / 300K| 4 |
64
+ AraGPT2-base | [aragpt2-base](https://huggingface.co/aubmindlab/aragpt2-base) | 527MB/135M | No | TPUv3-128 | 1024 | 1792 | 125K | 1.5 |
65
+ AraGPT2-medium | [aragpt2-medium](https://huggingface.co/aubmindlab/aragpt2-medium) | 1.38G/370M | No |TPUv3-8 | 1024 | 80 | 1M | 15 |
66
+ AraGPT2-large | [aragpt2-large](https://huggingface.co/aubmindlab/aragpt2-large) | 2.98GB/792M | No |TPUv3-128 | 1024 | 256 | 220k | 3 |
67
+ AraGPT2-mega | [aragpt2-mega](https://huggingface.co/aubmindlab/aragpt2-mega) | 5.5GB/1.46B |No |TPUv3-128 | 1024 | 256 | 800K | 9 |
68
+ AraELECTRA-base-generator | [araelectra-base-generator](https://huggingface.co/aubmindlab/araelectra-base-generator) | 227MB/60M | No | TPUv3-8 | 512 | 256 | 2M | 24
69
+ AraELECTRA-base-discriminator | [araelectra-base-discriminator](https://huggingface.co/aubmindlab/araelectra-base-discriminator) | 516MB/135M | No | TPUv3-8 | 512 | 256 | 2M | 24
70
+ AraBERTv0.2-Twitter-base| [bert-base-arabertv02-twitter](https://huggingface.co/aubmindlab/bert-base-arabertv02-twitter) | 543MB / 136M | No | V100 | *64* | - | - | - |
71
+ AraBERTv0.2-Twitter-large| [bert-large-arabertv02-twitter](https://huggingface.co/aubmindlab/bert-large-arabertv02-twitter) | 1.38G / 371M | No | V100 | *64* | - | - | - |
72
+
73
+ All models are available in the `HuggingFace` model page under the [aubmindlab](https://huggingface.co/aubmindlab/) name. Checkpoints are available in PyTorch, TF2 and TF1 formats.
74
+
75
+ # Preprocessing
76
+
77
+ You can test the Arabic Preprocessing pipeline in the Arabic Text Preprocessing page.
78
+
79
+ It is recommended to apply our preprocessing function before training/testing on any dataset.
80
+ **Install farasapy to segment text for AraBERT v1 & v2 `pip install farasapy`**
81
+
82
+ ```python
83
+ from arabert.preprocess import ArabertPreprocessor
84
+
85
+ model_name = "aubmindlab/bert-base-arabertv2"
86
+ arabert_prep = ArabertPreprocessor(model_name=model_name)
87
+
88
+ text = "ولن نبالغ إذا قلنا: إن 'هاتف' أو 'كمبيوتر المكتب' في زمننا هذا ضروري"
89
+ arabert_prep.preprocess(text)
90
+ >>>"و+ لن نبالغ إذا قل +نا : إن ' هاتف ' أو ' كمبيوتر ال+ مكتب ' في زمن +نا هذا ضروري"
91
+ ```
92
+
93
+ You can also use the `unpreprocess()` function to reverse the preprocessing changes, by fixing the spacing around non alphabetical characters, and also de-segmenting if the model selected need pre-segmentation. We highly recommend unprocessing generated content of `AraGPT2` model, to make it look more natural.
94
+ ```python
95
+ output_text = "و+ لن نبالغ إذا قل +نا : إن ' هاتف ' أو ' كمبيوتر ال+ مكتب ' في زمن +نا هذا ضروري"
96
+ arabert_prep.unpreprocess(output_text)
97
+ >>>"ولن نبالغ إذا قلنا: إن 'هاتف' أو 'كمبيوتر المكتب' في زمننا هذا ضروري"
98
+ ```
99
+
100
+ # If you used this model please cite us as :
101
+
102
+ ## AraBERT
103
+ Google Scholar has our Bibtex wrong (missing name), use this instead
104
+ ```
105
+ @inproceedings{antoun2020arabert,
106
+ title={AraBERT: Transformer-based Model for Arabic Language Understanding},
107
+ author={Antoun, Wissam and Baly, Fady and Hajj, Hazem},
108
+ booktitle={LREC 2020 Workshop Language Resources and Evaluation Conference 11--16 May 2020},
109
+ pages={9}
110
+ }
111
+ ```
112
+ ## AraGPT2
113
+ ```
114
+ @inproceedings{antoun-etal-2021-aragpt2,
115
+ title = "{A}ra{GPT}2: Pre-Trained Transformer for {A}rabic Language Generation",
116
+ author = "Antoun, Wissam and
117
+ Baly, Fady and
118
+ Hajj, Hazem",
119
+ booktitle = "Proceedings of the Sixth Arabic Natural Language Processing Workshop",
120
+ month = apr,
121
+ year = "2021",
122
+ address = "Kyiv, Ukraine (Virtual)",
123
+ publisher = "Association for Computational Linguistics",
124
+ url = "https://www.aclweb.org/anthology/2021.wanlp-1.21",
125
+ pages = "196--207",
126
+ }
127
+ ```
128
+
129
+ ## AraELECTRA
130
+ ```
131
+ @inproceedings{antoun-etal-2021-araelectra,
132
+ title = "{A}ra{ELECTRA}: Pre-Training Text Discriminators for {A}rabic Language Understanding",
133
+ author = "Antoun, Wissam and
134
+ Baly, Fady and
135
+ Hajj, Hazem",
136
+ booktitle = "Proceedings of the Sixth Arabic Natural Language Processing Workshop",
137
+ month = apr,
138
+ year = "2021",
139
+ address = "Kyiv, Ukraine (Virtual)",
140
+ publisher = "Association for Computational Linguistics",
141
+ url = "https://www.aclweb.org/anthology/2021.wanlp-1.20",
142
+ pages = "191--195",
143
+ }
144
+ ```
145
+
146
+
147
+ # Acknowledgments
148
+ Thanks to TensorFlow Research Cloud (TFRC) for the free access to Cloud TPUs, couldn't have done it without this program, and to the [AUB MIND Lab](https://sites.aub.edu.lb/mindlab/) Members for the continous support. Also thanks to [Yakshof](https://www.yakshof.com/#/) and Assafir for data and storage access. Another thanks for Habib Rahal (https://www.behance.net/rahalhabib), for putting a face to AraBERT.
149
+
150
+ # Contacts
151
+ **Wissam Antoun**: [Linkedin](https://www.linkedin.com/in/wissam-antoun-622142b4/) | [Twitter](https://twitter.com/wissam_antoun) | [Github](https://github.com/WissamAntoun) | wfa07 (AT) mail (DOT) aub (DOT) edu | wissam.antoun (AT) gmail (DOT) com
152
+
153
+ **Fady Baly**: [Linkedin](https://www.linkedin.com/in/fadybaly/) | [Twitter](https://twitter.com/fadybaly) | [Github](https://github.com/fadybaly) | fgb06 (AT) mail (DOT) aub (DOT) edu | baly.fady (AT) gmail (DOT) com
154
+
155
+ """
156
+ )
backend/modeling_gpt2.py ADDED
@@ -0,0 +1,1599 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ """
18
+ PyTorch OpenAI GPT-2 model.
19
+ Adapted from https://github.com/huggingface/transformers/blob/v4.0.1/src/transformers/models/gpt2/modeling_gpt2.py
20
+ and https://github.com/ghosthamlet/gpt2-ml-torch/blob/master/gpt2_ml_torch/modeling_gpt2.py
21
+ """
22
+
23
+
24
+ import logging
25
+ import os
26
+ from dataclasses import dataclass
27
+ from typing import List, Optional, Tuple
28
+
29
+ import torch
30
+ import torch.nn as nn
31
+ from torch.nn import CrossEntropyLoss, MSELoss
32
+ from transformers import CONFIG_NAME, WEIGHTS_NAME, GPT2Config, GPT2Model
33
+ from transformers.activations import ACT2FN
34
+ from transformers.file_utils import (
35
+ ModelOutput,
36
+ add_code_sample_docstrings,
37
+ add_start_docstrings,
38
+ add_start_docstrings_to_model_forward,
39
+ replace_return_docstrings,
40
+ )
41
+ from transformers.modeling_outputs import (
42
+ BaseModelOutputWithPastAndCrossAttentions,
43
+ CausalLMOutputWithCrossAttentions,
44
+ SequenceClassifierOutputWithPast,
45
+ TokenClassifierOutput,
46
+ )
47
+ from transformers.modeling_utils import (
48
+ Conv1D,
49
+ PreTrainedModel,
50
+ SequenceSummary,
51
+ find_pruneable_heads_and_indices,
52
+ prune_conv1d_layer,
53
+ )
54
+ from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
55
+
56
+ # THe Difference from Transformers is code under _USE_GROVER
57
+ _USE_GROVER = True
58
+
59
+ logger = logging.getLogger(__name__)
60
+
61
+ _CONFIG_FOR_DOC = "GPT2Config"
62
+ _TOKENIZER_FOR_DOC = "GPT2Tokenizer"
63
+
64
+ GPT2_PRETRAINED_MODEL_ARCHIVE_LIST = [
65
+ "gpt2",
66
+ "gpt2-medium",
67
+ "gpt2-large",
68
+ "gpt2-xl",
69
+ "distilgpt2",
70
+ # See all GPT-2 models at https://huggingface.co/models?filter=gpt2
71
+ ]
72
+
73
+ logger.setLevel(logging.INFO)
74
+ console = logging.StreamHandler()
75
+ console.setLevel(logging.INFO)
76
+ logger.addHandler(console)
77
+
78
+ _GPT2_ML_TF_TO_TORCH = {
79
+ "LayerNorm_embed_norm": "emb_norm",
80
+ "pos_embed": "wpe.weight",
81
+ "word_embed": "wte.weight",
82
+ "layer": "h",
83
+ # Most importently This two layer norm must be put on the same position as gpt2-ml
84
+ # or generated data is bad, just repeat the last token
85
+ "LayerNorm_mlp_ln0": "ln_1",
86
+ "LayerNorm_mlp_ln1": "ln_2",
87
+ "intermediate": "mlp.c_fc",
88
+ "output": "mlp.c_proj",
89
+ "query_layer": "attn.c_attn",
90
+ "key_layer": "attn.c_attn",
91
+ "value_layer": "attn.c_attn",
92
+ "context_projection_layer": "attn.c_proj",
93
+ "gamma": "weight",
94
+ "kernel": "weight",
95
+ "beta": "bias",
96
+ "bias": "bias",
97
+ }
98
+
99
+
100
+ def convert_gpt2_checkpoint_to_pytorch(
101
+ gpt2_checkpoint_path, gpt2_config_file, pytorch_dump_folder_path
102
+ ):
103
+ # Construct model
104
+ if gpt2_config_file == "":
105
+ config = GPT2Config()
106
+ else:
107
+ config = GPT2Config.from_json_file(gpt2_config_file)
108
+ model = GPT2Model(config)
109
+
110
+ # Load weights from numpy
111
+ load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path)
112
+
113
+ # Save pytorch-model
114
+ pytorch_weights_dump_path = pytorch_dump_folder_path + "/" + WEIGHTS_NAME
115
+ pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME
116
+ print("Save PyTorch model to {}".format(pytorch_weights_dump_path))
117
+ torch.save(model.state_dict(), pytorch_weights_dump_path)
118
+ print("Save configuration file to {}".format(pytorch_config_dump_path))
119
+ with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
120
+ f.write(config.to_json_string())
121
+
122
+
123
+ # XXX: MUST do like: convert_gpt2_checkpoint_to_pytorch('./model.ckpt-100000', './mega.json', './')
124
+ # https://github.com/tensorflow/models/issues/2675#issuecomment-516595597
125
+ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
126
+ """Load tf checkpoints in a pytorch model"""
127
+ try:
128
+ import re
129
+
130
+ import tensorflow as tf
131
+ except ImportError:
132
+ logger.error(
133
+ "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
134
+ "https://www.tensorflow.org/install/ for installation instructions."
135
+ )
136
+ raise
137
+ tf_path = os.path.abspath(gpt2_checkpoint_path)
138
+ logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
139
+ # Load weights from TF model
140
+ init_vars = tf.train.list_variables(tf_path)
141
+ names = []
142
+ arrays = []
143
+ for name, shape in init_vars:
144
+ logger.info("Loading TF weight {} with shape {}".format(name, shape))
145
+ array = tf.train.load_variable(tf_path, name)
146
+ names.append(name)
147
+ arrays.append(array.squeeze())
148
+
149
+ import copy
150
+
151
+ orig_model = copy.deepcopy(model)
152
+
153
+ for name, array in zip(names, arrays):
154
+ name = name[6:] # skip "model/"
155
+ name = name.split("/")
156
+ pointer = model
157
+
158
+ attn_layer = ""
159
+ for m_name in name:
160
+ if re.fullmatch(r"[A-Za-z]+\d+", m_name):
161
+ scope_names = re.split(r"(\d+)", m_name)
162
+ else:
163
+ scope_names = [m_name]
164
+ sname = scope_names[0]
165
+
166
+ if sname == "" or sname == "embeddings":
167
+ continue
168
+ elif sname not in _GPT2_ML_TF_TO_TORCH:
169
+ print("=========================================================")
170
+ logger.info("Skip var name {}".format(scope_names))
171
+ pointer = None
172
+ break
173
+ else:
174
+ tname = _GPT2_ML_TF_TO_TORCH[sname]
175
+ if "." in tname:
176
+ parent, child = tname.split(".")
177
+ pointer = getattr(pointer, parent)
178
+ pointer = getattr(pointer, child)
179
+ else:
180
+ pointer = getattr(pointer, tname)
181
+
182
+ if tname == "attn.c_attn":
183
+ attn_layer = sname
184
+
185
+ if len(scope_names) >= 2:
186
+ num = int(scope_names[1])
187
+ pointer = pointer[num]
188
+
189
+ if pointer is None:
190
+ continue
191
+ if attn_layer == "":
192
+ try:
193
+ assert pointer.shape == array.shape
194
+ except AssertionError as e:
195
+ e.args += (pointer.shape, array.shape)
196
+ raise
197
+ logger.info(
198
+ "Initialize PyTorch weight {}, {}, {}".format(
199
+ name, array.mean(), pointer.mean()
200
+ )
201
+ )
202
+ if attn_layer == "":
203
+ pointer.data = torch.from_numpy(array)
204
+ else:
205
+ shape = pointer.shape
206
+ d = torch.from_numpy(array)
207
+ is_bias = len(shape) == 1
208
+ end = int(shape[0 if is_bias else 1] / 3)
209
+ m = dict(
210
+ query_layer=0,
211
+ key_layer=end,
212
+ value_layer=end * 2,
213
+ )
214
+ start = m[attn_layer]
215
+ end = start + end
216
+ if is_bias:
217
+ pointer.data[start:end] = d
218
+ else:
219
+ pointer.data[:, start:end] = d
220
+ logger.info(
221
+ "Initialize PyTorch weight {}, {}, {}".format(
222
+ name, array.mean(), pointer.mean()
223
+ )
224
+ )
225
+
226
+ for name, params in orig_model.named_parameters():
227
+ for n, p in model.named_parameters():
228
+ if name == n:
229
+ if params.equal(p):
230
+ print("--------------------------")
231
+ print(" %s not changed!" % n)
232
+ return model
233
+
234
+
235
+ class Attention(nn.Module):
236
+ def __init__(self, nx, n_ctx, config, scale=False, is_cross_attention=False):
237
+ super().__init__()
238
+
239
+ n_state = nx # in Attention: n_state=768 (nx=n_embd)
240
+ # [switch nx => n_state from Block to Attention to keep identical to TF implem]
241
+ assert n_state % config.n_head == 0
242
+ self.register_buffer(
243
+ "bias",
244
+ torch.tril(torch.ones((n_ctx, n_ctx), dtype=torch.uint8)).view(
245
+ 1, 1, n_ctx, n_ctx
246
+ ),
247
+ )
248
+ self.register_buffer("masked_bias", torch.tensor(-1e4))
249
+ self.n_head = config.n_head
250
+ self.split_size = n_state
251
+ self.scale = scale
252
+ self.is_cross_attention = is_cross_attention
253
+ if self.is_cross_attention:
254
+ self.c_attn = Conv1D(2 * n_state, nx)
255
+ self.q_attn = Conv1D(n_state, nx)
256
+ else:
257
+ self.c_attn = Conv1D(3 * n_state, nx)
258
+ self.c_proj = Conv1D(n_state, nx)
259
+ self.attn_dropout = nn.Dropout(config.attn_pdrop)
260
+ self.resid_dropout = nn.Dropout(config.resid_pdrop)
261
+ self.pruned_heads = set()
262
+
263
+ def prune_heads(self, heads):
264
+ if len(heads) == 0:
265
+ return
266
+ heads, index = find_pruneable_heads_and_indices(
267
+ heads, self.n_head, self.split_size // self.n_head, self.pruned_heads
268
+ )
269
+ index_attn = torch.cat(
270
+ [index, index + self.split_size, index + (2 * self.split_size)]
271
+ )
272
+
273
+ # Prune conv1d layers
274
+ self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
275
+ self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
276
+
277
+ # Update hyper params
278
+ self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads))
279
+ self.n_head = self.n_head - len(heads)
280
+ self.pruned_heads = self.pruned_heads.union(heads)
281
+
282
+ def _attn(
283
+ self, q, k, v, attention_mask=None, head_mask=None, output_attentions=False
284
+ ):
285
+ w = torch.matmul(q, k)
286
+ if self.scale:
287
+ w = w / (float(v.size(-1)) ** 0.5)
288
+ nd, ns = w.size(-2), w.size(-1)
289
+
290
+ if not self.is_cross_attention:
291
+ # if only "normal" attention layer implements causal mask
292
+ mask = self.bias[:, :, ns - nd : ns, :ns]
293
+ w = torch.where(mask.bool(), w, self.masked_bias.to(w.dtype))
294
+
295
+ if attention_mask is not None:
296
+ # Apply the attention mask
297
+ w = w + attention_mask
298
+
299
+ w = nn.Softmax(dim=-1)(w)
300
+ w = self.attn_dropout(w)
301
+
302
+ # Mask heads if we want to
303
+ if head_mask is not None:
304
+ w = w * head_mask
305
+
306
+ outputs = [torch.matmul(w, v)]
307
+ if output_attentions:
308
+ outputs.append(w)
309
+ return outputs
310
+
311
+ def merge_heads(self, x):
312
+ x = x.permute(0, 2, 1, 3).contiguous()
313
+ new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
314
+ return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states
315
+
316
+ def split_heads(self, x, k=False):
317
+ new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
318
+ x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states
319
+ if k:
320
+ return x.permute(0, 2, 3, 1) # (batch, head, head_features, seq_length)
321
+ else:
322
+ return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
323
+
324
+ def forward(
325
+ self,
326
+ hidden_states,
327
+ layer_past=None,
328
+ attention_mask=None,
329
+ head_mask=None,
330
+ encoder_hidden_states=None,
331
+ encoder_attention_mask=None,
332
+ use_cache=False,
333
+ output_attentions=False,
334
+ ):
335
+ if encoder_hidden_states is not None:
336
+ assert hasattr(
337
+ self, "q_attn"
338
+ ), "If class is used as cross attention, the weights `q_attn` have to be defined. Please make sure to instantiate class with `Attention(..., is_cross_attention=True)`."
339
+ query = self.q_attn(hidden_states)
340
+ key, value = self.c_attn(encoder_hidden_states).split(
341
+ self.split_size, dim=2
342
+ )
343
+ attention_mask = encoder_attention_mask
344
+ else:
345
+ query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
346
+
347
+ query = self.split_heads(query)
348
+ key = self.split_heads(key, k=True)
349
+ value = self.split_heads(value)
350
+ if layer_past is not None:
351
+ past_key, past_value = (
352
+ layer_past[0].transpose(-2, -1),
353
+ layer_past[1],
354
+ ) # transpose back cf below
355
+ key = torch.cat((past_key, key), dim=-1)
356
+ value = torch.cat((past_value, value), dim=-2)
357
+
358
+ if use_cache is True:
359
+ present = torch.stack(
360
+ (key.transpose(-2, -1), value)
361
+ ) # transpose to have same shapes for stacking
362
+ else:
363
+ present = (None,)
364
+
365
+ attn_outputs = self._attn(
366
+ query, key, value, attention_mask, head_mask, output_attentions
367
+ )
368
+ a = attn_outputs[0]
369
+
370
+ a = self.merge_heads(a)
371
+ a = self.c_proj(a)
372
+ a = self.resid_dropout(a)
373
+
374
+ outputs = [a, present] + attn_outputs[1:]
375
+ return outputs # a, present, (attentions)
376
+
377
+
378
+ class MLP(nn.Module):
379
+ def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd)
380
+ super().__init__()
381
+ nx = config.n_embd
382
+ self.c_fc = Conv1D(n_state, nx)
383
+ self.c_proj = Conv1D(nx, n_state)
384
+ self.act = ACT2FN[config.activation_function]
385
+ self.dropout = nn.Dropout(config.resid_pdrop)
386
+
387
+ def forward(self, x):
388
+ h = self.act(self.c_fc(x))
389
+ h2 = self.c_proj(h)
390
+ return self.dropout(h2)
391
+
392
+
393
+ class Block(nn.Module):
394
+ def __init__(self, n_ctx, config, scale=False):
395
+ super().__init__()
396
+ hidden_size = config.n_embd
397
+ inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
398
+ self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
399
+ self.attn = Attention(hidden_size, n_ctx, config, scale)
400
+ self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
401
+ if config.add_cross_attention:
402
+ self.crossattention = Attention(
403
+ hidden_size, n_ctx, config, scale, is_cross_attention=True
404
+ )
405
+ self.ln_cross_attn = nn.LayerNorm(
406
+ hidden_size, eps=config.layer_norm_epsilon
407
+ )
408
+ self.mlp = MLP(inner_dim, config)
409
+
410
+ def forward(
411
+ self,
412
+ hidden_states,
413
+ layer_past=None,
414
+ attention_mask=None,
415
+ head_mask=None,
416
+ encoder_hidden_states=None,
417
+ encoder_attention_mask=None,
418
+ use_cache=False,
419
+ output_attentions=False,
420
+ ):
421
+ attn_outputs = self.attn(
422
+ hidden_states,
423
+ layer_past=layer_past,
424
+ attention_mask=attention_mask,
425
+ head_mask=head_mask,
426
+ use_cache=use_cache,
427
+ output_attentions=output_attentions,
428
+ )
429
+ attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
430
+ outputs = attn_outputs[1:]
431
+ # residual connection
432
+ hidden_states = attn_output + hidden_states
433
+
434
+ if encoder_hidden_states is not None:
435
+ # add one self-attention block for cross-attention
436
+ assert hasattr(
437
+ self, "crossattention"
438
+ ), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
439
+ cross_attn_outputs = self.crossattention(
440
+ self.ln_cross_attn(hidden_states),
441
+ attention_mask=attention_mask,
442
+ head_mask=head_mask,
443
+ encoder_hidden_states=encoder_hidden_states,
444
+ encoder_attention_mask=encoder_attention_mask,
445
+ output_attentions=output_attentions,
446
+ )
447
+ attn_output = cross_attn_outputs[0]
448
+ # residual connection
449
+ hidden_states = hidden_states + attn_output
450
+ outputs = (
451
+ outputs + cross_attn_outputs[2:]
452
+ ) # add cross attentions if we output attention weights
453
+
454
+ feed_forward_hidden_states = self.mlp(self.ln_1(hidden_states))
455
+ # residual connection
456
+ hidden_states = hidden_states + feed_forward_hidden_states
457
+
458
+ hidden_states = self.ln_2(hidden_states)
459
+
460
+ outputs = [hidden_states] + outputs
461
+ return outputs # hidden_states, present, (attentions, cross_attentions)
462
+
463
+
464
+ class GPT2PreTrainedModel(PreTrainedModel):
465
+ """
466
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
467
+ models.
468
+ """
469
+
470
+ config_class = GPT2Config
471
+ load_tf_weights = load_tf_weights_in_gpt2
472
+ base_model_prefix = "transformer"
473
+ is_parallelizable = True
474
+
475
+ def __init__(self, *inputs, **kwargs):
476
+ super().__init__(*inputs, **kwargs)
477
+
478
+ def _init_weights(self, module):
479
+ """Initialize the weights."""
480
+ if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)):
481
+ # Slightly different from the TF version which uses truncated_normal for initialization
482
+ # cf https://github.com/pytorch/pytorch/pull/5617
483
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
484
+ if isinstance(module, (nn.Linear, Conv1D)) and module.bias is not None:
485
+ module.bias.data.zero_()
486
+ elif isinstance(module, nn.LayerNorm):
487
+ module.bias.data.zero_()
488
+ module.weight.data.fill_(1.0)
489
+
490
+
491
+ @dataclass
492
+ class GPT2DoubleHeadsModelOutput(ModelOutput):
493
+ """
494
+ Base class for outputs of models predicting if two sentences are consecutive or not.
495
+
496
+ Args:
497
+ loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided):
498
+ Language modeling loss.
499
+ mc_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`mc_labels` is provided):
500
+ Multiple choice classification loss.
501
+ logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`):
502
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
503
+ mc_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
504
+ Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
505
+ past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
506
+ List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape :obj:`(2,
507
+ batch_size, num_heads, sequence_length, embed_size_per_head)`).
508
+
509
+ Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
510
+ :obj:`past_key_values` input) to speed up sequential decoding.
511
+ hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
512
+ Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
513
+ of shape :obj:`(batch_size, sequence_length, hidden_size)`.
514
+
515
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
516
+ attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
517
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
518
+ sequence_length, sequence_length)`.
519
+
520
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
521
+ heads.
522
+ """
523
+
524
+ loss: Optional[torch.FloatTensor] = None
525
+ mc_loss: Optional[torch.FloatTensor] = None
526
+ logits: torch.FloatTensor = None
527
+ mc_logits: torch.FloatTensor = None
528
+ past_key_values: Optional[List[torch.FloatTensor]] = None
529
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
530
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
531
+
532
+
533
+ GPT2_START_DOCSTRING = r"""
534
+
535
+ This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
536
+ methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
537
+ pruning heads etc.)
538
+
539
+ This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__
540
+ subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
541
+ general usage and behavior.
542
+
543
+ Parameters:
544
+ config (:class:`~transformers.GPT2Config`): Model configuration class with all the parameters of the model.
545
+ Initializing with a config file does not load the weights associated with the model, only the
546
+ configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
547
+ weights.
548
+ """
549
+
550
+ GPT2_INPUTS_DOCSTRING = r"""
551
+ Args:
552
+ input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`):
553
+ :obj:`input_ids_length` = ``sequence_length`` if :obj:`past_key_values` is ``None`` else
554
+ ``past_key_values[0].shape[-2]`` (``sequence_length`` of input past key value states). Indices of input
555
+ sequence tokens in the vocabulary.
556
+
557
+ If :obj:`past_key_values` is used, only ``input_ids`` that do not have their past calculated should be
558
+ passed as ``input_ids``.
559
+
560
+ Indices can be obtained using :class:`~transformers.GPT2Tokenizer`. See
561
+ :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
562
+ details.
563
+
564
+ `What are input IDs? <../glossary.html#input-ids>`__
565
+ past_key_values (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
566
+ Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
567
+ :obj:`past_key_values` output below). Can be used to speed up sequential decoding. The ``input_ids`` which
568
+ have their past given to this model should not be passed as ``input_ids`` as they have already been
569
+ computed.
570
+ attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
571
+ Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
572
+
573
+ - 1 for tokens that are **not masked**,
574
+ - 0 for tokens that are **masked**.
575
+
576
+ `What are attention masks? <../glossary.html#attention-mask>`__
577
+ token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`, `optional`):
578
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,
579
+ 1]``:
580
+
581
+ - 0 corresponds to a `sentence A` token,
582
+ - 1 corresponds to a `sentence B` token.
583
+
584
+ `What are token type IDs? <../glossary.html#token-type-ids>`_
585
+ position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
586
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
587
+ config.max_position_embeddings - 1]``.
588
+
589
+ `What are position IDs? <../glossary.html#position-ids>`_
590
+ head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
591
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:
592
+
593
+ - 1 indicates the head is **not masked**,
594
+ - 0 indicates the head is **masked**.
595
+
596
+ inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
597
+ Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
598
+ This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
599
+ vectors than the model's internal embedding lookup matrix.
600
+
601
+ If :obj:`past_key_values` is used, optionally only the last :obj:`inputs_embeds` have to be input (see
602
+ :obj:`past_key_values`).
603
+ use_cache (:obj:`bool`, `optional`):
604
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
605
+ decoding (see :obj:`past_key_values`).
606
+ output_attentions (:obj:`bool`, `optional`):
607
+ Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
608
+ tensors for more detail.
609
+ output_hidden_states (:obj:`bool`, `optional`):
610
+ Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
611
+ more detail.
612
+ return_dict (:obj:`bool`, `optional`):
613
+ Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
614
+ """
615
+
616
+ PARALLELIZE_DOCSTRING = r"""
617
+ This is an experimental feature and is a subject to change at a moment's notice.
618
+
619
+ Uses a device map to distribute attention modules of the model across several devices. If no device map is given,
620
+ it will evenly distribute blocks across all devices.
621
+
622
+ Args:
623
+ device_map (:obj:`Dict[int, list]`, optional, defaults to None):
624
+ A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always
625
+ automatically mapped to the first device (for esoteric reasons). That means that the first device should
626
+ have fewer attention modules mapped to it than other devices. For reference, the gpt2 models have the
627
+ following number of attention modules:
628
+
629
+ - gpt2: 12
630
+ - gpt2-medium: 24
631
+ - gpt2-large: 36
632
+ - gpt2-xl: 48
633
+
634
+ Example::
635
+
636
+ # Here is an example of a device map on a machine with 4 GPUs using gpt2-xl, which has a total of 48 attention modules:
637
+ model = GPT2LMHeadModel.from_pretrained('gpt2-xl')
638
+ device_map = {0: [0, 1, 2, 3, 4, 5, 6, 7, 8],
639
+
640
+ 1: [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21],
641
+ 2: [22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34],
642
+ 3: [35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47]}
643
+ model.parallelize(device_map)
644
+ """
645
+ DEPARALLELIZE_DOCSTRING = r"""
646
+ Moves the model to cpu from a model parallel state.
647
+
648
+ Example::
649
+
650
+ # On a 4 GPU machine with gpt2-large:
651
+ model = GPT2LMHeadModel.from_pretrained('gpt2-large')
652
+ device_map = {0: [0, 1, 2, 3, 4, 5, 6, 7],
653
+
654
+ 1: [8, 9, 10, 11, 12, 13, 14, 15],
655
+ 2: [16, 17, 18, 19, 20, 21, 22, 23],
656
+ 3: [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35]}
657
+ model.parallelize(device_map) # Splits the model across several devices
658
+ model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache()
659
+ """
660
+
661
+
662
+ @add_start_docstrings(
663
+ "The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.",
664
+ GPT2_START_DOCSTRING,
665
+ )
666
+ class GPT2Model(GPT2PreTrainedModel):
667
+ def __init__(self, config):
668
+ super().__init__(config)
669
+
670
+ self.wte = nn.Embedding(config.vocab_size, config.n_embd)
671
+ self.wpe = nn.Embedding(config.n_positions, config.n_embd)
672
+ if _USE_GROVER:
673
+ self.emb_norm = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
674
+
675
+ self.drop = nn.Dropout(config.embd_pdrop)
676
+ self.h = nn.ModuleList(
677
+ [Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)]
678
+ )
679
+ if not _USE_GROVER:
680
+ self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
681
+
682
+ self.init_weights()
683
+
684
+ # Model parallel
685
+ self.model_parallel = False
686
+ self.device_map = None
687
+
688
+ @add_start_docstrings(PARALLELIZE_DOCSTRING)
689
+ def parallelize(self, device_map=None):
690
+ # Check validity of device_map
691
+ self.device_map = (
692
+ get_device_map(len(self.h), range(torch.cuda.device_count()))
693
+ if device_map is None
694
+ else device_map
695
+ )
696
+ assert_device_map(self.device_map, len(self.h))
697
+ self.model_parallel = True
698
+ self.first_device = (
699
+ "cpu"
700
+ if "cpu" in self.device_map.keys()
701
+ else "cuda:" + str(min(self.device_map.keys()))
702
+ )
703
+ self.last_device = "cuda:" + str(max(self.device_map.keys()))
704
+ self.wte = self.wte.to(self.first_device)
705
+ self.wpe = self.wpe.to(self.first_device)
706
+ # Load onto devices
707
+ for k, v in self.device_map.items():
708
+ for block in v:
709
+ cuda_device = "cuda:" + str(k)
710
+ self.h[block] = self.h[block].to(cuda_device)
711
+ # ln_f to last
712
+ self.ln_f = self.ln_f.to(self.last_device)
713
+
714
+ @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
715
+ def deparallelize(self):
716
+ self.model_parallel = False
717
+ self.device_map = None
718
+ self.first_device = "cpu"
719
+ self.last_device = "cpu"
720
+ self.wte = self.wte.to("cpu")
721
+ self.wpe = self.wpe.to("cpu")
722
+ for index in range(len(self.h)):
723
+ self.h[index] = self.h[index].to("cpu")
724
+ self.ln_f = self.ln_f.to("cpu")
725
+ torch.cuda.empty_cache()
726
+
727
+ def get_input_embeddings(self):
728
+ return self.wte
729
+
730
+ def set_input_embeddings(self, new_embeddings):
731
+ self.wte = new_embeddings
732
+
733
+ def _prune_heads(self, heads_to_prune):
734
+ """
735
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
736
+ """
737
+ for layer, heads in heads_to_prune.items():
738
+ self.h[layer].attn.prune_heads(heads)
739
+
740
+ @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
741
+ @add_code_sample_docstrings(
742
+ tokenizer_class=_TOKENIZER_FOR_DOC,
743
+ checkpoint="gpt2",
744
+ output_type=BaseModelOutputWithPastAndCrossAttentions,
745
+ config_class=_CONFIG_FOR_DOC,
746
+ )
747
+ def forward(
748
+ self,
749
+ input_ids=None,
750
+ past_key_values=None,
751
+ attention_mask=None,
752
+ token_type_ids=None,
753
+ position_ids=None,
754
+ head_mask=None,
755
+ inputs_embeds=None,
756
+ encoder_hidden_states=None,
757
+ encoder_attention_mask=None,
758
+ use_cache=None,
759
+ output_attentions=None,
760
+ output_hidden_states=None,
761
+ return_dict=None,
762
+ ):
763
+ output_attentions = (
764
+ output_attentions
765
+ if output_attentions is not None
766
+ else self.config.output_attentions
767
+ )
768
+ output_hidden_states = (
769
+ output_hidden_states
770
+ if output_hidden_states is not None
771
+ else self.config.output_hidden_states
772
+ )
773
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
774
+ return_dict = (
775
+ return_dict if return_dict is not None else self.config.use_return_dict
776
+ )
777
+
778
+ if input_ids is not None and inputs_embeds is not None:
779
+ raise ValueError(
780
+ "You cannot specify both input_ids and inputs_embeds at the same time"
781
+ )
782
+ elif input_ids is not None:
783
+ input_shape = input_ids.size()
784
+ input_ids = input_ids.view(-1, input_shape[-1])
785
+ batch_size = input_ids.shape[0]
786
+ elif inputs_embeds is not None:
787
+ input_shape = inputs_embeds.size()[:-1]
788
+ batch_size = inputs_embeds.shape[0]
789
+ else:
790
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
791
+
792
+ if token_type_ids is not None:
793
+ token_type_ids = token_type_ids.view(-1, input_shape[-1])
794
+ if position_ids is not None:
795
+ position_ids = position_ids.view(-1, input_shape[-1])
796
+
797
+ if past_key_values is None:
798
+ past_length = 0
799
+ past_key_values = [None] * len(self.h)
800
+ else:
801
+ past_length = past_key_values[0][0].size(-2)
802
+ if position_ids is None:
803
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
804
+ position_ids = torch.arange(
805
+ past_length,
806
+ input_shape[-1] + past_length,
807
+ dtype=torch.long,
808
+ device=device,
809
+ )
810
+ position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
811
+
812
+ # Attention mask.
813
+ if attention_mask is not None:
814
+ if batch_size <= 0:
815
+ raise ValueError("batch_size has to be defined and > 0")
816
+ attention_mask = attention_mask.view(batch_size, -1)
817
+ # We create a 3D attention mask from a 2D tensor mask.
818
+ # Sizes are [batch_size, 1, 1, to_seq_length]
819
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
820
+ # this attention mask is more simple than the triangular masking of causal attention
821
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
822
+ attention_mask = attention_mask[:, None, None, :]
823
+
824
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
825
+ # masked positions, this operation will create a tensor which is 0.0 for
826
+ # positions we want to attend and -10000.0 for masked positions.
827
+ # Since we are adding it to the raw scores before the softmax, this is
828
+ # effectively the same as removing these entirely.
829
+ attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
830
+ attention_mask = (1.0 - attention_mask) * -10000.0
831
+
832
+ # If a 2D ou 3D attention mask is provided for the cross-attention
833
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
834
+ if self.config.add_cross_attention and encoder_hidden_states is not None:
835
+ (
836
+ encoder_batch_size,
837
+ encoder_sequence_length,
838
+ _,
839
+ ) = encoder_hidden_states.size()
840
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
841
+ if encoder_attention_mask is None:
842
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
843
+ encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
844
+ else:
845
+ encoder_attention_mask = None
846
+
847
+ # Prepare head mask if needed
848
+ # 1.0 in head_mask indicate we keep the head
849
+ # attention_probs has shape bsz x n_heads x N x N
850
+ # head_mask has shape n_layer x batch x n_heads x N x N
851
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
852
+
853
+ if inputs_embeds is None:
854
+ inputs_embeds = self.wte(input_ids)
855
+ position_embeds = self.wpe(position_ids)
856
+ hidden_states = inputs_embeds + position_embeds
857
+
858
+ if token_type_ids is not None:
859
+ token_type_embeds = self.wte(token_type_ids)
860
+ hidden_states = hidden_states + token_type_embeds
861
+
862
+ hidden_states = self.drop(hidden_states)
863
+ if _USE_GROVER:
864
+ hidden_states = self.emb_norm(hidden_states)
865
+ output_shape = input_shape + (hidden_states.size(-1),)
866
+
867
+ presents = () if use_cache else None
868
+ all_self_attentions = () if output_attentions else None
869
+ all_cross_attentions = (
870
+ () if output_attentions and self.config.add_cross_attention else None
871
+ )
872
+ all_hidden_states = () if output_hidden_states else None
873
+ for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
874
+
875
+ # Model parallel
876
+ if self.model_parallel:
877
+ torch.cuda.set_device(hidden_states.device)
878
+ # Ensure layer_past is on same device as hidden_states (might not be correct)
879
+ if layer_past is not None:
880
+ layer_past = tuple(
881
+ past_state.to(hidden_states.device) for past_state in layer_past
882
+ )
883
+ # Ensure that attention_mask is always on the same device as hidden_states
884
+ if attention_mask is not None:
885
+ attention_mask = attention_mask.to(hidden_states.device)
886
+ if isinstance(head_mask, torch.Tensor):
887
+ head_mask = head_mask.to(hidden_states.device)
888
+
889
+ if output_hidden_states:
890
+ all_hidden_states = all_hidden_states + (
891
+ hidden_states.view(*output_shape),
892
+ )
893
+
894
+ if getattr(self.config, "gradient_checkpointing", False):
895
+
896
+ def create_custom_forward(module):
897
+ def custom_forward(*inputs):
898
+ # checkpointing only works with tuple returns, not with lists
899
+ return tuple(
900
+ output
901
+ for output in module(*inputs, use_cache, output_attentions)
902
+ )
903
+
904
+ return custom_forward
905
+
906
+ outputs = torch.utils.checkpoint.checkpoint(
907
+ create_custom_forward(block),
908
+ hidden_states,
909
+ layer_past,
910
+ attention_mask,
911
+ head_mask[i],
912
+ encoder_hidden_states,
913
+ encoder_attention_mask,
914
+ )
915
+ else:
916
+ outputs = block(
917
+ hidden_states,
918
+ layer_past=layer_past,
919
+ attention_mask=attention_mask,
920
+ head_mask=head_mask[i],
921
+ encoder_hidden_states=encoder_hidden_states,
922
+ encoder_attention_mask=encoder_attention_mask,
923
+ use_cache=use_cache,
924
+ output_attentions=output_attentions,
925
+ )
926
+
927
+ hidden_states, present = outputs[:2]
928
+ if use_cache is True:
929
+ presents = presents + (present,)
930
+
931
+ if output_attentions:
932
+ all_self_attentions = all_self_attentions + (
933
+ outputs[2 if use_cache else 1],
934
+ )
935
+ if self.config.add_cross_attention:
936
+ all_cross_attentions = all_cross_attentions + (
937
+ outputs[3 if use_cache else 2],
938
+ )
939
+
940
+ # Model Parallel: If it's the last layer for that device, put things on the next device
941
+ if self.model_parallel:
942
+ for k, v in self.device_map.items():
943
+ if i == v[-1] and "cuda:" + str(k) != self.last_device:
944
+ hidden_states = hidden_states.to("cuda:" + str(k + 1))
945
+
946
+ if not _USE_GROVER:
947
+ hidden_states = self.ln_f(hidden_states)
948
+
949
+ hidden_states = hidden_states.view(*output_shape)
950
+ # Add last hidden state
951
+ if output_hidden_states:
952
+ all_hidden_states = all_hidden_states + (hidden_states,)
953
+
954
+ if not return_dict:
955
+ return tuple(
956
+ v
957
+ for v in [
958
+ hidden_states,
959
+ presents,
960
+ all_hidden_states,
961
+ all_self_attentions,
962
+ all_cross_attentions,
963
+ ]
964
+ if v is not None
965
+ )
966
+
967
+ return BaseModelOutputWithPastAndCrossAttentions(
968
+ last_hidden_state=hidden_states,
969
+ past_key_values=presents,
970
+ hidden_states=all_hidden_states,
971
+ attentions=all_self_attentions,
972
+ cross_attentions=all_cross_attentions,
973
+ )
974
+
975
+
976
+ @add_start_docstrings(
977
+ """
978
+ The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input
979
+ embeddings).
980
+ """,
981
+ GPT2_START_DOCSTRING,
982
+ )
983
+ class GPT2LMHeadModel(GPT2PreTrainedModel):
984
+ _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"]
985
+
986
+ def __init__(self, config):
987
+ super().__init__(config)
988
+ self.transformer = GPT2Model(config)
989
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
990
+
991
+ self.init_weights()
992
+
993
+ # Model parallel
994
+ self.model_parallel = False
995
+ self.device_map = None
996
+
997
+ @add_start_docstrings(PARALLELIZE_DOCSTRING)
998
+ def parallelize(self, device_map=None):
999
+ self.device_map = (
1000
+ get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
1001
+ if device_map is None
1002
+ else device_map
1003
+ )
1004
+ assert_device_map(self.device_map, len(self.transformer.h))
1005
+ self.transformer.parallelize(self.device_map)
1006
+ self.lm_head = self.lm_head.to(self.transformer.first_device)
1007
+ self.model_parallel = True
1008
+
1009
+ @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
1010
+ def deparallelize(self):
1011
+ self.transformer.deparallelize()
1012
+ self.transformer = self.transformer.to("cpu")
1013
+ self.lm_head = self.lm_head.to("cpu")
1014
+ self.model_parallel = False
1015
+ torch.cuda.empty_cache()
1016
+
1017
+ def get_output_embeddings(self):
1018
+ return self.lm_head
1019
+
1020
+ def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
1021
+ token_type_ids = kwargs.get("token_type_ids", None)
1022
+ # only last token for inputs_ids if past is defined in kwargs
1023
+ if past:
1024
+ input_ids = input_ids[:, -1].unsqueeze(-1)
1025
+ if token_type_ids is not None:
1026
+ token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
1027
+
1028
+ attention_mask = kwargs.get("attention_mask", None)
1029
+ position_ids = kwargs.get("position_ids", None)
1030
+
1031
+ if attention_mask is not None and position_ids is None:
1032
+ # create position_ids on the fly for batch generation
1033
+ position_ids = attention_mask.long().cumsum(-1) - 1
1034
+ position_ids.masked_fill_(attention_mask == 0, 1)
1035
+ if past:
1036
+ position_ids = position_ids[:, -1].unsqueeze(-1)
1037
+ else:
1038
+ position_ids = None
1039
+ return {
1040
+ "input_ids": input_ids,
1041
+ "past_key_values": past,
1042
+ "use_cache": kwargs.get("use_cache"),
1043
+ "position_ids": position_ids,
1044
+ "attention_mask": attention_mask,
1045
+ "token_type_ids": token_type_ids,
1046
+ }
1047
+
1048
+ @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
1049
+ @add_code_sample_docstrings(
1050
+ tokenizer_class=_TOKENIZER_FOR_DOC,
1051
+ checkpoint="gpt2",
1052
+ output_type=CausalLMOutputWithCrossAttentions,
1053
+ config_class=_CONFIG_FOR_DOC,
1054
+ )
1055
+ def forward(
1056
+ self,
1057
+ input_ids=None,
1058
+ past_key_values=None,
1059
+ attention_mask=None,
1060
+ token_type_ids=None,
1061
+ position_ids=None,
1062
+ head_mask=None,
1063
+ inputs_embeds=None,
1064
+ encoder_hidden_states=None,
1065
+ encoder_attention_mask=None,
1066
+ labels=None,
1067
+ use_cache=None,
1068
+ output_attentions=None,
1069
+ output_hidden_states=None,
1070
+ return_dict=None,
1071
+ ):
1072
+ r"""
1073
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1074
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
1075
+ ``labels = input_ids`` Indices are selected in ``[-100, 0, ..., config.vocab_size]`` All labels set to
1076
+ ``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ..., config.vocab_size]``
1077
+ """
1078
+ return_dict = (
1079
+ return_dict if return_dict is not None else self.config.use_return_dict
1080
+ )
1081
+
1082
+ transformer_outputs = self.transformer(
1083
+ input_ids,
1084
+ past_key_values=past_key_values,
1085
+ attention_mask=attention_mask,
1086
+ token_type_ids=token_type_ids,
1087
+ position_ids=position_ids,
1088
+ head_mask=head_mask,
1089
+ inputs_embeds=inputs_embeds,
1090
+ encoder_hidden_states=encoder_hidden_states,
1091
+ encoder_attention_mask=encoder_attention_mask,
1092
+ use_cache=use_cache,
1093
+ output_attentions=output_attentions,
1094
+ output_hidden_states=output_hidden_states,
1095
+ return_dict=return_dict,
1096
+ )
1097
+ hidden_states = transformer_outputs[0]
1098
+
1099
+ # Set device for model parallelism
1100
+ if self.model_parallel:
1101
+ torch.cuda.set_device(self.transformer.first_device)
1102
+ hidden_states = hidden_states.to(self.lm_head.weight.device)
1103
+
1104
+ lm_logits = self.lm_head(hidden_states)
1105
+
1106
+ loss = None
1107
+ if labels is not None:
1108
+ # Shift so that tokens < n predict n
1109
+ shift_logits = lm_logits[..., :-1, :].contiguous()
1110
+ shift_labels = labels[..., 1:].contiguous()
1111
+ # Flatten the tokens
1112
+ loss_fct = CrossEntropyLoss()
1113
+ loss = loss_fct(
1114
+ shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
1115
+ )
1116
+
1117
+ if not return_dict:
1118
+ output = (lm_logits,) + transformer_outputs[1:]
1119
+ return ((loss,) + output) if loss is not None else output
1120
+
1121
+ return CausalLMOutputWithCrossAttentions(
1122
+ loss=loss,
1123
+ logits=lm_logits,
1124
+ past_key_values=transformer_outputs.past_key_values,
1125
+ hidden_states=transformer_outputs.hidden_states,
1126
+ attentions=transformer_outputs.attentions,
1127
+ cross_attentions=transformer_outputs.cross_attentions,
1128
+ )
1129
+
1130
+ @staticmethod
1131
+ def _reorder_cache(
1132
+ past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
1133
+ ) -> Tuple[Tuple[torch.Tensor]]:
1134
+ """
1135
+ This function is used to re-order the :obj:`past_key_values` cache if
1136
+ :meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is
1137
+ called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
1138
+ """
1139
+ return tuple(
1140
+ tuple(
1141
+ past_state.index_select(0, beam_idx.to(past_state.device))
1142
+ for past_state in layer_past
1143
+ )
1144
+ for layer_past in past
1145
+ )
1146
+
1147
+
1148
+ @add_start_docstrings(
1149
+ """
1150
+ The GPT2 Model transformer with a language modeling and a multiple-choice classification head on top e.g. for
1151
+ RocStories/SWAG tasks. The two heads are two linear layers. The language modeling head has its weights tied to the
1152
+ input embeddings, the classification head takes as input the input of a specified classification token index in the
1153
+ input sequence).
1154
+ """,
1155
+ GPT2_START_DOCSTRING,
1156
+ )
1157
+ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
1158
+ def __init__(self, config):
1159
+ super().__init__(config)
1160
+ config.num_labels = 1
1161
+ self.transformer = GPT2Model(config)
1162
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
1163
+ self.multiple_choice_head = SequenceSummary(config)
1164
+
1165
+ self.init_weights()
1166
+
1167
+ # Model parallel
1168
+ self.model_parallel = False
1169
+ self.device_map = None
1170
+
1171
+ @add_start_docstrings(PARALLELIZE_DOCSTRING)
1172
+ def parallelize(self, device_map=None):
1173
+ self.device_map = (
1174
+ get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
1175
+ if device_map is None
1176
+ else device_map
1177
+ )
1178
+ assert_device_map(self.device_map, len(self.transformer.h))
1179
+ self.transformer.parallelize(self.device_map)
1180
+ self.lm_head = self.lm_head.to(self.transformer.first_device)
1181
+ self.multiple_choice_head = self.multiple_choice_head.to(
1182
+ self.transformer.first_device
1183
+ )
1184
+ self.model_parallel = True
1185
+
1186
+ @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
1187
+ def deparallelize(self):
1188
+ self.transformer.deparallelize()
1189
+ self.transformer = self.transformer.to("cpu")
1190
+ self.lm_head = self.lm_head.to("cpu")
1191
+ self.multiple_choice_head = self.multiple_choice_head.to("cpu")
1192
+ self.model_parallel = False
1193
+ torch.cuda.empty_cache()
1194
+
1195
+ def get_output_embeddings(self):
1196
+ return self.lm_head
1197
+
1198
+ def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
1199
+ token_type_ids = kwargs.get("token_type_ids", None)
1200
+ # only last token for inputs_ids if past is defined in kwargs
1201
+ if past:
1202
+ input_ids = input_ids[:, -1].unsqueeze(-1)
1203
+ if token_type_ids is not None:
1204
+ token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
1205
+
1206
+ attention_mask = kwargs.get("attention_mask", None)
1207
+ position_ids = kwargs.get("position_ids", None)
1208
+
1209
+ if attention_mask is not None and position_ids is None:
1210
+ # create position_ids on the fly for batch generation
1211
+ position_ids = attention_mask.long().cumsum(-1) - 1
1212
+ position_ids.masked_fill_(attention_mask == 0, 1)
1213
+ if past:
1214
+ position_ids = position_ids[:, -1].unsqueeze(-1)
1215
+ else:
1216
+ position_ids = None
1217
+
1218
+ return {
1219
+ "input_ids": input_ids,
1220
+ "past_key_values": past,
1221
+ "use_cache": kwargs.get("use_cache"),
1222
+ "position_ids": position_ids,
1223
+ "attention_mask": attention_mask,
1224
+ "token_type_ids": token_type_ids,
1225
+ }
1226
+
1227
+ @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
1228
+ @replace_return_docstrings(
1229
+ output_type=GPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC
1230
+ )
1231
+ def forward(
1232
+ self,
1233
+ input_ids=None,
1234
+ past_key_values=None,
1235
+ attention_mask=None,
1236
+ token_type_ids=None,
1237
+ position_ids=None,
1238
+ head_mask=None,
1239
+ inputs_embeds=None,
1240
+ mc_token_ids=None,
1241
+ labels=None,
1242
+ mc_labels=None,
1243
+ use_cache=None,
1244
+ output_attentions=None,
1245
+ output_hidden_states=None,
1246
+ return_dict=None,
1247
+ **kwargs,
1248
+ ):
1249
+ r"""
1250
+ mc_token_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, num_choices)`, `optional`, default to index of the last token of the input):
1251
+ Index of the classification token in each input sequence. Selected in the range ``[0, input_ids.size(-1) -
1252
+ 1[``.
1253
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1254
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
1255
+ ``labels = input_ids`` Indices are selected in ``[-1, 0, ..., config.vocab_size]`` All labels set to
1256
+ ``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ..., config.vocab_size]``
1257
+ mc_labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size)`, `optional`):
1258
+ Labels for computing the multiple choice classification loss. Indices should be in ``[0, ...,
1259
+ num_choices]`` where `num_choices` is the size of the second dimension of the input tensors. (see
1260
+ `input_ids` above)
1261
+
1262
+ Return:
1263
+
1264
+ Example::
1265
+
1266
+ >>> import torch
1267
+ >>> from transformers import GPT2Tokenizer, GPT2DoubleHeadsModel
1268
+
1269
+ >>> tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
1270
+ >>> model = GPT2DoubleHeadsModel.from_pretrained('gpt2')
1271
+
1272
+ >>> # Add a [CLS] to the vocabulary (we should train it also!)
1273
+ >>> num_added_tokens = tokenizer.add_special_tokens({'cls_token': '[CLS]'})
1274
+
1275
+ >>> embedding_layer = model.resize_token_embeddings(len(tokenizer)) # Update the model embeddings with the new vocabulary size
1276
+
1277
+ >>> choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"]
1278
+ >>> encoded_choices = [tokenizer.encode(s) for s in choices]
1279
+ >>> cls_token_location = [tokens.index(tokenizer.cls_token_id) for tokens in encoded_choices]
1280
+
1281
+ >>> input_ids = torch.tensor(encoded_choices).unsqueeze(0) # Batch size: 1, number of choices: 2
1282
+ >>> mc_token_ids = torch.tensor([cls_token_location]) # Batch size: 1
1283
+
1284
+ >>> outputs = model(input_ids, mc_token_ids=mc_token_ids)
1285
+ >>> lm_logits = outputs.lm_logits
1286
+ >>> mc_logits = outputs.mc_logits
1287
+
1288
+ """
1289
+ return_dict = (
1290
+ return_dict if return_dict is not None else self.config.use_return_dict
1291
+ )
1292
+
1293
+ transformer_outputs = self.transformer(
1294
+ input_ids,
1295
+ past_key_values=past_key_values,
1296
+ attention_mask=attention_mask,
1297
+ token_type_ids=token_type_ids,
1298
+ position_ids=position_ids,
1299
+ head_mask=head_mask,
1300
+ inputs_embeds=inputs_embeds,
1301
+ use_cache=use_cache,
1302
+ output_attentions=output_attentions,
1303
+ output_hidden_states=output_hidden_states,
1304
+ return_dict=return_dict,
1305
+ )
1306
+
1307
+ hidden_states = transformer_outputs[0]
1308
+
1309
+ # Set device for model parallelism
1310
+ if self.model_parallel:
1311
+ torch.cuda.set_device(self.transformer.first_device)
1312
+ hidden_states = hidden_states.to(self.lm_head.weight.device)
1313
+
1314
+ lm_logits = self.lm_head(hidden_states)
1315
+ mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1)
1316
+
1317
+ mc_loss = None
1318
+ if mc_labels is not None:
1319
+ loss_fct = CrossEntropyLoss()
1320
+ mc_loss = loss_fct(
1321
+ mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)
1322
+ )
1323
+ lm_loss = None
1324
+ if labels is not None:
1325
+ shift_logits = lm_logits[..., :-1, :].contiguous()
1326
+ shift_labels = labels[..., 1:].contiguous()
1327
+ loss_fct = CrossEntropyLoss()
1328
+ lm_loss = loss_fct(
1329
+ shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
1330
+ )
1331
+
1332
+ if not return_dict:
1333
+ output = (lm_logits, mc_logits) + transformer_outputs[1:]
1334
+ if mc_loss is not None:
1335
+ output = (mc_loss,) + output
1336
+ return ((lm_loss,) + output) if lm_loss is not None else output
1337
+
1338
+ return GPT2DoubleHeadsModelOutput(
1339
+ loss=lm_loss,
1340
+ mc_loss=mc_loss,
1341
+ logits=lm_logits,
1342
+ mc_logits=mc_logits,
1343
+ past_key_values=transformer_outputs.past_key_values,
1344
+ hidden_states=transformer_outputs.hidden_states,
1345
+ attentions=transformer_outputs.attentions,
1346
+ )
1347
+
1348
+ @staticmethod
1349
+ def _reorder_cache(
1350
+ past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
1351
+ ) -> Tuple[Tuple[torch.Tensor]]:
1352
+ """
1353
+ This function is used to re-order the :obj:`past_key_values` cache if
1354
+ :meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is
1355
+ called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
1356
+ """
1357
+ return tuple(
1358
+ tuple(
1359
+ past_state.index_select(0, beam_idx.to(past_state.device))
1360
+ for past_state in layer_past
1361
+ )
1362
+ for layer_past in past
1363
+ )
1364
+
1365
+
1366
+ @add_start_docstrings(
1367
+ """
1368
+ The GPT2 Model transformer with a sequence classification head on top (linear layer).
1369
+
1370
+ :class:`~transformers.GPT2ForSequenceClassification` uses the last token in order to do the classification, as
1371
+ other causal models (e.g. GPT-1) do.
1372
+
1373
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1374
+ :obj:`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each
1375
+ row. If no :obj:`pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot
1376
+ guess the padding tokens when :obj:`inputs_embeds` are passed instead of :obj:`input_ids`, it does the same (take
1377
+ the last value in each row of the batch).
1378
+ """,
1379
+ GPT2_START_DOCSTRING,
1380
+ )
1381
+ class GPT2ForSequenceClassification(GPT2PreTrainedModel):
1382
+ _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"]
1383
+
1384
+ def __init__(self, config):
1385
+ super().__init__(config)
1386
+ self.num_labels = config.num_labels
1387
+ self.transformer = GPT2Model(config)
1388
+ self.score = nn.Linear(config.n_embd, self.num_labels, bias=False)
1389
+
1390
+ self.init_weights()
1391
+
1392
+ # Model parallel
1393
+ self.model_parallel = False
1394
+ self.device_map = None
1395
+
1396
+ @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
1397
+ @add_code_sample_docstrings(
1398
+ tokenizer_class=_TOKENIZER_FOR_DOC,
1399
+ checkpoint="microsoft/dialogrpt",
1400
+ output_type=SequenceClassifierOutputWithPast,
1401
+ config_class=_CONFIG_FOR_DOC,
1402
+ )
1403
+ def forward(
1404
+ self,
1405
+ input_ids=None,
1406
+ past_key_values=None,
1407
+ attention_mask=None,
1408
+ token_type_ids=None,
1409
+ position_ids=None,
1410
+ head_mask=None,
1411
+ inputs_embeds=None,
1412
+ labels=None,
1413
+ use_cache=None,
1414
+ output_attentions=None,
1415
+ output_hidden_states=None,
1416
+ return_dict=None,
1417
+ ):
1418
+ r"""
1419
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1420
+ Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
1421
+ config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
1422
+ If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1423
+ """
1424
+ return_dict = (
1425
+ return_dict if return_dict is not None else self.config.use_return_dict
1426
+ )
1427
+
1428
+ transformer_outputs = self.transformer(
1429
+ input_ids,
1430
+ past_key_values=past_key_values,
1431
+ attention_mask=attention_mask,
1432
+ token_type_ids=token_type_ids,
1433
+ position_ids=position_ids,
1434
+ head_mask=head_mask,
1435
+ inputs_embeds=inputs_embeds,
1436
+ use_cache=use_cache,
1437
+ output_attentions=output_attentions,
1438
+ output_hidden_states=output_hidden_states,
1439
+ return_dict=return_dict,
1440
+ )
1441
+ hidden_states = transformer_outputs[0]
1442
+ logits = self.score(hidden_states)
1443
+
1444
+ if input_ids is not None:
1445
+ batch_size, sequence_length = input_ids.shape[:2]
1446
+ else:
1447
+ batch_size, sequence_length = inputs_embeds.shape[:2]
1448
+
1449
+ assert (
1450
+ self.config.pad_token_id is not None or batch_size == 1
1451
+ ), "Cannot handle batch sizes > 1 if no padding token is defined."
1452
+ if self.config.pad_token_id is None:
1453
+ sequence_lengths = -1
1454
+ else:
1455
+ if input_ids is not None:
1456
+ sequence_lengths = (
1457
+ torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
1458
+ )
1459
+ else:
1460
+ sequence_lengths = -1
1461
+ logger.warning(
1462
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
1463
+ f"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
1464
+ )
1465
+
1466
+ pooled_logits = logits[range(batch_size), sequence_lengths]
1467
+
1468
+ loss = None
1469
+ if labels is not None:
1470
+ if self.num_labels == 1:
1471
+ # We are doing regression
1472
+ loss_fct = MSELoss()
1473
+ loss = loss_fct(pooled_logits.view(-1), labels.to(self.dtype).view(-1))
1474
+ else:
1475
+ loss_fct = CrossEntropyLoss()
1476
+ loss = loss_fct(
1477
+ pooled_logits.view(-1, self.num_labels), labels.view(-1)
1478
+ )
1479
+
1480
+ if not return_dict:
1481
+ output = (pooled_logits,) + transformer_outputs[1:]
1482
+ return ((loss,) + output) if loss is not None else output
1483
+
1484
+ return SequenceClassifierOutputWithPast(
1485
+ loss=loss,
1486
+ logits=pooled_logits,
1487
+ past_key_values=transformer_outputs.past_key_values,
1488
+ hidden_states=transformer_outputs.hidden_states,
1489
+ attentions=transformer_outputs.attentions,
1490
+ )
1491
+
1492
+
1493
+ @add_start_docstrings(
1494
+ """
1495
+ GPT2 Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1496
+ Named-Entity-Recognition (NER) tasks.
1497
+ """,
1498
+ GPT2_START_DOCSTRING,
1499
+ )
1500
+ class GPT2ForTokenClassification(GPT2PreTrainedModel):
1501
+ def __init__(self, config):
1502
+ super().__init__(config)
1503
+ self.num_labels = config.num_labels
1504
+
1505
+ self.transformer = GPT2Model(config)
1506
+ if (
1507
+ hasattr(config, "classifier_dropout")
1508
+ and config.classifier_dropout is not None
1509
+ ):
1510
+ classifier_dropout = config.classifier_dropout
1511
+ elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
1512
+ classifier_dropout = config.hidden_dropout
1513
+ else:
1514
+ classifier_dropout = 0.1
1515
+ self.dropout = nn.Dropout(classifier_dropout)
1516
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1517
+
1518
+ self.init_weights()
1519
+
1520
+ # Model parallel
1521
+ self.model_parallel = False
1522
+ self.device_map = None
1523
+
1524
+ @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
1525
+ @add_code_sample_docstrings(
1526
+ tokenizer_class=_TOKENIZER_FOR_DOC,
1527
+ checkpoint="microsoft/DialogRPT-updown",
1528
+ output_type=TokenClassifierOutput,
1529
+ config_class=_CONFIG_FOR_DOC,
1530
+ )
1531
+ def forward(
1532
+ self,
1533
+ input_ids=None,
1534
+ past_key_values=None,
1535
+ attention_mask=None,
1536
+ token_type_ids=None,
1537
+ position_ids=None,
1538
+ head_mask=None,
1539
+ inputs_embeds=None,
1540
+ labels=None,
1541
+ use_cache=None,
1542
+ output_attentions=None,
1543
+ output_hidden_states=None,
1544
+ return_dict=None,
1545
+ ):
1546
+ r"""
1547
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1548
+ Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
1549
+ config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
1550
+ If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1551
+ """
1552
+ return_dict = (
1553
+ return_dict if return_dict is not None else self.config.use_return_dict
1554
+ )
1555
+
1556
+ transformer_outputs = self.transformer(
1557
+ input_ids,
1558
+ past_key_values=past_key_values,
1559
+ attention_mask=attention_mask,
1560
+ token_type_ids=token_type_ids,
1561
+ position_ids=position_ids,
1562
+ head_mask=head_mask,
1563
+ inputs_embeds=inputs_embeds,
1564
+ use_cache=use_cache,
1565
+ output_attentions=output_attentions,
1566
+ output_hidden_states=output_hidden_states,
1567
+ return_dict=return_dict,
1568
+ )
1569
+
1570
+ hidden_states = transformer_outputs[0]
1571
+ hidden_states = self.dropout(hidden_states)
1572
+ logits = self.classifier(hidden_states)
1573
+
1574
+ loss = None
1575
+ if labels is not None:
1576
+ loss_fct = CrossEntropyLoss()
1577
+ # Only keep active parts of the loss
1578
+ if attention_mask is not None:
1579
+ active_loss = attention_mask.view(-1) == 1
1580
+ active_logits = logits.view(-1, self.num_labels)
1581
+ active_labels = torch.where(
1582
+ active_loss,
1583
+ labels.view(-1),
1584
+ torch.tensor(loss_fct.ignore_index).type_as(labels),
1585
+ )
1586
+ loss = loss_fct(active_logits, active_labels)
1587
+ else:
1588
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1589
+
1590
+ if not return_dict:
1591
+ output = (logits,) + transformer_outputs[2:]
1592
+ return ((loss,) + output) if loss is not None else output
1593
+
1594
+ return TokenClassifierOutput(
1595
+ loss=loss,
1596
+ logits=logits,
1597
+ hidden_states=transformer_outputs.hidden_states,
1598
+ attentions=transformer_outputs.attentions,
1599
+ )
backend/preprocess.py ADDED
@@ -0,0 +1,736 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import html
2
+ import logging
3
+ import re
4
+ from typing import List
5
+ from farasa.segmenter import FarasaSegmenter
6
+ import emoji
7
+
8
+ import pyarabic.araby as araby
9
+
10
+ ACCEPTED_MODELS = [
11
+ "bert-base-arabertv01",
12
+ "bert-base-arabert",
13
+ "bert-base-arabertv02",
14
+ "bert-base-arabertv2",
15
+ "bert-large-arabertv02",
16
+ "bert-large-arabertv2",
17
+ "araelectra-base",
18
+ "araelectra-base-discriminator",
19
+ "araelectra-base-generator",
20
+ "araelectra-base-artydiqa",
21
+ "aragpt2-base",
22
+ "aragpt2-medium",
23
+ "aragpt2-large",
24
+ "aragpt2-mega",
25
+ ]
26
+
27
+ SEGMENTED_MODELS = [
28
+ "bert-base-arabert",
29
+ "bert-base-arabertv2",
30
+ "bert-large-arabertv2",
31
+ ]
32
+
33
+ SECOND_GEN_MODELS = [
34
+ "bert-base-arabertv02",
35
+ "bert-base-arabertv2",
36
+ "bert-large-arabertv02",
37
+ "bert-large-arabertv2",
38
+ "araelectra-base",
39
+ "araelectra-base-discriminator",
40
+ "araelectra-base-generator",
41
+ "araelectra-base-artydiqa",
42
+ "aragpt2-base",
43
+ "aragpt2-medium",
44
+ "aragpt2-large",
45
+ "aragpt2-mega",
46
+ ]
47
+
48
+ farasa_segmenter = FarasaSegmenter(interactive=True)
49
+
50
+
51
+ class ArabertPreprocessor:
52
+ """
53
+ A Preprocessor class that cleans and preprocesses text for all models in the AraBERT repo.
54
+ It also can unprocess the text ouput of the generated text
55
+
56
+ Args:
57
+
58
+ model_name (:obj:`str`): model name from the HuggingFace Models page without
59
+ the aubmindlab tag. Will default to a base Arabic preprocessor if model name was not found.
60
+ Current accepted models are:
61
+
62
+ - "bert-base-arabertv01": No farasa segmentation.
63
+ - "bert-base-arabert": with farasa segmentation.
64
+ - "bert-base-arabertv02": No farasas egmentation.
65
+ - "bert-base-arabertv2": with farasa segmentation.
66
+ - "bert-large-arabertv02": No farasas egmentation.
67
+ - "bert-large-arabertv2": with farasa segmentation.
68
+ - "araelectra-base": No farasa segmentation.
69
+ - "araelectra-base-discriminator": No farasa segmentation.
70
+ - "araelectra-base-generator": No farasa segmentation.
71
+ - "aragpt2-base": No farasa segmentation.
72
+ - "aragpt2-medium": No farasa segmentation.
73
+ - "aragpt2-large": No farasa segmentation.
74
+ - "aragpt2-mega": No farasa segmentation.
75
+
76
+
77
+ keep_emojis(:obj:`bool`, `optional`, defaults to :obj:`False`): don't remove emojis while preprocessing.
78
+
79
+ remove_html_markup(:obj: `bool`, `optional`, defaults to :obj:`True`): Whether to remove html artfacts,
80
+ should be set to False when preprocessing TyDi QA.
81
+
82
+ replace_urls_emails_mentions(:obj:`bool`, `optional`, defaults to :obj:`True`): Whether to replace email urls
83
+ and mentions by special tokens.
84
+
85
+ strip_tashkeel(:obj:`bool`, `optional`, defaults to :obj:`True`): remove diacritics (FATHATAN, DAMMATAN, KASRATAN, FATHA, DAMMA,
86
+ KASRA, SUKUN, SHADDA).
87
+
88
+ strip_tatweel(:obj:`bool`, `optional`, defaults to :obj:`True`): remove tatweel '\\u0640'.
89
+
90
+ insert_white_spaces(:obj:`bool`, `optional`, defaults to :obj:`True`): insert whitespace before and after all non Arabic digits
91
+ or English digits or Arabic and English Alphabet or the 2 brackets, then inserts whitespace
92
+ between words and numbers or numbers and words.
93
+
94
+ remove_non_digit_repetition(:obj:`bool`, `optional`, defaults to :obj:`True`): replace repetition of more than 2 non-digit character with
95
+ 2 of this character.
96
+
97
+ replace_slash_with_dash(:obj:`bool`, `optional`, defaults to :obj:`None`): Will be automatically set to True in AraBERTv02,
98
+ AraELECTRA and AraGPT2.
99
+ Set to False to force disable, and True to force enable. Replaces the "/" with "-",
100
+ since "/" is missing from AraBERTv2, AraELECTRA and ARAGPT2 vocabulary.
101
+
102
+ map_hindi_numbers_to_arabic(:obj:`bool`, `optional`, defaults to :obj:`None`): Will be automatically set to True in
103
+ AraBERTv02, AraELECTRA and AraGPT2.Set to False to force disable, and True to force enable.
104
+ Replaces hindi numbers with the corresponding Arabic one. ex: "١٩٩٥" --> "1995".
105
+ This is behavior is present by default in AraBERTv1 and v2 (with pre-segmentation),
106
+ and fixes the issue of caused by a bug when inserting white spaces.
107
+
108
+ apply_farasa_segmentation(:obj:`bool`, `optional`, defaults to :obj:`None`): Will be automatically set to True in
109
+ AraBERTv2, and AraBERTv1. Set to False to force disable, and True to force enable.
110
+
111
+
112
+
113
+ Returns:
114
+
115
+ ArabertPreprocessor: A preprocessor instance
116
+
117
+ Example:
118
+
119
+ from preprocess import ArabertPreprocessor
120
+
121
+ arabert_prep = ArabertPreprocessor("aubmindlab/bert-base-arabertv2")
122
+
123
+ arabert_prep.preprocess("SOME ARABIC TEXT")
124
+ """
125
+
126
+ def __init__(
127
+ self,
128
+ model_name: str,
129
+ keep_emojis: bool = False,
130
+ remove_html_markup: bool = True,
131
+ replace_urls_emails_mentions: bool = True,
132
+ strip_tashkeel: bool = True,
133
+ strip_tatweel: bool = True,
134
+ insert_white_spaces: bool = True,
135
+ remove_non_digit_repetition: bool = True,
136
+ replace_slash_with_dash: bool = None,
137
+ map_hindi_numbers_to_arabic: bool = None,
138
+ apply_farasa_segmentation: bool = None,
139
+ ):
140
+
141
+ model_name = model_name.replace("aubmindlab/", "").replace("wissamantoun/", "")
142
+
143
+ if model_name not in ACCEPTED_MODELS:
144
+ logging.warning(
145
+ """Model provided is not in the accepted model list. Preprocessor will default to a base Arabic preprocessor"""
146
+ )
147
+ self.model_name = "bert-base-arabertv02"
148
+ else:
149
+ self.model_name = model_name
150
+
151
+ if apply_farasa_segmentation is None:
152
+ if self.model_name in SEGMENTED_MODELS:
153
+ self.apply_farasa_segmentation = True
154
+ else:
155
+ self.apply_farasa_segmentation = False
156
+ else:
157
+ if (
158
+ apply_farasa_segmentation == False
159
+ and self.model_name in SEGMENTED_MODELS
160
+ ):
161
+ logging.warning(
162
+ "The selected model_name requires Farasa pre-segmentation, but apply_farasa_segmentation was set to False!"
163
+ )
164
+
165
+ self.apply_farasa_segmentation = apply_farasa_segmentation
166
+
167
+ self.keep_emojis = keep_emojis
168
+ self.remove_html_markup = remove_html_markup
169
+ self.replace_urls_emails_mentions = replace_urls_emails_mentions
170
+ self.strip_tashkeel = strip_tashkeel
171
+ self.strip_tatweel = strip_tatweel
172
+ self.insert_white_spaces = insert_white_spaces
173
+ self.remove_non_digit_repetition = remove_non_digit_repetition
174
+
175
+ if replace_slash_with_dash is None:
176
+ if self.model_name in SECOND_GEN_MODELS:
177
+ self.replace_slash_with_dash = True
178
+ else:
179
+ self.replace_slash_with_dash = False
180
+ else:
181
+ self.replace_slash_with_dash = replace_slash_with_dash
182
+
183
+ if map_hindi_numbers_to_arabic is None:
184
+ if self.model_name in SECOND_GEN_MODELS:
185
+ self.map_hindi_numbers_to_arabic = True
186
+ else:
187
+ self.map_hindi_numbers_to_arabic = False
188
+ else:
189
+ self.map_hindi_numbers_to_arabic = map_hindi_numbers_to_arabic
190
+
191
+ def preprocess(self, text: str) -> str:
192
+ """
193
+ Preprocess takes an input text line an applies the same preprocessing used in AraBERT
194
+ pretraining, or according to settings
195
+
196
+ Args:
197
+
198
+ text (:obj:`str`): inout text string
199
+
200
+ Returns:
201
+
202
+ string: A preprocessed string depending on which model was selected
203
+ """
204
+ if (
205
+ self.model_name == "bert-base-arabert"
206
+ or self.model_name == "bert-base-arabertv01"
207
+ ):
208
+ return self._preprocess_v1(
209
+ text,
210
+ do_farasa_tokenization=self.apply_farasa_segmentation,
211
+ )
212
+
213
+ if self.model_name in SECOND_GEN_MODELS:
214
+ return self._preprocess_v2(text)
215
+
216
+ return self._preprocess_v3(text)
217
+
218
+ def unpreprocess(self, text: str, desegment: bool = True) -> str:
219
+ """Re-formats the text to a classic format where punctuations, brackets, parenthesis are not seperated by whitespaces.
220
+ The objective is to make the generated text of any model appear natural and not preprocessed.
221
+
222
+ Args:
223
+ text (:obj:`str`): input text to be un-preprocessed
224
+ desegment (:obj:`bool`, optional): [whether or not to remove farasa pre-segmentation before]..
225
+
226
+ Returns:
227
+ str: The unpreprocessed (and possibly Farasa-desegmented) text.
228
+ """
229
+
230
+ if self.apply_farasa_segmentation and desegment:
231
+ text = self.desegment(text)
232
+
233
+ # removes the spaces around quotation marks ex: i " ate " an apple --> i "ate" an apple
234
+ # https://stackoverflow.com/a/53436792/5381220
235
+ text = re.sub(white_spaced_double_quotation_regex, '"' + r"\1" + '"', text)
236
+ text = re.sub(white_spaced_single_quotation_regex, "'" + r"\1" + "'", text)
237
+ text = re.sub(white_spaced_back_quotation_regex, "\`" + r"\1" + "\`", text)
238
+ text = re.sub(white_spaced_back_quotation_regex, "\—" + r"\1" + "\—", text)
239
+
240
+ # during generation, sometimes the models don't put a space after the dot, this handles it
241
+ text = text.replace(".", " . ")
242
+ text = " ".join(text.split())
243
+
244
+ # handle decimals
245
+ text = re.sub(r"(\d+) \. (\d+)", r"\1.\2", text)
246
+ text = re.sub(r"(\d+) \, (\d+)", r"\1,\2", text)
247
+
248
+ text = re.sub(left_and_right_spaced_chars, r"\1", text)
249
+ text = re.sub(left_spaced_chars, r"\1", text)
250
+ text = re.sub(right_spaced_chars, r"\1", text)
251
+
252
+ return text
253
+
254
+ def desegment(self, text: str) -> str:
255
+ """
256
+ Use this function if sentence tokenization was done using
257
+ `from arabert.preprocess_arabert import preprocess` with Farasa enabled
258
+ AraBERT segmentation using Farasa adds a space after the '+' for prefixes,
259
+ and after before the '+' for suffixes
260
+
261
+ Example:
262
+ >>> desegment('ال+ دراس +ات')
263
+ الدراسات
264
+ """
265
+ text = text.replace("+ ", "+")
266
+ text = text.replace(" +", "+")
267
+ text = " ".join([self._desegmentword(word) for word in text.split(" ")])
268
+ return text
269
+
270
+ def _desegmentword(self, orig_word: str) -> str:
271
+ """
272
+ Word segmentor that takes a Farasa Segmented Word and removes the '+' signs
273
+
274
+ Example:
275
+ >>> _desegmentword("ال+يومي+ة")
276
+ اليومية
277
+ """
278
+ word = orig_word.replace("ل+ال+", "لل")
279
+ if "ال+ال" not in orig_word:
280
+ word = word.replace("ل+ال", "لل")
281
+ word = word.replace("+", "")
282
+ word = word.replace("للل", "لل")
283
+ return word
284
+
285
+ def _preprocess_v3(self, text: str) -> str:
286
+ text = str(text)
287
+ text = html.unescape(text)
288
+ if self.strip_tashkeel:
289
+ text = araby.strip_tashkeel(text)
290
+ if self.strip_tatweel:
291
+ text = araby.strip_tatweel(text)
292
+
293
+ if self.replace_urls_emails_mentions:
294
+ # replace all possible URLs
295
+ for reg in url_regexes:
296
+ text = re.sub(reg, " [رابط] ", text)
297
+ # REplace Emails with [بريد]
298
+ for reg in email_regexes:
299
+ text = re.sub(reg, " [بريد] ", text)
300
+ # replace mentions with [مستخدم]
301
+ text = re.sub(user_mention_regex, " [مستخدم] ", text)
302
+
303
+ if self.remove_html_markup:
304
+ # remove html line breaks
305
+ text = re.sub("<br />", " ", text)
306
+ # remove html markup
307
+ text = re.sub("</?[^>]+>", " ", text)
308
+
309
+ if self.map_hindi_numbers_to_arabic:
310
+ text = text.translate(hindi_to_arabic_map)
311
+
312
+ # remove repeated characters >2
313
+ if self.remove_non_digit_repetition:
314
+ text = self._remove_non_digit_repetition(text)
315
+
316
+ # insert whitespace before and after all non Arabic digits or English Digits and Alphabet and the 2 brackets
317
+ if self.insert_white_spaces:
318
+ text = re.sub(
319
+ "([^0-9\u0621-\u063A\u0641-\u064A\u0660-\u0669a-zA-Z ])",
320
+ r" \1 ",
321
+ text,
322
+ )
323
+
324
+ # re-fix brackets
325
+ text = text.replace("[ رابط ]", "[رابط]")
326
+ text = text.replace("[ بريد ]", "[بريد]")
327
+ text = text.replace("[ مستخدم ]", "[مستخدم]")
328
+
329
+ # insert whitespace between words and numbers or numbers and words
330
+ text = re.sub(
331
+ "(\d+)([\u0621-\u063A\u0641-\u064A\u066A-\u066C\u0654-\u0655]+)",
332
+ r" \1 \2 ",
333
+ text,
334
+ )
335
+ text = re.sub(
336
+ "([\u0621-\u063A\u0641-\u064A\u066A-\u066C\u0654-\u0655]+)(\d+)",
337
+ r" \1 \2 ",
338
+ text,
339
+ )
340
+
341
+ # remove unwanted characters
342
+ if self.keep_emojis:
343
+ emoji_regex = "".join(list(emoji.UNICODE_EMOJI["en"].keys()))
344
+ rejected_chars_regex2 = "[^%s%s]" % (chars_regexv2, emoji_regex)
345
+ text = re.sub(rejected_chars_regex2, " ", text)
346
+ else:
347
+ text = re.sub(rejected_chars_regexv2, " ", text)
348
+
349
+ # remove extra spaces
350
+ text = " ".join(text.replace("\uFE0F", "").split())
351
+
352
+ if self.apply_farasa_segmentation:
353
+ if self.keep_emojis:
354
+ new_text = []
355
+ for word in text.split():
356
+ if word in list(emoji.UNICODE_EMOJI["en"].keys()):
357
+ new_text.append(word)
358
+ else:
359
+ new_text.append(farasa_segmenter.segment(word))
360
+ text = " ".join(new_text)
361
+ else:
362
+ text = farasa_segmenter.segment(text)
363
+ return self._farasa_segment(text)
364
+
365
+ # ALl the other models dont require Farasa Segmentation
366
+ return text
367
+
368
+ def _preprocess_v2(self, text: str) -> str:
369
+ text = str(text)
370
+ text = html.unescape(text)
371
+ if self.strip_tashkeel:
372
+ text = araby.strip_tashkeel(text)
373
+ if self.strip_tatweel:
374
+ text = araby.strip_tatweel(text)
375
+
376
+ if self.replace_urls_emails_mentions:
377
+ # replace all possible URLs
378
+ for reg in url_regexes:
379
+ text = re.sub(reg, " [رابط] ", text)
380
+ # REplace Emails with [بريد]
381
+ for reg in email_regexes:
382
+ text = re.sub(reg, " [بريد] ", text)
383
+ # replace mentions with [مستخدم]
384
+ text = re.sub(user_mention_regex, " [مستخدم] ", text)
385
+
386
+ if self.remove_html_markup:
387
+ # remove html line breaks
388
+ text = re.sub("<br />", " ", text)
389
+ # remove html markup
390
+ text = re.sub("</?[^>]+>", " ", text)
391
+
392
+ if self.map_hindi_numbers_to_arabic:
393
+ text = text.translate(hindi_to_arabic_map)
394
+
395
+ # remove repeated characters >2
396
+ if self.remove_non_digit_repetition:
397
+ text = self._remove_non_digit_repetition(text)
398
+
399
+ # insert whitespace before and after all non Arabic digits or English Digits and Alphabet and the 2 brackets
400
+ if self.insert_white_spaces:
401
+ text = re.sub(
402
+ "([^0-9\u0621-\u063A\u0641-\u064A\u0660-\u0669a-zA-Z\[\]])",
403
+ r" \1 ",
404
+ text,
405
+ )
406
+
407
+ # insert whitespace between words and numbers or numbers and words
408
+ text = re.sub(
409
+ "(\d+)([\u0621-\u063A\u0641-\u064A\u0660-\u066C]+)", r" \1 \2 ", text
410
+ )
411
+ text = re.sub(
412
+ "([\u0621-\u063A\u0641-\u064A\u0660-\u066C]+)(\d+)", r" \1 \2 ", text
413
+ )
414
+
415
+ if self.replace_slash_with_dash:
416
+ text = text.replace("/", "-")
417
+
418
+ # remove unwanted characters
419
+ if self.keep_emojis:
420
+ emoji_regex = "".join(list(emoji.UNICODE_EMOJI["en"].keys()))
421
+ rejected_chars_regex2 = "[^%s%s]" % (chars_regex, emoji_regex)
422
+ text = re.sub(rejected_chars_regex2, " ", text)
423
+ else:
424
+ text = re.sub(rejected_chars_regex, " ", text)
425
+
426
+ # remove extra spaces
427
+ text = " ".join(text.replace("\uFE0F", "").split())
428
+
429
+ if (
430
+ self.model_name == "bert-base-arabertv2"
431
+ or self.model_name == "bert-large-arabertv2"
432
+ ):
433
+ if self.keep_emojis:
434
+ new_text = []
435
+ for word in text.split():
436
+ if word in list(emoji.UNICODE_EMOJI["en"].keys()):
437
+ new_text.append(word)
438
+ else:
439
+ new_text.append(farasa_segmenter.segment(word))
440
+ text = " ".join(new_text)
441
+ else:
442
+ text = farasa_segmenter.segment(text)
443
+ return self._farasa_segment(text)
444
+
445
+ # ALl the other models dont require Farasa Segmentation
446
+ return text
447
+
448
+ def _preprocess_v1(self, text: str, do_farasa_tokenization: bool) -> str:
449
+ """
450
+ AraBERTv1 preprocessing Function
451
+ """
452
+ text = str(text)
453
+ if self.strip_tashkeel:
454
+ text = araby.strip_tashkeel(text)
455
+
456
+ text = re.sub(r"\d+\/[ء-ي]+\/\d+\]", "", text)
457
+ text = re.sub("ـ", "", text)
458
+ text = re.sub("[«»]", ' " ', text)
459
+
460
+ if self.replace_urls_emails_mentions:
461
+ # replace the [رابط] token with space if you want to clean links
462
+ text = re.sub(regex_url_step1, "[رابط]", text)
463
+ text = re.sub(regex_url_step2, "[رابط]", text)
464
+ text = re.sub(regex_url, "[رابط]", text)
465
+ text = re.sub(regex_email, "[بريد]", text)
466
+ text = re.sub(regex_mention, "[مستخدم]", text)
467
+ text = re.sub("…", r"\.", text).strip()
468
+ text = self._remove_redundant_punct(text)
469
+
470
+ if self.replace_urls_emails_mentions:
471
+ text = re.sub(r"\[ رابط \]|\[ رابط\]|\[رابط \]", " [رابط] ", text)
472
+ text = re.sub(r"\[ بريد \]|\[ بريد\]|\[بريد \]", " [بريد] ", text)
473
+ text = re.sub(r"\[ مستخدم \]|\[ مستخدم\]|\[مستخدم \]", " [مستخدم] ", text)
474
+
475
+ if self.remove_non_digit_repetition:
476
+ text = self._remove_non_digit_repetition(text)
477
+
478
+ if self.insert_white_spaces:
479
+ text = re.sub(
480
+ "([^0-9\u0621-\u063A\u0641-\u0669\u0671-\u0673a-zA-Z\[\]])",
481
+ r" \1 ",
482
+ text,
483
+ )
484
+ if do_farasa_tokenization:
485
+ text = self._tokenize_arabic_words_farasa(text)
486
+
487
+ text = " ".join(text.split())
488
+
489
+ return text
490
+
491
+ def _farasa_segment(self, text: str) -> str:
492
+ line_farasa = text.split()
493
+ segmented_line = []
494
+ for index, word in enumerate(line_farasa):
495
+ if word in ["[", "]"]:
496
+ continue
497
+ if word in ["رابط", "بريد", "مستخدم"] and line_farasa[index - 1] in [
498
+ "[",
499
+ "]",
500
+ ]:
501
+ segmented_line.append("[" + word + "]")
502
+ continue
503
+ if "+" not in word:
504
+ segmented_line.append(word)
505
+ continue
506
+ segmented_word = self._split_farasa_output(word)
507
+ segmented_line.extend(segmented_word)
508
+
509
+ return " ".join(segmented_line)
510
+
511
+ def _split_farasa_output(self, word: str) -> str:
512
+ segmented_word = []
513
+ temp_token = ""
514
+ for i, c in enumerate(word):
515
+ if c == "+":
516
+ # if the token is KAF, it could be a suffix or prefix
517
+ if temp_token == "ك":
518
+ # if we are at the second token, then KAF is surely a prefix
519
+ if i == 1:
520
+ segmented_word.append(temp_token + "+")
521
+ temp_token = ""
522
+ # If the KAF token is between 2 tokens
523
+ elif word[i - 2] == "+":
524
+ # if the previous token is prefix, then this KAF must be a prefix
525
+ if segmented_word[-1][-1] == "+":
526
+ segmented_word.append(temp_token + "+")
527
+ temp_token = ""
528
+ # else it is a suffix, this KAF could not be a second suffix
529
+ else:
530
+ segmented_word.append("+" + temp_token)
531
+ temp_token = ""
532
+ # if Kaf is at the end, this is handled with the statement after the loop
533
+ elif temp_token in prefix_list:
534
+ segmented_word.append(temp_token + "+")
535
+ temp_token = ""
536
+ elif temp_token in suffix_list:
537
+ segmented_word.append("+" + temp_token)
538
+ temp_token = ""
539
+ else:
540
+ segmented_word.append(temp_token)
541
+ temp_token = ""
542
+ continue
543
+ temp_token += c
544
+ if temp_token != "":
545
+ if temp_token in suffix_list:
546
+ segmented_word.append("+" + temp_token)
547
+ else:
548
+ segmented_word.append(temp_token)
549
+ return segmented_word
550
+
551
+ def _tokenize_arabic_words_farasa(self, line_input: str) -> str:
552
+
553
+ if self.keep_emojis:
554
+ # insert whitespace before and after all non Arabic digits or English Digits and Alphabet and the 2 brackets
555
+ line_farasa = []
556
+ for word in line_input.split():
557
+ if word in list(emoji.UNICODE_EMOJI["en"].keys()):
558
+ line_farasa.append(word)
559
+ else:
560
+ line_farasa.append(farasa_segmenter.segment(word))
561
+ else:
562
+ line_farasa = farasa_segmenter.segment(line_input).split()
563
+
564
+ segmented_line = []
565
+ for index, word in enumerate(line_farasa):
566
+ if word in ["[", "]"]:
567
+ continue
568
+ if word in ["رابط", "بريد", "مستخدم"] and line_farasa[index - 1] in [
569
+ "[",
570
+ "]",
571
+ ]:
572
+ segmented_line.append("[" + word + "]")
573
+ continue
574
+ segmented_word = []
575
+ for token in word.split("+"):
576
+ if token in prefix_list:
577
+ segmented_word.append(token + "+")
578
+ elif token in suffix_list:
579
+ segmented_word.append("+" + token)
580
+ else:
581
+ segmented_word.append(token)
582
+ segmented_line.extend(segmented_word)
583
+ return " ".join(segmented_line)
584
+
585
+ def _remove_non_digit_repetition(self, text: str) -> str:
586
+ """
587
+ :param text: the input text to remove elongation
588
+ :return: delongated text
589
+ """
590
+ # loop over the number of times the regex matched the text
591
+ # OLD
592
+ # for index_ in range(len(re.findall(regex_tatweel, text))):
593
+ # elongation = re.search(regex_tatweel, text)
594
+ # if elongation:
595
+ # elongation_pattern = elongation.group()
596
+ # elongation_replacement = elongation_pattern[0]
597
+ # elongation_pattern = re.escape(elongation_pattern)
598
+ # text = re.sub(
599
+ # elongation_pattern, elongation_replacement, text, flags=re.MULTILINE
600
+ # )
601
+ # else:
602
+ # break
603
+
604
+ # New
605
+ text = multiple_char_pattern.sub(r"\1\1", text)
606
+ return text
607
+
608
+ def _remove_redundant_punct(self, text: str) -> str:
609
+ text_ = text
610
+ result = re.search(redundant_punct_pattern, text)
611
+ dif = 0
612
+ while result:
613
+ sub = result.group()
614
+ sub = sorted(set(sub), key=sub.index)
615
+ sub = " " + "".join(list(sub)) + " "
616
+ text = "".join(
617
+ (text[: result.span()[0] + dif], sub, text[result.span()[1] + dif :])
618
+ )
619
+ text_ = "".join(
620
+ (text_[: result.span()[0]], text_[result.span()[1] :])
621
+ ).strip()
622
+ dif = abs(len(text) - len(text_))
623
+ result = re.search(redundant_punct_pattern, text_)
624
+ text = re.sub(r"\s+", " ", text)
625
+ return text.strip()
626
+
627
+
628
+ prefix_list = [
629
+ "ال",
630
+ "و",
631
+ "ف",
632
+ "ب",
633
+ "ك",
634
+ "ل",
635
+ "لل",
636
+ "\u0627\u0644",
637
+ "\u0648",
638
+ "\u0641",
639
+ "\u0628",
640
+ "\u0643",
641
+ "\u0644",
642
+ "\u0644\u0644",
643
+ "س",
644
+ ]
645
+ suffix_list = [
646
+ "ه",
647
+ "ها",
648
+ "ك",
649
+ "ي",
650
+ "هما",
651
+ "كما",
652
+ "نا",
653
+ "كم",
654
+ "هم",
655
+ "هن",
656
+ "كن",
657
+ "ا",
658
+ "ان",
659
+ "ين",
660
+ "ون",
661
+ "وا",
662
+ "ات",
663
+ "ت",
664
+ "ن",
665
+ "ة",
666
+ "\u0647",
667
+ "\u0647\u0627",
668
+ "\u0643",
669
+ "\u064a",
670
+ "\u0647\u0645\u0627",
671
+ "\u0643\u0645\u0627",
672
+ "\u0646\u0627",
673
+ "\u0643\u0645",
674
+ "\u0647\u0645",
675
+ "\u0647\u0646",
676
+ "\u0643\u0646",
677
+ "\u0627",
678
+ "\u0627\u0646",
679
+ "\u064a\u0646",
680
+ "\u0648\u0646",
681
+ "\u0648\u0627",
682
+ "\u0627\u062a",
683
+ "\u062a",
684
+ "\u0646",
685
+ "\u0629",
686
+ ]
687
+ other_tokens = ["[رابط]", "[مستخدم]", "[بريد]"]
688
+
689
+ # the never_split list is ussed with the transformers library
690
+ prefix_symbols = [x + "+" for x in prefix_list]
691
+ suffix_symblos = ["+" + x for x in suffix_list]
692
+ never_split_tokens = list(set(prefix_symbols + suffix_symblos + other_tokens))
693
+
694
+ url_regexes = [
695
+ r"(http(s)?:\/\/.)?(www\.)?[-a-zA-Z0-9@:%._\+~#=]{2,256}\.[a-z]{2,6}\b([-a-zA-Z0-9@:%_\+.~#?&//=]*)",
696
+ r"@(https?|ftp)://(-\.)?([^\s/?\.#-]+\.?)+(/[^\s]*)?$@iS",
697
+ r"http[s]?://[a-zA-Z0-9_\-./~\?=%&]+",
698
+ r"www[a-zA-Z0-9_\-?=%&/.~]+",
699
+ r"[a-zA-Z]+\.com",
700
+ r"(?=http)[^\s]+",
701
+ r"(?=www)[^\s]+",
702
+ r"://",
703
+ ]
704
+ user_mention_regex = r"@[\w\d]+"
705
+ email_regexes = [r"[\w-]+@([\w-]+\.)+[\w-]+", r"\S+@\S+"]
706
+ redundant_punct_pattern = (
707
+ r"([!\"#\$%\'\(\)\*\+,\.:;\-<=·>?@\[\\\]\^_ـ`{\|}~—٪’،؟`୍“؛”ۚ【»؛\s+«–…‘]{2,})"
708
+ )
709
+
710
+ regex_tatweel = r"(\D)\1{2,}"
711
+ multiple_char_pattern = re.compile(r"(\D)\1{2,}", re.DOTALL)
712
+
713
+ rejected_chars_regex = r"[^0-9\u0621-\u063A\u0640-\u066C\u0671-\u0674a-zA-Z\[\]!\"#\$%\'\(\)\*\+,\.:;\-<=·>?@\[\\\]\^_ـ`{\|}~—٪’،؟`୍“؛”ۚ»؛\s+«–…‘]"
714
+ rejected_chars_regexv2 = r"[^0-9\u0621-\u063A\u0641-\u066C\u0671-\u0674a-zA-Z\[\]!\"#\$%\'\(\)\*\+,\.:;\-<=·>?@\[\\\]\^_ـ`{\|}~—٪’،؟`୍“؛”ۚ»؛\s+«–…‘/]"
715
+
716
+ regex_url_step1 = r"(?=http)[^\s]+"
717
+ regex_url_step2 = r"(?=www)[^\s]+"
718
+ regex_url = r"(http(s)?:\/\/.)?(www\.)?[-a-zA-Z0-9@:%._\+~#=]{2,256}\.[a-z]{2,6}\b([-a-zA-Z0-9@:%_\+.~#?&//=]*)"
719
+ regex_mention = r"@[\w\d]+"
720
+ regex_email = r"\S+@\S+"
721
+
722
+ chars_regex = r"0-9\u0621-\u063A\u0640-\u066C\u0671-\u0674a-zA-Z\[\]!\"#\$%\'\(\)\*\+,\.:;\-<=·>?@\[\\\]\^_ـ`{\|}~—٪’،؟`୍“؛”ۚ»؛\s+«–…‘"
723
+ chars_regexv2 = r"0-9\u0621-\u063A\u0640-\u066C\u0671-\u0674a-zA-Z\[\]!\"#\$%\'\(\)\*\+,\.:;\-<=·>?@\[\\\]\^_ـ`{\|}~—٪’،؟`୍“؛”ۚ»؛\s+«–…‘/"
724
+
725
+ white_spaced_double_quotation_regex = r'\"\s+([^"]+)\s+\"'
726
+ white_spaced_single_quotation_regex = r"\'\s+([^']+)\s+\'"
727
+ white_spaced_back_quotation_regex = r"\`\s+([^`]+)\s+\`"
728
+ white_spaced_em_dash = r"\—\s+([^—]+)\s+\—"
729
+
730
+ left_spaced_chars = r" ([\]!#\$%\),\.:;\?}٪’،؟”؛…»·])"
731
+ right_spaced_chars = r"([\[\(\{“«‘*\~]) "
732
+ left_and_right_spaced_chars = r" ([\+\-\<\=\>\@\\\^\_\|\–]) "
733
+
734
+ hindi_nums = "٠١٢٣٤٥٦٧٨٩"
735
+ arabic_nums = "0123456789"
736
+ hindi_to_arabic_map = str.maketrans(hindi_nums, arabic_nums)
backend/processor.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import awesome_streamlit as ast
3
+ from .preprocess import (
4
+ ArabertPreprocessor,
5
+ white_spaced_back_quotation_regex,
6
+ white_spaced_double_quotation_regex,
7
+ white_spaced_em_dash,
8
+ white_spaced_single_quotation_regex,
9
+ left_and_right_spaced_chars,
10
+ left_spaced_chars,
11
+ right_spaced_chars,
12
+ )
13
+ import re
14
+
15
+ MODELS_to_SELECT = [
16
+ "None",
17
+ "bert-base-arabertv01",
18
+ "bert-base-arabert",
19
+ "bert-base-arabertv02",
20
+ "bert-base-arabertv2",
21
+ "bert-large-arabertv02",
22
+ "bert-large-arabertv2",
23
+ "araelectra-base",
24
+ "araelectra-base-discriminator",
25
+ "araelectra-base-generator",
26
+ "araelectra-base-artydiqa",
27
+ "aragpt2-base",
28
+ "aragpt2-medium",
29
+ "aragpt2-large",
30
+ "aragpt2-mega",
31
+ ]
32
+
33
+
34
+ def unpreprocess(text: str) -> str:
35
+ """Re-formats the text to a classic format where punctuations, brackets, parenthesis are not seperated by whitespaces.
36
+ The objective is to make the generated text of any model appear natural and not preprocessed.
37
+
38
+ Args:
39
+ text (:obj:`str`): input text to be un-preprocessed
40
+ desegment (:obj:`bool`, optional): [whether or not to remove farasa pre-segmentation before]..
41
+
42
+ Returns:
43
+ str: The unpreprocessed (and possibly Farasa-desegmented) text.
44
+ """
45
+
46
+ text = desegment(text)
47
+
48
+ # removes the spaces around quotation marks ex: i " ate " an apple --> i "ate" an apple
49
+ # https://stackoverflow.com/a/53436792/5381220
50
+ text = re.sub(white_spaced_double_quotation_regex, '"' + r"\1" + '"', text)
51
+ text = re.sub(white_spaced_single_quotation_regex, "'" + r"\1" + "'", text)
52
+ text = re.sub(white_spaced_back_quotation_regex, "\`" + r"\1" + "\`", text)
53
+ text = re.sub(white_spaced_back_quotation_regex, "\—" + r"\1" + "\—", text)
54
+
55
+ # during generation, sometimes the models don't put a space after the dot, this handles it
56
+ text = text.replace(".", " . ")
57
+ text = " ".join(text.split())
58
+
59
+ # handle decimals
60
+ text = re.sub(r"(\d+) \. (\d+)", r"\1.\2", text)
61
+ text = re.sub(r"(\d+) \, (\d+)", r"\1,\2", text)
62
+
63
+ text = re.sub(left_and_right_spaced_chars, r"\1", text)
64
+ text = re.sub(left_spaced_chars, r"\1", text)
65
+ text = re.sub(right_spaced_chars, r"\1", text)
66
+
67
+ return text
68
+
69
+
70
+ def desegment(text: str) -> str:
71
+ """
72
+ Use this function if sentence tokenization was done using
73
+ `from arabert.preprocess_arabert import preprocess` with Farasa enabled
74
+ AraBERT segmentation using Farasa adds a space after the '+' for prefixes,
75
+ and after before the '+' for suffixes
76
+
77
+ Example:
78
+ >>> desegment('ال+ دراس +ات')
79
+ الدراسات
80
+ """
81
+ text = text.replace("+ ", "+")
82
+ text = text.replace(" +", "+")
83
+ text = " ".join([_desegmentword(word) for word in text.split(" ")])
84
+ return text
85
+
86
+
87
+ def _desegmentword(orig_word: str) -> str:
88
+ """
89
+ Word segmentor that takes a Farasa Segmented Word and removes the '+' signs
90
+
91
+ Example:
92
+ >>> _desegmentword("ال+يومي+ة")
93
+ اليومية
94
+ """
95
+ word = orig_word.replace("ل+ال+", "لل")
96
+ if "ال+ال" not in orig_word:
97
+ word = word.replace("ل+ال", "لل")
98
+ word = word.replace("+", "")
99
+ word = word.replace("للل", "لل")
100
+ return word
101
+
102
+
103
+ def write():
104
+
105
+ st.markdown(
106
+ """
107
+ <h1 style="text-align:left;">Arabic Text Pre-Processor</h1>
108
+ """,
109
+ unsafe_allow_html=True,
110
+ )
111
+ st.markdown(
112
+ """
113
+ <style>
114
+ p, div, input, label {
115
+ text-align: right;
116
+ }
117
+ </style>
118
+ """,
119
+ unsafe_allow_html=True,
120
+ )
121
+ input_text = st.text_input(
122
+ "Text to Pre-Process",
123
+ value="ولن نبالغ إذا قلنا: إن 'هاتف' أو 'كمبيوتر المكتب' في زمننا هذا ضروري",
124
+ )
125
+
126
+ st.sidebar.title("Model Selector")
127
+ model_selector = st.sidebar.selectbox(
128
+ """Select None to enable further filters""", options=MODELS_to_SELECT, index=3
129
+ )
130
+ if model_selector == "None":
131
+ keep_emojis = st.sidebar.checkbox("Keep emojis", False)
132
+ remove_html_markup = st.sidebar.checkbox("Remove html markup", True)
133
+ strip_tashkeel = st.sidebar.checkbox("Strip tashkeel", True)
134
+ replace_urls_emails_mentions = st.sidebar.checkbox(
135
+ "Replace urls and emails", True
136
+ )
137
+ strip_tatweel = st.sidebar.checkbox("Strip tatweel", True)
138
+ insert_white_spaces = st.sidebar.checkbox("Insert white spaces", True)
139
+ remove_non_digit_repetition = st.sidebar.checkbox(
140
+ "Remove non-digit repetition", True
141
+ )
142
+ replace_slash_with_dash = st.sidebar.checkbox("Replace slash with dash", None)
143
+ map_hindi_numbers_to_arabic = st.sidebar.checkbox(
144
+ "Map hindi numbers to arabic", None
145
+ )
146
+ apply_farasa_segmentation = st.sidebar.checkbox(
147
+ "Apply farasa segmentation", None
148
+ )
149
+
150
+ run_preprocessor = st.button("Run Pre-Processor")
151
+
152
+ prep_text = None
153
+ if run_preprocessor:
154
+ if model_selector == "None":
155
+ arabert_preprocessor = ArabertPreprocessor(
156
+ model_selector,
157
+ keep_emojis,
158
+ remove_html_markup,
159
+ replace_urls_emails_mentions,
160
+ strip_tashkeel,
161
+ strip_tatweel,
162
+ insert_white_spaces,
163
+ remove_non_digit_repetition,
164
+ replace_slash_with_dash,
165
+ map_hindi_numbers_to_arabic,
166
+ apply_farasa_segmentation,
167
+ )
168
+ else:
169
+ arabert_preprocessor = ArabertPreprocessor(model_name=model_selector)
170
+ prep_text = arabert_preprocessor._preprocess_v3(input_text)
171
+ st.write(prep_text)
172
+
173
+ st.write("-----")
174
+ input_text_unprep = st.text_input(
175
+ "Text to Undo the Pre-Processing",
176
+ value=prep_text
177
+ if prep_text
178
+ else "و+ لن نبالغ إذا قل +نا : إن ' هاتف ' أو ' كمبيوتر ال+ مكتب ' في زمن +نا هذا ضروري",
179
+ )
180
+ run_unpreprocessor = st.button("Run Un-Pre-Processor")
181
+
182
+ if run_unpreprocessor:
183
+ st.write(unpreprocess(input_text_unprep))
backend/qa.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ from .qa_utils import annotate_answer
4
+ from .services import get_qa_answers
5
+
6
+
7
+ def write():
8
+ _, col1, _ = st.columns(3)
9
+
10
+ with col1:
11
+ st.image("images/is2alni_logo.png", width=200)
12
+ st.title("إسألني أي شيء")
13
+
14
+ st.markdown(
15
+ """
16
+ <style>
17
+ p, div, input, label {
18
+ text-align: right;
19
+ }
20
+ </style>
21
+ """,
22
+ unsafe_allow_html=True,
23
+ )
24
+
25
+ st.sidebar.header("Info")
26
+ st.sidebar.image("images/AraELECTRA.png", width=150)
27
+ st.sidebar.write("Powered by [AraELECTRA](https://github.com/aub-mind/arabert)")
28
+
29
+ st.sidebar.write("\n")
30
+ n_answers = st.sidebar.slider(
31
+ "Max. number of answers", min_value=1, max_value=10, value=2, step=1
32
+ )
33
+
34
+ question = st.text_input("", value="من هو جو بايدن؟")
35
+ if "؟" not in question:
36
+ question += "؟"
37
+
38
+ run_query = st.button("أجب")
39
+ if run_query:
40
+ # https://discuss.streamlit.io/t/showing-a-gif-while-st-spinner-runs/5084
41
+ with st.spinner("... جاري البحث "):
42
+ results_dict = get_qa_answers(question)
43
+
44
+ if len(results_dict) > 0:
45
+ st.write("## :الأجابات هي")
46
+ for result in results_dict["results"][:n_answers]:
47
+ annotate_answer(result)
48
+ f"[**المصدر**](<{result['link']}>)"
49
+ else:
50
+ st.write("## 😞 ليس لدي جواب")
backend/qa_utils.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit.components.v1
2
+
3
+ from htbuilder import HtmlElement, div, span, styles
4
+ from htbuilder.units import px, rem, em
5
+
6
+
7
+ def annotation(body, label="", background="#ddd", color="#333", **style):
8
+ """Build an HtmlElement span object with the given body and annotation label.
9
+
10
+ The end result will look something like this:
11
+
12
+ [body | label]
13
+
14
+ Parameters
15
+ ----------
16
+ body : string
17
+ The string to put in the "body" part of the annotation.
18
+ label : string
19
+ The string to put in the "label" part of the annotation.
20
+ background : string
21
+ The color to use for the background "chip" containing this annotation.
22
+ color : string
23
+ The color to use for the body and label text.
24
+ **style : dict
25
+ Any CSS you want to use to customize the containing "chip".
26
+
27
+ Examples
28
+ --------
29
+
30
+ Produce a simple annotation with default colors:
31
+
32
+ >>> annotation("apple", "fruit")
33
+
34
+ Produce an annotation with custom colors:
35
+
36
+ >>> annotation("apple", "fruit", background="#FF0", color="black")
37
+
38
+ Produce an annotation with crazy CSS:
39
+
40
+ >>> annotation("apple", "fruit", background="#FF0", border="1px dashed red")
41
+
42
+ """
43
+
44
+ if "font_family" not in style:
45
+ style["font_family"] = "sans-serif"
46
+
47
+ return span(
48
+ style=styles(
49
+ background=background,
50
+ border_radius=rem(0.33),
51
+ color=color,
52
+ padding=(rem(0.17), rem(0.67)),
53
+ display="inline-flex",
54
+ justify_content="center",
55
+ align_items="center",
56
+ **style,
57
+ )
58
+ )(
59
+ body,
60
+ span(
61
+ style=styles(
62
+ color=color,
63
+ font_size=em(0.67),
64
+ opacity=0.5,
65
+ padding_left=rem(0.5),
66
+ text_transform="uppercase",
67
+ margin_bottom=px(-2),
68
+ )
69
+ )(label),
70
+ )
71
+
72
+
73
+ def annotated_text(*args, **kwargs):
74
+ """Writes test with annotations into your Streamlit app.
75
+
76
+ Parameters
77
+ ----------
78
+ *args : str, tuple or htbuilder.HtmlElement
79
+ Arguments can be:
80
+ - strings, to draw the string as-is on the screen.
81
+ - tuples of the form (main_text, annotation_text, background, color) where
82
+ background and foreground colors are optional and should be an CSS-valid string such as
83
+ "#aabbcc" or "rgb(10, 20, 30)"
84
+ - HtmlElement objects in case you want to customize the annotations further. In particular,
85
+ you can import the `annotation()` function from this module to easily produce annotations
86
+ whose CSS you can customize via keyword arguments.
87
+
88
+ Examples
89
+ --------
90
+
91
+ >>> annotated_text(
92
+ ... "This ",
93
+ ... ("is", "verb", "#8ef"),
94
+ ... " some ",
95
+ ... ("annotated", "adj", "#faa"),
96
+ ... ("text", "noun", "#afa"),
97
+ ... " for those of ",
98
+ ... ("you", "pronoun", "#fea"),
99
+ ... " who ",
100
+ ... ("like", "verb", "#8ef"),
101
+ ... " this sort of ",
102
+ ... ("thing", "noun", "#afa"),
103
+ ... )
104
+
105
+ >>> annotated_text(
106
+ ... "Hello ",
107
+ ... annotation("world!", "noun", color="#8ef", border="1px dashed red"),
108
+ ... )
109
+
110
+ """
111
+ out = div(
112
+ style=styles(
113
+ font_family="sans-serif",
114
+ line_height="1.45",
115
+ font_size=px(16),
116
+ text_align="right",
117
+ )
118
+ )
119
+
120
+ for arg in args:
121
+ if isinstance(arg, str):
122
+ out(arg)
123
+
124
+ elif isinstance(arg, HtmlElement):
125
+ out(arg)
126
+
127
+ elif isinstance(arg, tuple):
128
+ out(annotation(*arg))
129
+
130
+ else:
131
+ raise Exception("Oh noes!")
132
+
133
+ streamlit.components.v1.html(str(out), **kwargs)
134
+
135
+
136
+ def shorten_text(text, n, reverse=False):
137
+ if text.isspace() or text == "":
138
+ return text
139
+ if reverse:
140
+ text = text[::-1]
141
+ words = iter(text.split())
142
+ lines, current = [], next(words)
143
+ for word in words:
144
+ if len(current) + 1 + len(word) > n:
145
+ break
146
+ else:
147
+ current += " " + word
148
+ lines.append(current)
149
+ if reverse:
150
+ return lines[0][::-1]
151
+ return lines[0]
152
+
153
+
154
+ def annotate_answer(result):
155
+ annotated_text(
156
+ shorten_text(
157
+ result["original"][: result["new_start"]],
158
+ 500,
159
+ reverse=True,
160
+ ),
161
+ (result["new_answer"], "جواب", "#8ef"),
162
+ shorten_text(result["original"][result["new_end"] :], 500) + " ...... إلخ",
163
+ )
backend/sa.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from .services import SentimentAnalyzer
3
+ from functools import lru_cache
4
+
5
+ # @st.cache(allow_output_mutation=False, hash_funcs={Tokenizer: str})
6
+ @lru_cache(maxsize=1)
7
+ def load_text_generator():
8
+ predictor = SentimentAnalyzer()
9
+ return predictor
10
+
11
+
12
+ predictor = load_text_generator()
13
+
14
+
15
+ def write():
16
+ st.markdown(
17
+ """
18
+ # Arabic Sentiment Analysis
19
+
20
+ This is a simple sentiment analysis app that uses the prediction kernel from Wissam's (me) submission that won the [Arabic Senitment Analysis competition @ KAUST](https://www.kaggle.com/c/arabic-sentiment-analysis-2021-kaust)
21
+ """
22
+ )
23
+ if st.checkbox("More info: "):
24
+ st.markdown(
25
+ """
26
+ ### Submission Description:
27
+
28
+ My submission is based on an ensemble of 5 models with varying preprocessing, and classifier design. All model variants are built over MARBERT [1], which is a BERT-based model pre-trained on 1B dialectal Arabic tweets.
29
+
30
+ For preprocessing, all models shared the following steps:
31
+ - Replacing user mentions with “USER” and links with “URL”
32
+ - Replacing the “#” with “HASH”
33
+ - Removed the underscore character since it is missing the MARBERT vocabulary.
34
+ - Removed diacritics and elongations (tatweel)
35
+ - Spacing out emojis
36
+
37
+ For classifier design, all models use a dense layer on top of MARBERT unless otherwise specified. Model training is done by hyperparameter grid-search with 5-fold cross-validation with the following search space:
38
+ - Learning rate: [2e-5,3e-5,4e-5]
39
+ - Batch size: 128
40
+ - Maximum sequence length: 64
41
+ - Epochs: 3 (we select the best epoch for the final prediction)
42
+ - Warmup ratio: [0,0.1]
43
+ - Seed: [1,25,42,123,666]
44
+
45
+ Model I is a vanilla variant with only the preprocessing steps mention above applied. Model II enhances the emoji representation by replacing OOV emojis with ones that have similar meaning, for example 💊  😷.
46
+ We noticed the repetitive use of “السلام عليكم” and “ورحمة الله وبركاته” in neutral tweets, especially when users were directing questions to business accounts. This could confuse the classifier, if it encountered these words in a for example a negative tweet, hence in Model III we removed variation of the phrase mentioned before using fuzzy matching algorithms.
47
+
48
+ In Model IV, we tried to help the model by appending a sarcasm label to the input. We first trained a separate MARBERT on the ArSarcasm [2] dataset, and then used it to label the training and test sets.
49
+
50
+ Model V uses the vanilla preprocessing approach, but instead of a dense layer built on top of MARBERT, we follow the approach detailed by Safaya et.al. [3] which uses a CNN-based classifier instead.
51
+
52
+ For the final prediction, we first average the predictions of the 5 models from cross-validation (this is done for each model separately), we then average the results from the 5 model variants. We observed that the distribution of the predicted sentiment classes, doesn’t quite match the true distribution, this is due to the model preferring the neutral class over the positive class. To counter that, we apply what we call Label-Weighted average where during after the final averaging we rescale the score with the following weights 1.57,0.98 and 0.93 for positive, neutral, and negative (note that the weights were determined empirically).
53
+
54
+ 1- https://aclanthology.org/2021.acl-long.551/
55
+
56
+ 2- https://github.com/iabufarha/ArSarcasm
57
+
58
+ 3- https://github.com/alisafaya/OffensEval2020
59
+
60
+
61
+ """
62
+ )
63
+ input_text = st.text_input(
64
+ "Enter your text here:",
65
+ )
66
+ if st.button("Predict"):
67
+ with st.spinner("Predicting..."):
68
+ prediction, score, all_score = predictor.predict([input_text])
69
+ st.write(f"Result: {prediction[0]}")
70
+ detailed_score = {
71
+ "Positive": all_score[0][0],
72
+ "Neutral": all_score[0][1],
73
+ "Negative": all_score[0][2],
74
+ }
75
+ st.write("All scores:")
76
+ st.write(detailed_score)
backend/sa_utils.py ADDED
@@ -0,0 +1,510 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from contextlib import contextmanager
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from fuzzysearch import find_near_matches
8
+ from pyarabic import araby
9
+ from torch import nn
10
+ from transformers import AutoTokenizer, BertModel, BertPreTrainedModel, pipeline
11
+ from transformers.modeling_outputs import SequenceClassifierOutput
12
+
13
+ from .preprocess import ArabertPreprocessor, url_regexes, user_mention_regex
14
+
15
+ multiple_char_pattern = re.compile(r"(.)\1{2,}", re.DOTALL)
16
+
17
+ # ASAD-NEW_AraBERT_PREP-Balanced
18
+ class NewArabicPreprocessorBalanced(ArabertPreprocessor):
19
+ def __init__(
20
+ self,
21
+ model_name: str,
22
+ keep_emojis: bool = False,
23
+ remove_html_markup: bool = True,
24
+ replace_urls_emails_mentions: bool = True,
25
+ strip_tashkeel: bool = True,
26
+ strip_tatweel: bool = True,
27
+ insert_white_spaces: bool = True,
28
+ remove_non_digit_repetition: bool = True,
29
+ replace_slash_with_dash: bool = None,
30
+ map_hindi_numbers_to_arabic: bool = None,
31
+ apply_farasa_segmentation: bool = None,
32
+ ):
33
+ if "UBC-NLP" in model_name or "CAMeL-Lab" in model_name:
34
+ keep_emojis = True
35
+ remove_non_digit_repetition = True
36
+ super().__init__(
37
+ model_name=model_name,
38
+ keep_emojis=keep_emojis,
39
+ remove_html_markup=remove_html_markup,
40
+ replace_urls_emails_mentions=replace_urls_emails_mentions,
41
+ strip_tashkeel=strip_tashkeel,
42
+ strip_tatweel=strip_tatweel,
43
+ insert_white_spaces=insert_white_spaces,
44
+ remove_non_digit_repetition=remove_non_digit_repetition,
45
+ replace_slash_with_dash=replace_slash_with_dash,
46
+ map_hindi_numbers_to_arabic=map_hindi_numbers_to_arabic,
47
+ apply_farasa_segmentation=apply_farasa_segmentation,
48
+ )
49
+ self.true_model_name = model_name
50
+
51
+ def preprocess(self, text):
52
+ if "UBC-NLP" in self.true_model_name:
53
+ return self.ubc_prep(text)
54
+
55
+ def ubc_prep(self, text):
56
+ text = re.sub("\s", " ", text)
57
+ text = text.replace("\\n", " ")
58
+ text = text.replace("\\r", " ")
59
+ text = araby.strip_tashkeel(text)
60
+ text = araby.strip_tatweel(text)
61
+ # replace all possible URLs
62
+ for reg in url_regexes:
63
+ text = re.sub(reg, " URL ", text)
64
+ text = re.sub("(URL\s*)+", " URL ", text)
65
+ # replace mentions with USER
66
+ text = re.sub(user_mention_regex, " USER ", text)
67
+ text = re.sub("(USER\s*)+", " USER ", text)
68
+ # replace hashtags with HASHTAG
69
+ # text = re.sub(r"#[\w\d]+", " HASH TAG ", text)
70
+ text = text.replace("#", " HASH ")
71
+ text = text.replace("_", " ")
72
+ text = " ".join(text.split())
73
+ # text = re.sub("\B\\[Uu]\w+", "", text)
74
+ text = text.replace("\\U0001f97a", "🥺")
75
+ text = text.replace("\\U0001f928", "🤨")
76
+ text = text.replace("\\U0001f9d8", "😀")
77
+ text = text.replace("\\U0001f975", "😥")
78
+ text = text.replace("\\U0001f92f", "😲")
79
+ text = text.replace("\\U0001f92d", "🤭")
80
+ text = text.replace("\\U0001f9d1", "😐")
81
+ text = text.replace("\\U000e0067", "")
82
+ text = text.replace("\\U000e006e", "")
83
+ text = text.replace("\\U0001f90d", "♥")
84
+ text = text.replace("\\U0001f973", "🎉")
85
+ text = text.replace("\\U0001fa79", "")
86
+ text = text.replace("\\U0001f92b", "🤐")
87
+ text = text.replace("\\U0001f9da", "🦋")
88
+ text = text.replace("\\U0001f90e", "♥")
89
+ text = text.replace("\\U0001f9d0", "🧐")
90
+ text = text.replace("\\U0001f9cf", "")
91
+ text = text.replace("\\U0001f92c", "😠")
92
+ text = text.replace("\\U0001f9f8", "😸")
93
+ text = text.replace("\\U0001f9b6", "💩")
94
+ text = text.replace("\\U0001f932", "🤲")
95
+ text = text.replace("\\U0001f9e1", "🧡")
96
+ text = text.replace("\\U0001f974", "☹")
97
+ text = text.replace("\\U0001f91f", "")
98
+ text = text.replace("\\U0001f9fb", "💩")
99
+ text = text.replace("\\U0001f92a", "🤪")
100
+ text = text.replace("\\U0001f9fc", "")
101
+ text = text.replace("\\U000e0065", "")
102
+ text = text.replace("\\U0001f92e", "💩")
103
+ text = text.replace("\\U000e007f", "")
104
+ text = text.replace("\\U0001f970", "🥰")
105
+ text = text.replace("\\U0001f929", "🤩")
106
+ text = text.replace("\\U0001f6f9", "")
107
+ text = text.replace("🤍", "♥")
108
+ text = text.replace("🦠", "😷")
109
+ text = text.replace("🤢", "مقرف")
110
+ text = text.replace("🤮", "مقرف")
111
+ text = text.replace("🕠", "⌚")
112
+ text = text.replace("🤬", "😠")
113
+ text = text.replace("🤧", "😷")
114
+ text = text.replace("🥳", "🎉")
115
+ text = text.replace("🥵", "🔥")
116
+ text = text.replace("🥴", "☹")
117
+ text = text.replace("🤫", "🤐")
118
+ text = text.replace("🤥", "كذاب")
119
+ text = text.replace("\\u200d", " ")
120
+ text = text.replace("u200d", " ")
121
+ text = text.replace("\\u200c", " ")
122
+ text = text.replace("u200c", " ")
123
+ text = text.replace('"', "'")
124
+ text = text.replace("\\xa0", "")
125
+ text = text.replace("\\u2066", " ")
126
+ text = re.sub("\B\\\[Uu]\w+", "", text)
127
+ text = super(NewArabicPreprocessorBalanced, self).preprocess(text)
128
+
129
+ text = " ".join(text.split())
130
+ return text
131
+
132
+
133
+ """CNNMarbertArabicPreprocessor"""
134
+ # ASAD-CNN_MARBERT
135
+ class CNNMarbertArabicPreprocessor(ArabertPreprocessor):
136
+ def __init__(
137
+ self,
138
+ model_name,
139
+ keep_emojis=False,
140
+ remove_html_markup=True,
141
+ replace_urls_emails_mentions=True,
142
+ remove_elongations=True,
143
+ ):
144
+ if "UBC-NLP" in model_name or "CAMeL-Lab" in model_name:
145
+ keep_emojis = True
146
+ remove_elongations = False
147
+ super().__init__(
148
+ model_name,
149
+ keep_emojis,
150
+ remove_html_markup,
151
+ replace_urls_emails_mentions,
152
+ remove_elongations,
153
+ )
154
+ self.true_model_name = model_name
155
+
156
+ def preprocess(self, text):
157
+ if "UBC-NLP" in self.true_model_name:
158
+ return self.ubc_prep(text)
159
+
160
+ def ubc_prep(self, text):
161
+ text = re.sub("\s", " ", text)
162
+ text = text.replace("\\n", " ")
163
+ text = araby.strip_tashkeel(text)
164
+ text = araby.strip_tatweel(text)
165
+ # replace all possible URLs
166
+ for reg in url_regexes:
167
+ text = re.sub(reg, " URL ", text)
168
+ text = re.sub("(URL\s*)+", " URL ", text)
169
+ # replace mentions with USER
170
+ text = re.sub(user_mention_regex, " USER ", text)
171
+ text = re.sub("(USER\s*)+", " USER ", text)
172
+ # replace hashtags with HASHTAG
173
+ # text = re.sub(r"#[\w\d]+", " HASH TAG ", text)
174
+ text = text.replace("#", " HASH ")
175
+ text = text.replace("_", " ")
176
+ text = " ".join(text.split())
177
+ text = super(CNNMarbertArabicPreprocessor, self).preprocess(text)
178
+ text = text.replace("\u200d", " ")
179
+ text = text.replace("u200d", " ")
180
+ text = text.replace("\u200c", " ")
181
+ text = text.replace("u200c", " ")
182
+ text = text.replace('"', "'")
183
+ # text = re.sub('[\d\.]+', ' NUM ', text)
184
+ # text = re.sub('(NUM\s*)+', ' NUM ', text)
185
+ text = multiple_char_pattern.sub(r"\1\1", text)
186
+ text = " ".join(text.split())
187
+ return text
188
+
189
+
190
+ """Trial5ArabicPreprocessor"""
191
+
192
+
193
+ class Trial5ArabicPreprocessor(ArabertPreprocessor):
194
+ def __init__(
195
+ self,
196
+ model_name,
197
+ keep_emojis=False,
198
+ remove_html_markup=True,
199
+ replace_urls_emails_mentions=True,
200
+ ):
201
+ if "UBC-NLP" in model_name:
202
+ keep_emojis = True
203
+ super().__init__(
204
+ model_name, keep_emojis, remove_html_markup, replace_urls_emails_mentions
205
+ )
206
+ self.true_model_name = model_name
207
+
208
+ def preprocess(self, text):
209
+ if "UBC-NLP" in self.true_model_name:
210
+ return self.ubc_prep(text)
211
+
212
+ def ubc_prep(self, text):
213
+ text = re.sub("\s", " ", text)
214
+ text = text.replace("\\n", " ")
215
+ text = araby.strip_tashkeel(text)
216
+ text = araby.strip_tatweel(text)
217
+ # replace all possible URLs
218
+ for reg in url_regexes:
219
+ text = re.sub(reg, " URL ", text)
220
+ # replace mentions with USER
221
+ text = re.sub(user_mention_regex, " USER ", text)
222
+ # replace hashtags with HASHTAG
223
+ # text = re.sub(r"#[\w\d]+", " HASH TAG ", text)
224
+ text = text.replace("#", " HASH TAG ")
225
+ text = text.replace("_", " ")
226
+ text = " ".join(text.split())
227
+ text = super(Trial5ArabicPreprocessor, self).preprocess(text)
228
+ # text = text.replace("السلام عليكم"," ")
229
+ # text = text.replace(find_near_matches("السلام عليكم",text,max_deletions=3,max_l_dist=3)[0].matched," ")
230
+ return text
231
+
232
+
233
+ """SarcasmArabicPreprocessor"""
234
+
235
+
236
+ class SarcasmArabicPreprocessor(ArabertPreprocessor):
237
+ def __init__(
238
+ self,
239
+ model_name,
240
+ keep_emojis=False,
241
+ remove_html_markup=True,
242
+ replace_urls_emails_mentions=True,
243
+ ):
244
+ if "UBC-NLP" in model_name:
245
+ keep_emojis = True
246
+ super().__init__(
247
+ model_name, keep_emojis, remove_html_markup, replace_urls_emails_mentions
248
+ )
249
+ self.true_model_name = model_name
250
+
251
+ def preprocess(self, text):
252
+ if "UBC-NLP" in self.true_model_name:
253
+ return self.ubc_prep(text)
254
+ else:
255
+ return super(SarcasmArabicPreprocessor, self).preprocess(text)
256
+
257
+ def ubc_prep(self, text):
258
+ text = re.sub("\s", " ", text)
259
+ text = text.replace("\\n", " ")
260
+ text = araby.strip_tashkeel(text)
261
+ text = araby.strip_tatweel(text)
262
+ # replace all possible URLs
263
+ for reg in url_regexes:
264
+ text = re.sub(reg, " URL ", text)
265
+ # replace mentions with USER
266
+ text = re.sub(user_mention_regex, " USER ", text)
267
+ # replace hashtags with HASHTAG
268
+ # text = re.sub(r"#[\w\d]+", " HASH TAG ", text)
269
+ text = text.replace("#", " HASH TAG ")
270
+ text = text.replace("_", " ")
271
+ text = text.replace('"', " ")
272
+ text = " ".join(text.split())
273
+ text = super(SarcasmArabicPreprocessor, self).preprocess(text)
274
+ return text
275
+
276
+
277
+ """NoAOAArabicPreprocessor"""
278
+
279
+
280
+ class NoAOAArabicPreprocessor(ArabertPreprocessor):
281
+ def __init__(
282
+ self,
283
+ model_name,
284
+ keep_emojis=False,
285
+ remove_html_markup=True,
286
+ replace_urls_emails_mentions=True,
287
+ ):
288
+ if "UBC-NLP" in model_name:
289
+ keep_emojis = True
290
+ super().__init__(
291
+ model_name, keep_emojis, remove_html_markup, replace_urls_emails_mentions
292
+ )
293
+ self.true_model_name = model_name
294
+
295
+ def preprocess(self, text):
296
+ if "UBC-NLP" in self.true_model_name:
297
+ return self.ubc_prep(text)
298
+ else:
299
+ return super(NoAOAArabicPreprocessor, self).preprocess(text)
300
+
301
+ def ubc_prep(self, text):
302
+ text = re.sub("\s", " ", text)
303
+ text = text.replace("\\n", " ")
304
+ text = araby.strip_tashkeel(text)
305
+ text = araby.strip_tatweel(text)
306
+ # replace all possible URLs
307
+ for reg in url_regexes:
308
+ text = re.sub(reg, " URL ", text)
309
+ # replace mentions with USER
310
+ text = re.sub(user_mention_regex, " USER ", text)
311
+ # replace hashtags with HASHTAG
312
+ # text = re.sub(r"#[\w\d]+", " HASH TAG ", text)
313
+ text = text.replace("#", " HASH TAG ")
314
+ text = text.replace("_", " ")
315
+ text = " ".join(text.split())
316
+ text = super(NoAOAArabicPreprocessor, self).preprocess(text)
317
+ text = text.replace("السلام عليكم", " ")
318
+ text = text.replace("ورحمة الله وبركاته", " ")
319
+ matched = find_near_matches("السلام عليكم", text, max_deletions=3, max_l_dist=3)
320
+ if len(matched) > 0:
321
+ text = text.replace(matched[0].matched, " ")
322
+ matched = find_near_matches(
323
+ "ورحمة الله وبركاته", text, max_deletions=3, max_l_dist=3
324
+ )
325
+ if len(matched) > 0:
326
+ text = text.replace(matched[0].matched, " ")
327
+ return text
328
+
329
+
330
+ class CnnBertForSequenceClassification(BertPreTrainedModel):
331
+ def __init__(self, config):
332
+ super().__init__(config)
333
+ self.num_labels = config.num_labels
334
+ self.config = config
335
+
336
+ self.bert = BertModel(config)
337
+
338
+ filter_sizes = [1, 2, 3, 4, 5]
339
+ num_filters = 32
340
+ self.convs1 = nn.ModuleList(
341
+ [nn.Conv2d(4, num_filters, (K, config.hidden_size)) for K in filter_sizes]
342
+ )
343
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
344
+ self.classifier = nn.Linear(len(filter_sizes) * num_filters, config.num_labels)
345
+
346
+ self.init_weights()
347
+
348
+ def forward(
349
+ self,
350
+ input_ids=None,
351
+ attention_mask=None,
352
+ token_type_ids=None,
353
+ position_ids=None,
354
+ head_mask=None,
355
+ inputs_embeds=None,
356
+ labels=None,
357
+ output_attentions=None,
358
+ output_hidden_states=None,
359
+ return_dict=None,
360
+ ):
361
+
362
+ return_dict = (
363
+ return_dict if return_dict is not None else self.config.use_return_dict
364
+ )
365
+
366
+ outputs = self.bert(
367
+ input_ids,
368
+ attention_mask=attention_mask,
369
+ token_type_ids=token_type_ids,
370
+ position_ids=position_ids,
371
+ head_mask=head_mask,
372
+ inputs_embeds=inputs_embeds,
373
+ output_attentions=output_attentions,
374
+ output_hidden_states=output_hidden_states,
375
+ return_dict=return_dict,
376
+ )
377
+
378
+ x = outputs[2][-4:]
379
+
380
+ x = torch.stack(x, dim=1)
381
+ x = [F.relu(conv(x)).squeeze(3) for conv in self.convs1]
382
+ x = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x]
383
+ x = torch.cat(x, 1)
384
+ x = self.dropout(x)
385
+ logits = self.classifier(x)
386
+
387
+ loss = None
388
+ if labels is not None:
389
+ if self.config.problem_type is None:
390
+ if self.num_labels == 1:
391
+ self.config.problem_type = "regression"
392
+ elif self.num_labels > 1 and (
393
+ labels.dtype == torch.long or labels.dtype == torch.int
394
+ ):
395
+ self.config.problem_type = "single_label_classification"
396
+ else:
397
+ self.config.problem_type = "multi_label_classification"
398
+
399
+ if self.config.problem_type == "regression":
400
+ loss_fct = nn.MSELoss()
401
+ if self.num_labels == 1:
402
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
403
+ else:
404
+ loss = loss_fct(logits, labels)
405
+ elif self.config.problem_type == "single_label_classification":
406
+ loss_fct = nn.CrossEntropyLoss()
407
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
408
+ elif self.config.problem_type == "multi_label_classification":
409
+ loss_fct = nn.BCEWithLogitsLoss()
410
+ loss = loss_fct(logits, labels)
411
+ if not return_dict:
412
+ output = (logits,) + outputs[2:]
413
+ return ((loss,) + output) if loss is not None else output
414
+
415
+ return SequenceClassifierOutput(
416
+ loss=loss,
417
+ logits=logits,
418
+ hidden_states=None,
419
+ attentions=outputs.attentions,
420
+ )
421
+
422
+
423
+ class CNNTextClassificationPipeline:
424
+ def __init__(self, model_path, device, return_all_scores=False):
425
+ self.model_path = model_path
426
+ self.model = CnnBertForSequenceClassification.from_pretrained(self.model_path)
427
+ # Special handling
428
+ self.device = torch.device("cpu" if device < 0 else f"cuda:{device}")
429
+ if self.device.type == "cuda":
430
+ self.model = self.model.to(self.device)
431
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
432
+ self.return_all_scores = return_all_scores
433
+
434
+ @contextmanager
435
+ def device_placement(self):
436
+ """
437
+ Context Manager allowing tensor allocation on the user-specified device in framework agnostic way.
438
+ Returns:
439
+ Context manager
440
+ Examples::
441
+ # Explicitly ask for tensor allocation on CUDA device :0
442
+ pipe = pipeline(..., device=0)
443
+ with pipe.device_placement():
444
+ # Every framework specific tensor allocation will be done on the request device
445
+ output = pipe(...)
446
+ """
447
+
448
+ if self.device.type == "cuda":
449
+ torch.cuda.set_device(self.device)
450
+
451
+ yield
452
+
453
+ def ensure_tensor_on_device(self, **inputs):
454
+ """
455
+ Ensure PyTorch tensors are on the specified device.
456
+ Args:
457
+ inputs (keyword arguments that should be :obj:`torch.Tensor`): The tensors to place on :obj:`self.device`.
458
+ Return:
459
+ :obj:`Dict[str, torch.Tensor]`: The same as :obj:`inputs` but on the proper device.
460
+ """
461
+ return {
462
+ name: tensor.to(self.device) if isinstance(tensor, torch.Tensor) else tensor
463
+ for name, tensor in inputs.items()
464
+ }
465
+
466
+ def __call__(self, text):
467
+ """
468
+ Classify the text(s) given as inputs.
469
+ Args:
470
+ args (:obj:`str` or :obj:`List[str]`):
471
+ One or several texts (or one list of prompts) to classify.
472
+ Return:
473
+ A list or a list of list of :obj:`dict`: Each result comes as list of dictionaries with the following keys:
474
+ - **label** (:obj:`str`) -- The label predicted.
475
+ - **score** (:obj:`float`) -- The corresponding probability.
476
+ If ``self.return_all_scores=True``, one such dictionary is returned per label.
477
+ """
478
+ # outputs = super().__call__(*args, **kwargs)
479
+ inputs = self.tokenizer.batch_encode_plus(
480
+ text,
481
+ add_special_tokens=True,
482
+ max_length=64,
483
+ padding=True,
484
+ truncation="longest_first",
485
+ return_tensors="pt",
486
+ )
487
+
488
+ with torch.no_grad():
489
+ inputs = self.ensure_tensor_on_device(**inputs)
490
+ predictions = self.model(**inputs)[0].cpu()
491
+
492
+ predictions = predictions.numpy()
493
+
494
+ if self.model.config.num_labels == 1:
495
+ scores = 1.0 / (1.0 + np.exp(-predictions))
496
+ else:
497
+ scores = np.exp(predictions) / np.exp(predictions).sum(-1, keepdims=True)
498
+ if self.return_all_scores:
499
+ return [
500
+ [
501
+ {"label": self.model.config.id2label[i], "score": score.item()}
502
+ for i, score in enumerate(item)
503
+ ]
504
+ for item in scores
505
+ ]
506
+ else:
507
+ return [
508
+ {"label": self.inv_label_map[item.argmax()], "score": item.max().item()}
509
+ for item in scores
510
+ ]
backend/sarcasm.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from .sa import predictor
3
+
4
+
5
+ def write():
6
+ st.markdown(
7
+ """
8
+ # Arabic Sarcasm Detection
9
+
10
+ This is a simple sarcasm detection app that uses the [MARBERT](https://huggingface.co/UBC-NLP/MARBERT) model trained on [ArSarcasm](https://github.com/iabufarha/ArSarcasm)
11
+ """
12
+ )
13
+
14
+ input_text = st.text_input(
15
+ "Enter your text here:",
16
+ )
17
+ if st.button("Predict"):
18
+ with st.spinner("Predicting..."):
19
+ prediction, scores = predictor.get_preds_from_sarcasm([input_text])
20
+ st.write(f"Result: {prediction[0]}")
21
+ st.write(f"Score: {scores[0]}")
backend/services.py ADDED
@@ -0,0 +1,519 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ from functools import lru_cache
5
+ from typing import List
6
+ from urllib.parse import unquote
7
+
8
+ import more_itertools
9
+ import pandas as pd
10
+ import requests
11
+ import streamlit as st
12
+ import wikipedia
13
+ from codetiming import Timer
14
+ from fuzzysearch import find_near_matches
15
+ from googleapi import google
16
+ from tqdm.auto import tqdm
17
+ from transformers import (
18
+ AutoTokenizer,
19
+ GPT2LMHeadModel,
20
+ GPT2Tokenizer,
21
+ pipeline,
22
+ set_seed,
23
+ )
24
+
25
+ from .modeling_gpt2 import GPT2LMHeadModel as GROVERLMHeadModel
26
+ from .preprocess import ArabertPreprocessor
27
+ from .sa_utils import *
28
+ from .utils import download_models, softmax
29
+
30
+ logger = logging.getLogger(__name__)
31
+ # Taken and Modified from https://huggingface.co/spaces/flax-community/chef-transformer/blob/main/app.py
32
+ class TextGeneration:
33
+ def __init__(self):
34
+ self.debug = False
35
+ self.generation_pipline = {}
36
+ self.preprocessor = ArabertPreprocessor(model_name="aragpt2-mega")
37
+ self.tokenizer = GPT2Tokenizer.from_pretrained(
38
+ "aubmindlab/aragpt2-mega", use_fast=False
39
+ )
40
+ self.tokenizer.pad_token = self.tokenizer.eos_token
41
+ self.API_KEY = os.getenv("API_KEY")
42
+ self.headers = {"Authorization": f"Bearer {self.API_KEY}"}
43
+ # self.model_names_or_paths = {
44
+ # "aragpt2-medium": "D:/ML/Models/aragpt2-medium",
45
+ # "aragpt2-base": "D:/ML/Models/aragpt2-base",
46
+ # }
47
+ self.model_names_or_paths = {
48
+ # "aragpt2-medium": "aubmindlab/aragpt2-medium",
49
+ "aragpt2-base": "aubmindlab/aragpt2-base",
50
+ # "aragpt2-large": "aubmindlab/aragpt2-large",
51
+ "aragpt2-mega": "aubmindlab/aragpt2-mega",
52
+ }
53
+ set_seed(42)
54
+
55
+ def load_pipeline(self):
56
+ for model_name, model_path in self.model_names_or_paths.items():
57
+ if "base" in model_name or "medium" in model_name:
58
+ self.generation_pipline[model_name] = pipeline(
59
+ "text-generation",
60
+ model=GPT2LMHeadModel.from_pretrained(model_path),
61
+ tokenizer=self.tokenizer,
62
+ device=-1,
63
+ )
64
+ else:
65
+ self.generation_pipline[model_name] = pipeline(
66
+ "text-generation",
67
+ model=GROVERLMHeadModel.from_pretrained(model_path),
68
+ tokenizer=self.tokenizer,
69
+ device=-1,
70
+ )
71
+
72
+ def load(self):
73
+ if not self.debug:
74
+ self.load_pipeline()
75
+
76
+ def generate(
77
+ self,
78
+ model_name,
79
+ prompt,
80
+ max_new_tokens: int,
81
+ temperature: float,
82
+ top_k: int,
83
+ top_p: float,
84
+ repetition_penalty: float,
85
+ no_repeat_ngram_size: int,
86
+ do_sample: bool,
87
+ num_beams: int,
88
+ ):
89
+ logger.info(f"Generating with {model_name}")
90
+ prompt = self.preprocessor.preprocess(prompt)
91
+ return_full_text = False
92
+ return_text = True
93
+ num_return_sequences = 1
94
+ pad_token_id = 0
95
+ eos_token_id = 0
96
+ input_tok = self.tokenizer.tokenize(prompt)
97
+ max_length = len(input_tok) + max_new_tokens
98
+ if max_length > 1024:
99
+ max_length = 1024
100
+ if not self.debug:
101
+ generated_text = self.generation_pipline[model_name.lower()](
102
+ prompt,
103
+ max_length=max_length,
104
+ temperature=temperature,
105
+ top_k=top_k,
106
+ top_p=top_p,
107
+ repetition_penalty=repetition_penalty,
108
+ no_repeat_ngram_size=no_repeat_ngram_size,
109
+ pad_token_id=pad_token_id,
110
+ eos_token_id=eos_token_id,
111
+ return_full_text=return_full_text,
112
+ return_text=return_text,
113
+ do_sample=do_sample,
114
+ num_beams=num_beams,
115
+ num_return_sequences=num_return_sequences,
116
+ )[0]["generated_text"]
117
+ else:
118
+ generated_text = self.generate_by_query(
119
+ prompt,
120
+ model_name,
121
+ max_length=max_length,
122
+ temperature=temperature,
123
+ top_k=top_k,
124
+ top_p=top_p,
125
+ repetition_penalty=repetition_penalty,
126
+ no_repeat_ngram_size=no_repeat_ngram_size,
127
+ pad_token_id=pad_token_id,
128
+ eos_token_id=eos_token_id,
129
+ return_full_text=return_full_text,
130
+ return_text=return_text,
131
+ do_sample=do_sample,
132
+ num_beams=num_beams,
133
+ num_return_sequences=num_return_sequences,
134
+ )
135
+ # print(generated_text)
136
+ if isinstance(generated_text, dict):
137
+ if "error" in generated_text:
138
+ if "is currently loading" in generated_text["error"]:
139
+ return f"Model is currently loading, estimated time is {generated_text['estimated_time']}"
140
+ return generated_text["error"]
141
+ else:
142
+ return "Something happened 🤷‍♂️!!"
143
+ else:
144
+ generated_text = generated_text[0]["generated_text"]
145
+
146
+ logger.info(f"Prompt: {prompt}")
147
+ logger.info(f"Generated text: {generated_text}")
148
+ return self.preprocessor.unpreprocess(generated_text)
149
+
150
+ def query(self, payload, model_name):
151
+ data = json.dumps(payload)
152
+ url = (
153
+ "https://api-inference.huggingface.co/models/aubmindlab/"
154
+ + model_name.lower()
155
+ )
156
+ response = requests.request("POST", url, headers=self.headers, data=data)
157
+ return json.loads(response.content.decode("utf-8"))
158
+
159
+ def generate_by_query(
160
+ self,
161
+ prompt: str,
162
+ model_name: str,
163
+ max_length: int,
164
+ temperature: float,
165
+ top_k: int,
166
+ top_p: float,
167
+ repetition_penalty: float,
168
+ no_repeat_ngram_size: int,
169
+ pad_token_id: int,
170
+ eos_token_id: int,
171
+ return_full_text: int,
172
+ return_text: int,
173
+ do_sample: bool,
174
+ num_beams: int,
175
+ num_return_sequences: int,
176
+ ):
177
+ payload = {
178
+ "inputs": prompt,
179
+ "parameters": {
180
+ "max_length ": max_length,
181
+ "top_k": top_k,
182
+ "top_p": top_p,
183
+ "temperature": temperature,
184
+ "repetition_penalty": repetition_penalty,
185
+ "no_repeat_ngram_size": no_repeat_ngram_size,
186
+ "pad_token_id": pad_token_id,
187
+ "eos_token_id": eos_token_id,
188
+ "return_full_text": return_full_text,
189
+ "return_text": return_text,
190
+ "pad_token_id": pad_token_id,
191
+ "do_sample": do_sample,
192
+ "num_beams": num_beams,
193
+ "num_return_sequences": num_return_sequences,
194
+ },
195
+ "options": {
196
+ "use_cache": True,
197
+ },
198
+ }
199
+ return self.query(payload, model_name)
200
+
201
+
202
+ class SentimentAnalyzer:
203
+ def __init__(self):
204
+ self.sa_models = [
205
+ "sa_trial5_1",
206
+ # "sa_no_aoa_in_neutral",
207
+ # "sa_cnnbert",
208
+ # "sa_sarcasm",
209
+ # "sar_trial10",
210
+ # "sa_no_AOA",
211
+ ]
212
+ download_models(self.sa_models)
213
+ # fmt: off
214
+ self.processors = {
215
+ "sa_trial5_1": Trial5ArabicPreprocessor(model_name='UBC-NLP/MARBERT'),
216
+ # "sa_no_aoa_in_neutral": NewArabicPreprocessorBalanced(model_name='UBC-NLP/MARBERT'),
217
+ # "sa_cnnbert": CNNMarbertArabicPreprocessor(model_name='UBC-NLP/MARBERT'),
218
+ # "sa_sarcasm": SarcasmArabicPreprocessor(model_name='UBC-NLP/MARBERT'),
219
+ # "sar_trial10": SarcasmArabicPreprocessor(model_name='UBC-NLP/MARBERT'),
220
+ # "sa_no_AOA": NewArabicPreprocessorBalanced(model_name='UBC-NLP/MARBERT'),
221
+ }
222
+
223
+ self.pipelines = {
224
+ "sa_trial5_1": [pipeline("sentiment-analysis", model="{}/train_{}/best_model".format("sa_trial5_1",i), device=-1,return_all_scores =True) for i in tqdm(range(0,5), desc=f"Loading pipeline for model: sa_trial5_1")],
225
+ # "sa_no_aoa_in_neutral": [pipeline("sentiment-analysis", model="{}/train_{}/best_model".format("sa_no_aoa_in_neutral",i), device=-1,return_all_scores =True) for i in tqdm(range(0,5), desc=f"Loading pipeline for model: sa_no_aoa_in_neutral")],
226
+ # "sa_cnnbert": [CNNTextClassificationPipeline("{}/train_{}/best_model".format("sa_cnnbert",i), device=-1, return_all_scores =True) for i in tqdm(range(0,5), desc=f"Loading pipeline for model: sa_cnnbert")],
227
+ # "sa_sarcasm": [pipeline("sentiment-analysis", model="{}/train_{}/best_model".format("sa_sarcasm",i), device=-1,return_all_scores =True) for i in tqdm(range(0,5), desc=f"Loading pipeline for model: sa_sarcasm")],
228
+ # "sar_trial10": [pipeline("sentiment-analysis", model="{}/train_{}/best_model".format("sar_trial10",i), device=-1,return_all_scores =True) for i in tqdm(range(0,5), desc=f"Loading pipeline for model: sar_trial10")],
229
+ # "sa_no_AOA": [pipeline("sentiment-analysis", model="{}/train_{}/best_model".format("sa_no_AOA",i), device=-1,return_all_scores =True) for i in tqdm(range(0,5), desc=f"Loading pipeline for model: sa_no_AOA")],
230
+ }
231
+ # fmt: on
232
+
233
+ def get_preds_from_sarcasm(self, texts):
234
+ prep = self.processors["sar_trial10"]
235
+ prep_texts = [prep.preprocess(x) for x in texts]
236
+
237
+ preds_df = pd.DataFrame([])
238
+ for i in range(0, 5):
239
+ preds = []
240
+ for s in more_itertools.chunked(list(prep_texts), 128):
241
+ preds.extend(self.pipelines["sar_trial10"][i](s))
242
+ preds_df[f"model_{i}"] = preds
243
+
244
+ final_labels = []
245
+ final_scores = []
246
+ for id, row in preds_df.iterrows():
247
+ pos_total = 0
248
+ neu_total = 0
249
+ for pred in row[:]:
250
+ pos_total += pred[0]["score"]
251
+ neu_total += pred[1]["score"]
252
+
253
+ pos_avg = pos_total / len(row[:])
254
+ neu_avg = neu_total / len(row[:])
255
+
256
+ final_labels.append(
257
+ self.pipelines["sar_trial10"][0].model.config.id2label[
258
+ np.argmax([pos_avg, neu_avg])
259
+ ]
260
+ )
261
+ final_scores.append(np.max([pos_avg, neu_avg]))
262
+
263
+ return final_labels, final_scores
264
+
265
+ def get_preds_from_a_model(self, texts: List[str], model_name):
266
+ try:
267
+ prep = self.processors[model_name]
268
+
269
+ prep_texts = [prep.preprocess(x) for x in texts]
270
+ if model_name == "sa_sarcasm":
271
+ sarcasm_label, _ = self.get_preds_from_sarcasm(texts)
272
+ sarcastic_map = {"Not_Sarcastic": "غير ساخر", "Sarcastic": "ساخر"}
273
+ labeled_prep_texts = []
274
+ for t, l in zip(prep_texts, sarcasm_label):
275
+ labeled_prep_texts.append(sarcastic_map[l] + " [SEP] " + t)
276
+
277
+ preds_df = pd.DataFrame([])
278
+ for i in range(0, 5):
279
+ preds = []
280
+ for s in more_itertools.chunked(list(prep_texts), 128):
281
+ preds.extend(self.pipelines[model_name][i](s))
282
+ preds_df[f"model_{i}"] = preds
283
+
284
+ final_labels = []
285
+ final_scores = []
286
+ final_scores_list = []
287
+ for id, row in preds_df.iterrows():
288
+ pos_total = 0
289
+ neg_total = 0
290
+ neu_total = 0
291
+ for pred in row[2:]:
292
+ pos_total += pred[0]["score"]
293
+ neu_total += pred[1]["score"]
294
+ neg_total += pred[2]["score"]
295
+
296
+ pos_avg = pos_total / 5
297
+ neu_avg = neu_total / 5
298
+ neg_avg = neg_total / 5
299
+
300
+ if model_name == "sa_no_aoa_in_neutral":
301
+ final_labels.append(
302
+ self.pipelines[model_name][0].model.config.id2label[
303
+ np.argmax([neu_avg, neg_avg, pos_avg])
304
+ ]
305
+ )
306
+ else:
307
+ final_labels.append(
308
+ self.pipelines[model_name][0].model.config.id2label[
309
+ np.argmax([pos_avg, neu_avg, neg_avg])
310
+ ]
311
+ )
312
+ final_scores.append(np.max([pos_avg, neu_avg, neg_avg]))
313
+ final_scores_list.append((pos_avg, neu_avg, neg_avg))
314
+ except RuntimeError as e:
315
+ if model_name == "sa_cnnbert":
316
+ return (
317
+ ["Neutral"] * len(texts),
318
+ [0.0] * len(texts),
319
+ [(0.0, 0.0, 0.0)] * len(texts),
320
+ )
321
+ else:
322
+ raise RuntimeError(e)
323
+ return final_labels, final_scores, final_scores_list
324
+
325
+ def predict(self, texts: List[str]):
326
+ logger.info(f"Predicting for: {texts}")
327
+ # (
328
+ # new_balanced_label,
329
+ # new_balanced_score,
330
+ # new_balanced_score_list,
331
+ # ) = self.get_preds_from_a_model(texts, "sa_no_aoa_in_neutral")
332
+ # (
333
+ # cnn_marbert_label,
334
+ # cnn_marbert_score,
335
+ # cnn_marbert_score_list,
336
+ # ) = self.get_preds_from_a_model(texts, "sa_cnnbert")
337
+ trial5_label, trial5_score, trial5_score_list = self.get_preds_from_a_model(
338
+ texts, "sa_trial5_1"
339
+ )
340
+ # no_aoa_label, no_aoa_score, no_aoa_score_list = self.get_preds_from_a_model(
341
+ # texts, "sa_no_AOA"
342
+ # )
343
+ # sarcasm_label, sarcasm_score, sarcasm_score_list = self.get_preds_from_a_model(
344
+ # texts, "sa_sarcasm"
345
+ # )
346
+
347
+ id_label_map = {0: "Positive", 1: "Neutral", 2: "Negative"}
348
+
349
+ final_ensemble_prediction = []
350
+ final_ensemble_score = []
351
+ final_ensemble_all_score = []
352
+ for entry in zip(
353
+ # new_balanced_score_list,
354
+ # cnn_marbert_score_list,
355
+ trial5_score_list,
356
+ # no_aoa_score_list,
357
+ # sarcasm_score_list,
358
+ ):
359
+ pos_score = 0
360
+ neu_score = 0
361
+ neg_score = 0
362
+ for s in entry:
363
+ pos_score += s[0] * 1.57
364
+ neu_score += s[1] * 0.98
365
+ neg_score += s[2] * 0.93
366
+
367
+ # weighted 2
368
+ # pos_score += s[0]*1.67
369
+ # neu_score += s[1]
370
+ # neg_score += s[2]*0.95
371
+
372
+ final_ensemble_prediction.append(
373
+ id_label_map[np.argmax([pos_score, neu_score, neg_score])]
374
+ )
375
+ final_ensemble_score.append(np.max([pos_score, neu_score, neg_score]))
376
+ final_ensemble_all_score.append(
377
+ softmax(np.array([pos_score, neu_score, neg_score])).tolist()
378
+ )
379
+
380
+ logger.info(f"Result: {final_ensemble_prediction}")
381
+ logger.info(f"Score: {final_ensemble_score}")
382
+ logger.info(f"All Scores: {final_ensemble_all_score}")
383
+ return final_ensemble_prediction, final_ensemble_score, final_ensemble_all_score
384
+
385
+
386
+ wikipedia.set_lang("ar")
387
+
388
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
389
+
390
+ preprocessor = ArabertPreprocessor("wissamantoun/araelectra-base-artydiqa")
391
+ logger.info("Loading QA Pipeline...")
392
+ tokenizer = AutoTokenizer.from_pretrained("wissamantoun/araelectra-base-artydiqa")
393
+ qa_pipe = pipeline("question-answering", model="wissamantoun/araelectra-base-artydiqa")
394
+ logger.info("Finished loading QA Pipeline...")
395
+
396
+
397
+ @lru_cache(maxsize=100)
398
+ def get_qa_answers(question):
399
+ logger.info("\n=================================================================")
400
+ logger.info(f"Question: {question}")
401
+
402
+ if "وسام أنطون" in question or "wissam antoun" in question.lower():
403
+ return {
404
+ "title": "Creator",
405
+ "results": [
406
+ {
407
+ "score": 1.0,
408
+ "new_start": 0,
409
+ "new_end": 12,
410
+ "new_answer": "My Creator 😜",
411
+ "original": "My Creator 😜",
412
+ "link": "https://github.com/WissamAntoun/",
413
+ }
414
+ ],
415
+ }
416
+ search_timer = Timer(
417
+ "search and wiki", text="Search and Wikipedia Time: {:.2f}", logger=logging.info
418
+ )
419
+ try:
420
+ search_timer.start()
421
+ search_results = google.search(
422
+ question + " site:ar.wikipedia.org", lang="ar", area="ar"
423
+ )
424
+ if len(search_results) == 0:
425
+ return {}
426
+
427
+ page_name = search_results[0].link.split("wiki/")[-1]
428
+ wiki_page = wikipedia.page(unquote(page_name))
429
+ wiki_page_content = wiki_page.content
430
+ search_timer.stop()
431
+ except:
432
+ return {}
433
+
434
+ sections = []
435
+ for section in re.split("== .+ ==[^=]", wiki_page_content):
436
+ if not section.isspace():
437
+ prep_section = tokenizer.tokenize(preprocessor.preprocess(section))
438
+ if len(prep_section) > 500:
439
+ subsections = []
440
+ for subsection in re.split("=== .+ ===", section):
441
+ if subsection.isspace():
442
+ continue
443
+ prep_subsection = tokenizer.tokenize(
444
+ preprocessor.preprocess(subsection)
445
+ )
446
+ subsections.append(subsection)
447
+ # logger.info(f"Subsection found with length: {len(prep_subsection)}")
448
+ sections.extend(subsections)
449
+ else:
450
+ # logger.info(f"Regular Section with length: {len(prep_section)}")
451
+ sections.append(section)
452
+
453
+ full_len_sections = []
454
+ temp_section = ""
455
+ for section in sections:
456
+ if (
457
+ len(tokenizer.tokenize(preprocessor.preprocess(temp_section)))
458
+ + len(tokenizer.tokenize(preprocessor.preprocess(section)))
459
+ > 384
460
+ ):
461
+ if temp_section == "":
462
+ temp_section = section
463
+ continue
464
+ full_len_sections.append(temp_section)
465
+ # logger.info(
466
+ # f"full section length: {len(tokenizer.tokenize(preprocessor.preprocess(temp_section)))}"
467
+ # )
468
+ temp_section = ""
469
+ else:
470
+ temp_section += " " + section + " "
471
+ if temp_section != "":
472
+ full_len_sections.append(temp_section)
473
+
474
+ reader_time = Timer("electra", text="Reader Time: {:.2f}", logger=logging.info)
475
+ reader_time.start()
476
+ results = qa_pipe(
477
+ question=[preprocessor.preprocess(question)] * len(full_len_sections),
478
+ context=[preprocessor.preprocess(x) for x in full_len_sections],
479
+ )
480
+
481
+ if not isinstance(results, list):
482
+ results = [results]
483
+
484
+ logger.info(f"Wiki Title: {unquote(page_name)}")
485
+ logger.info(f"Total Sections: {len(sections)}")
486
+ logger.info(f"Total Full Sections: {len(full_len_sections)}")
487
+
488
+ for result, section in zip(results, full_len_sections):
489
+ result["original"] = section
490
+ answer_match = find_near_matches(
491
+ " " + preprocessor.unpreprocess(result["answer"]) + " ",
492
+ result["original"],
493
+ max_l_dist=min(5, len(preprocessor.unpreprocess(result["answer"])) // 2),
494
+ max_deletions=0,
495
+ )
496
+ try:
497
+ result["new_start"] = answer_match[0].start
498
+ result["new_end"] = answer_match[0].end
499
+ result["new_answer"] = answer_match[0].matched
500
+ result["link"] = (
501
+ search_results[0].link + "#:~:text=" + result["new_answer"].strip()
502
+ )
503
+ except:
504
+ result["new_start"] = result["start"]
505
+ result["new_end"] = result["end"]
506
+ result["new_answer"] = result["answer"]
507
+ result["original"] = preprocessor.preprocess(result["original"])
508
+ result["link"] = search_results[0].link
509
+ logger.info(f"Answers: {preprocessor.preprocess(result['new_answer'])}")
510
+
511
+ sorted_results = sorted(results, reverse=True, key=lambda x: x["score"])
512
+
513
+ return_dict = {}
514
+ return_dict["title"] = unquote(page_name)
515
+ return_dict["results"] = sorted_results
516
+
517
+ reader_time.stop()
518
+ logger.info(f"Total time spent: {reader_time.last + search_timer.last}")
519
+ return return_dict
backend/utils.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import numpy as np
3
+ import psutil
4
+ import os
5
+ from tqdm.auto import tqdm
6
+ import logging
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ def get_current_ram_usage():
12
+ ram = psutil.virtual_memory()
13
+ return ram.available / 1024 / 1024 / 1024, ram.total / 1024 / 1024 / 1024
14
+
15
+
16
+ def download_models(models):
17
+ for model in tqdm(models, desc="Downloading models"):
18
+ logger.info(f"Downloading {model}")
19
+ for i in range(0, 5):
20
+ curr_dir = f"{model}/train_{i}/best_model/"
21
+ os.makedirs(curr_dir, exist_ok=True)
22
+ os.system(
23
+ f"wget -q https://huggingface.co/researchaccount/{model}/resolve/main/train_{i}/best_model/config.json -P {curr_dir}"
24
+ )
25
+ os.system(
26
+ f"wget -q https://huggingface.co/researchaccount/{model}/resolve/main/train_{i}/best_model/pytorch_model.bin -P {curr_dir}"
27
+ )
28
+ os.system(
29
+ f"wget -q https://huggingface.co/researchaccount/{model}/resolve/main/train_{i}/best_model/special_tokens_map.json -P {curr_dir}"
30
+ )
31
+ os.system(
32
+ f"wget -q https://huggingface.co/researchaccount/{model}/resolve/main/train_{i}/best_model/tokenizer_config.json -P {curr_dir}"
33
+ )
34
+ os.system(
35
+ f"wget -q https://huggingface.co/researchaccount/{model}/resolve/main/train_{i}/best_model/training_args.bin -P {curr_dir}"
36
+ )
37
+ os.system(
38
+ f"wget -q https://huggingface.co/researchaccount/{model}/resolve/main/train_{i}/best_model/vocab.txt -P {curr_dir}"
39
+ )
40
+
41
+
42
+ def softmax(x):
43
+ return np.exp(x) / sum(np.exp(x))
44
+
45
+
46
+ def ga(file):
47
+ code = """
48
+ <!-- Global site tag (gtag.js) - Google Analytics -->
49
+ <script async src="https://www.googletagmanager.com/gtag/js?id=G-NH9HWCW08F"></script>
50
+ <script>
51
+ window.dataLayer = window.dataLayer || [];
52
+ function gtag(){dataLayer.push(arguments);}
53
+ gtag('js', new Date());
54
+ gtag('config', 'G-NH9HWCW08F');
55
+ </script>
56
+ """
57
+
58
+ a = os.path.dirname(file) + "/static/index.html"
59
+ with open(a, "r") as f:
60
+ data = f.read()
61
+ if len(re.findall("G-", data)) == 0:
62
+ with open(a, "w") as ff:
63
+ newdata = re.sub("<head>", "<head>" + code, data)
64
+ ff.write(newdata)
packages.txt ADDED
@@ -0,0 +1,2 @@
 
 
1
+ openjdk-11-jre
2
+ curl
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit==0.84.2
2
+ arabic-reshaper==2.1.3
3
+ python-bidi==0.4.2
4
+ PyArabic
5
+ farasapy==0.0.14
6
+ emoji==1.4.2
7
+ awesome_streamlit
8
+ torch==1.9.0
9
+ transformers==4.10.0
10
+ psutil==5.8.0
11
+ fuzzysearch==0.7.3
12
+ more-itertools==8.9.0
13
+ cookiecutter
14
+ git+https://github.com/dantru7/Google-Search-API
15
+ codetiming==1.3.0
16
+ htbuilder
17
+ wikipedia==1.4.0
test.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
1
+ #%%
2
+ from transformers import GPT2Tokenizer
3
+
4
+ # %%
5
+ tok = GPT2Tokenizer.from_pretrained("D:/ML/Models/aragpt2-medium", use_fast=False)
6
+ # %%
7
+ tok.pad_token = tok.eos_token
8
+ #%%
9
+ tok.pad_token_id = [tok.eos_token_id]
10
+ # %%