abxhr commited on
Commit
34473f3
1 Parent(s): ce608f1
Pipfile ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [[source]]
2
+ url = "https://pypi.org/simple"
3
+ verify_ssl = true
4
+ name = "pypi"
5
+
6
+ [packages]
7
+ streamlit = "==0.84.2"
8
+ arabic-reshaper = "==2.1.3"
9
+ python-bidi = "==0.4.2"
10
+ pyarabic = "*"
11
+ farasapy = "==0.0.14"
12
+ emoji = "==1.4.2"
13
+ awesome-streamlit = "*"
14
+ torch = "==1.9.0"
15
+ transformers = "==4.10.0"
16
+ psutil = "==5.8.0"
17
+ fuzzysearch = "==0.7.3"
18
+ more-itertools = "==8.9.0"
19
+ cookiecutter = "*"
20
+
21
+ [dev-packages]
22
+
23
+ [requires]
24
+ python_version = "3.8"
README.md CHANGED
@@ -1,13 +1,9 @@
1
  ---
2
- title: Design Project
3
- emoji: 🏃
4
  colorFrom: purple
5
  colorTo: green
6
  sdk: streamlit
7
- sdk_version: 1.9.0
8
  app_file: app.py
9
- pinned: false
10
- license: unlicense
11
  ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference
 
1
  ---
2
+ title: Arabic NLP Demo
3
+ emoji:
4
  colorFrom: purple
5
  colorTo: green
6
  sdk: streamlit
 
7
  app_file: app.py
8
+ pinned: true
 
9
  ---
 
 
app.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import awesome_streamlit as ast
2
+ import streamlit as st
3
+
4
+ from backend.utils import get_current_ram_usage, ga
5
+
6
+ import backend.aragpt
7
+ import backend.home
8
+ import backend.processor
9
+ import backend.sa
10
+ import backend.qa
11
+
12
+ st.set_page_config(
13
+ page_title="TEST", page_icon="📖", initial_sidebar_state="expanded", layout="wide"
14
+ )
15
+
16
+ ga(st.__file__)
17
+
18
+ PAGES = {
19
+ "Home": backend.home,
20
+ "Arabic Text Preprocessor": backend.processor,
21
+ "Arabic Language Generation": backend.aragpt,
22
+ "Arabic Sentiment Analysis": backend.sa,
23
+ # "Arabic Sarcasm Detection": backend.sarcasm,
24
+ "Arabic Question Answering": backend.qa,
25
+ }
26
+
27
+
28
+ st.sidebar.title("Navigation")
29
+ selection = st.sidebar.radio("Pages", list(PAGES.keys()))
30
+
31
+ page = PAGES[selection]
32
+ # with st.spinner(f"Loading {selection} ..."):
33
+ ast.shared.components.write_page(page)
34
+
35
+ st.sidebar.header("Info")
36
+ st.sidebar.write("Made by [Wissam Antoun](https://twitter.com/wissam_antoun)")
37
+ st.sidebar.write(
38
+ "Pre-trained models are available on [HF Hub](https://huggingface.co/aubmindlab)"
39
+ )
40
+ st.sidebar.write(
41
+ "Models source code available on [GitHub](https://github.com/aub-mind/arabert)"
42
+ )
43
+ st.sidebar.write(
44
+ "App source code available on [GitHub](https://github.com/WissamAntoun/Arabic-NLP-app)"
45
+ )
46
+ if st.sidebar.checkbox("Show RAM usage"):
47
+ ram = get_current_ram_usage()
48
+ st.sidebar.write("Ram usage: {:.2f}/{:.2f} GB".format(ram[0], ram[1]))
backend/__init__.py ADDED
File without changes
backend/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (214 Bytes). View file
 
backend/__pycache__/aragpt.cpython-38.pyc ADDED
Binary file (4.41 kB). View file
 
backend/__pycache__/modeling_gpt2.cpython-38.pyc ADDED
Binary file (42.9 kB). View file
 
backend/__pycache__/preprocess.cpython-38.pyc ADDED
Binary file (17.8 kB). View file
 
backend/__pycache__/sa_utils.cpython-38.pyc ADDED
Binary file (14.5 kB). View file
 
backend/__pycache__/services.cpython-38.pyc ADDED
Binary file (11.6 kB). View file
 
backend/__pycache__/utils.cpython-38.pyc ADDED
Binary file (2.28 kB). View file
 
backend/aragpt.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from .services import TextGeneration
3
+ from tokenizers import Tokenizer
4
+ from functools import lru_cache
5
+
6
+ # @st.cache(allow_output_mutation=False, hash_funcs={Tokenizer: str})
7
+ @lru_cache(maxsize=1)
8
+ def load_text_generator():
9
+ generator = TextGeneration()
10
+ generator.load()
11
+ return generator
12
+
13
+
14
+ generator = load_text_generator()
15
+
16
+ qa_prompt = """
17
+ أجب عن السؤال التالي:
18
+ """
19
+ qa_prompt_post = """ الجواب هو """
20
+ qa_prompt_post_year = """ في سنة: """
21
+
22
+
23
+ def write():
24
+ st.markdown(
25
+ """
26
+ <h1 style="text-align:left;">Arabic Language Generation</h1>
27
+ """,
28
+ unsafe_allow_html=True,
29
+ )
30
+
31
+ # Sidebar
32
+
33
+ # Taken from https://huggingface.co/spaces/flax-community/spanish-gpt2/blob/main/app.py
34
+ st.sidebar.subheader("Configurable parameters")
35
+
36
+ model_name = st.sidebar.selectbox(
37
+ "Model Selector",
38
+ options=[
39
+ "AraGPT2-Base",
40
+ # "AraGPT2-Medium",
41
+ # "Aragpt2-Large",
42
+ "AraGPT2-Mega",
43
+ ],
44
+ index=0,
45
+ )
46
+
47
+ max_new_tokens = st.sidebar.number_input(
48
+ "Maximum length",
49
+ min_value=0,
50
+ max_value=1024,
51
+ value=100,
52
+ help="The maximum length of the sequence to be generated.",
53
+ )
54
+ temp = st.sidebar.slider(
55
+ "Temperature",
56
+ value=1.0,
57
+ min_value=0.1,
58
+ max_value=100.0,
59
+ help="The value used to module the next token probabilities.",
60
+ )
61
+ top_k = st.sidebar.number_input(
62
+ "Top k",
63
+ value=10,
64
+ help="The number of highest probability vocabulary tokens to keep for top-k-filtering.",
65
+ )
66
+ top_p = st.sidebar.number_input(
67
+ "Top p",
68
+ value=0.95,
69
+ help=" If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.",
70
+ )
71
+ do_sample = st.sidebar.selectbox(
72
+ "Sampling?",
73
+ (True, False),
74
+ help="Whether or not to use sampling; use greedy decoding otherwise.",
75
+ )
76
+ num_beams = st.sidebar.number_input(
77
+ "Number of beams",
78
+ min_value=1,
79
+ max_value=10,
80
+ value=3,
81
+ help="The number of beams to use for beam search.",
82
+ )
83
+ repetition_penalty = st.sidebar.number_input(
84
+ "Repetition Penalty",
85
+ min_value=0.0,
86
+ value=3.0,
87
+ step=0.1,
88
+ help="The parameter for repetition penalty. 1.0 means no penalty",
89
+ )
90
+ no_repeat_ngram_size = st.sidebar.number_input(
91
+ "No Repeat N-Gram Size",
92
+ min_value=0,
93
+ value=3,
94
+ help="If set to int > 0, all ngrams of that size can only occur once.",
95
+ )
96
+
97
+ st.write("#")
98
+
99
+ col = st.columns(2)
100
+
101
+ col[0].image("images/AraGPT2.png", width=200)
102
+
103
+ st.markdown(
104
+ """
105
+
106
+ <h3 style="text-align:left;">AraGPT2 is GPT2 model trained from scratch on 77GB of Arabic text.</h3>
107
+ <h4 style="text-align:left;"> More details in our <a href="https://github.com/aub-mind/arabert/tree/master/aragpt2">repo</a>.</h4>
108
+
109
+ <p style="text-align:left;"><p>
110
+ <p style="text-align:left;">Use the generation paramters on the sidebar to adjust generation quality.</p>
111
+ <p style="text-align:right;"><p>
112
+ """,
113
+ unsafe_allow_html=True,
114
+ )
115
+
116
+ # col[0].write(
117
+ # "AraGPT2 is trained from screatch on 77GB of Arabic text. More details in our [repo](https://github.com/aub-mind/arabert/tree/master/aragpt2)."
118
+ # )
119
+ # st.write("## Generate Arabic Text")
120
+
121
+ st.markdown(
122
+ """
123
+ <style>
124
+ p, div, input, label, textarea{
125
+ text-align: right;
126
+ }
127
+ </style>
128
+ """,
129
+ unsafe_allow_html=True,
130
+ )
131
+
132
+ prompt = st.text_area(
133
+ "Prompt",
134
+ "يحكى أن مزارعا مخادعا قام ببيع بئر الماء الموجود في أرضه لجاره مقابل مبلغ كبير من المال",
135
+ )
136
+ if st.button("Generate"):
137
+ with st.spinner("Generating..."):
138
+ generated_text = generator.generate(
139
+ prompt=prompt,
140
+ model_name=model_name,
141
+ max_new_tokens=max_new_tokens,
142
+ temperature=temp,
143
+ top_k=top_k,
144
+ top_p=top_p,
145
+ repetition_penalty=repetition_penalty,
146
+ do_sample=do_sample,
147
+ num_beams=num_beams,
148
+ no_repeat_ngram_size=no_repeat_ngram_size,
149
+ )
150
+ st.write(generated_text)
151
+
152
+ st.markdown("---")
153
+ st.subheader("")
154
+ st.markdown(
155
+ """
156
+ <p style="text-align:left;"><p>
157
+ <h2 style="text-align:left;">Zero-Shot Question Answering</h2>
158
+
159
+ <p style="text-align:left;">Adjust the maximum length to closely match the expected output length. Setting the Sampling paramter to False is recommended</p>
160
+ <p style="text-align:left;"><p>
161
+ """,
162
+ unsafe_allow_html=True,
163
+ )
164
+
165
+ question = st.text_input(
166
+ "Question", "من كان رئيس ألمانيا النازية في الحرب العالمية الثانية ؟"
167
+ )
168
+ is_date = st.checkbox("Help the model: Is the answer a date?")
169
+ if st.button("Answer"):
170
+
171
+ prompt2 = qa_prompt + question + qa_prompt_post
172
+ if is_date:
173
+ prompt2 += qa_prompt_post_year
174
+ else:
175
+ prompt2 += " : "
176
+ with st.spinner("Thinking..."):
177
+ answer = generator.generate(
178
+ prompt=prompt2,
179
+ model_name=model_name,
180
+ max_new_tokens=max_new_tokens,
181
+ temperature=temp,
182
+ top_k=top_k,
183
+ top_p=top_p,
184
+ repetition_penalty=repetition_penalty,
185
+ do_sample=do_sample,
186
+ num_beams=num_beams,
187
+ no_repeat_ngram_size=no_repeat_ngram_size,
188
+ )
189
+ st.write(answer)
backend/home.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import awesome_streamlit as ast
3
+
4
+
5
+ def write():
6
+ st.markdown(
7
+ """
8
+ # Arabic Natural Language Processing
9
+
10
+ ![visitors](https://visitor-badge.glitch.me/badge?page_id=wissamantoun.arabicnlpapp)
11
+
12
+
13
+ In this HuggingFace space you will be able to test the different Arabic NLP models that my colleges at [AUB MIND Lab](https://sites.aub.edu.lb/mindlab/) have built, with some other applications.
14
+
15
+ Check the **Navigation bar** to access the apps:
16
+ - Arabic Text Preprocessor: Test how text imput is treated by our preprocessor
17
+ - Arabic Language Generation: Generate Arabic text using our AraGPT2 language models
18
+ - Arabic Sentiment Analysis: Test the senitment analysis model that won the [Arabic Senitment Analysis competition @ KAUST](https://www.kaggle.com/c/arabic-sentiment-analysis-2021-kaust)
19
+ - Arabic Question Answering: Test our AraELECTRA QA capabilities
20
+ """
21
+ )
22
+ st.markdown("#")
23
+ col1, col2, col3 = st.columns(3)
24
+
25
+ col1.write("## **AraBERT**")
26
+ col1.image("images/arabert_logo.png", width=200)
27
+
28
+ col2.write("## **AraGPT2**")
29
+ col2.image("images/AraGPT2.png", width=200)
30
+
31
+ col3.write("## **AraElectra**")
32
+ col3.image("images/AraELECTRA.png", width=200)
33
+
34
+ st.markdown(
35
+ """
36
+
37
+ You can find the more details in the source code and paper linked in our repository on GitHub [repo](https://github.com/aub-mind/arabert).
38
+
39
+ ## Dataset
40
+
41
+ The pretraining data used for the new **AraBERT** model is also used for **AraGPT2 and AraELECTRA**.
42
+
43
+ The dataset consists of 77GB or 200,095,961 lines or 8,655,948,860 words or 82,232,988,358 chars (before applying Farasa Segmentation)
44
+
45
+ Our large models were train a TPUv3-128 provided by TFRC.
46
+
47
+ For the new dataset we added the unshuffled OSCAR corpus, after we thoroughly filter it, to the previous dataset used in AraBERTv1 but with out the websites that we previously crawled:
48
+ - OSCAR unshuffled and filtered.
49
+ - [Arabic Wikipedia dump](https://archive.org/details/arwiki-20190201) from 2020/09/01
50
+ - [The 1.5B words Arabic Corpus](https://www.semanticscholar.org/paper/1.5-billion-words-Arabic-Corpus-El-Khair/f3eeef4afb81223df96575adadf808fe7fe440b4)
51
+ - [The OSIAN Corpus](https://www.aclweb.org/anthology/W19-4619)
52
+ - Assafir news articles. Huge thank you for Assafir for the data
53
+
54
+ ## Models
55
+
56
+ Model | HuggingFace Model Name | Size (MB/Params)| Pre-Segmentation | Hardware | Sequence Length | Batch Size | Num of Steps | Total Time (in Days) |
57
+ ---|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:
58
+ AraBERTv0.2-base | [bert-base-arabertv02](https://huggingface.co/aubmindlab/bert-base-arabertv02) | 543MB / 136M | No | TPUv3-8 | 128 /512 | 2560/384 | 1M/ 2M | 36 |
59
+ AraBERTv0.2-large| [bert-large-arabertv02](https://huggingface.co/aubmindlab/bert-large-arabertv02) | 1.38G / 371M | No | TPUv3-128 | 128 /512 | 13440 / 2056 | 250K / 300K | 7 |
60
+ AraBERTv2-base| [bert-base-arabertv2](https://huggingface.co/aubmindlab/bert-base-arabertv2) | 543MB / 136M | Yes | TPUv3-8 |128 /512 | 2560 / 384 | 1M / 2M | 36 |
61
+ AraBERTv2-large| [bert-large-arabertv2](https://huggingface.co/aubmindlab/bert-large-arabertv2) | 1.38G / 371M | Yes | TPUv3-128 |128 /512 | 13440 / 2056| 250K / 300K | 7 |
62
+ AraBERTv0.1-base| [bert-base-arabertv01](https://huggingface.co/aubmindlab/bert-base-arabertv01) | 543MB / 136M | No | TPUv2-8 |128 /512 |128 / 512 | 900K / 300K| 4 |
63
+ AraBERTv1-base| [bert-base-arabert](https://huggingface.co/aubmindlab/bert-base-arabert) | 543MB / 136M | Yes | TPUv2-8 |128 /512 |128 / 512 | 900K / 300K| 4 |
64
+ AraGPT2-base | [aragpt2-base](https://huggingface.co/aubmindlab/aragpt2-base) | 527MB/135M | No | TPUv3-128 | 1024 | 1792 | 125K | 1.5 |
65
+ AraGPT2-medium | [aragpt2-medium](https://huggingface.co/aubmindlab/aragpt2-medium) | 1.38G/370M | No |TPUv3-8 | 1024 | 80 | 1M | 15 |
66
+ AraGPT2-large | [aragpt2-large](https://huggingface.co/aubmindlab/aragpt2-large) | 2.98GB/792M | No |TPUv3-128 | 1024 | 256 | 220k | 3 |
67
+ AraGPT2-mega | [aragpt2-mega](https://huggingface.co/aubmindlab/aragpt2-mega) | 5.5GB/1.46B |No |TPUv3-128 | 1024 | 256 | 800K | 9 |
68
+ AraELECTRA-base-generator | [araelectra-base-generator](https://huggingface.co/aubmindlab/araelectra-base-generator) | 227MB/60M | No | TPUv3-8 | 512 | 256 | 2M | 24
69
+ AraELECTRA-base-discriminator | [araelectra-base-discriminator](https://huggingface.co/aubmindlab/araelectra-base-discriminator) | 516MB/135M | No | TPUv3-8 | 512 | 256 | 2M | 24
70
+ AraBERTv0.2-Twitter-base| [bert-base-arabertv02-twitter](https://huggingface.co/aubmindlab/bert-base-arabertv02-twitter) | 543MB / 136M | No | V100 | *64* | - | - | - |
71
+ AraBERTv0.2-Twitter-large| [bert-large-arabertv02-twitter](https://huggingface.co/aubmindlab/bert-large-arabertv02-twitter) | 1.38G / 371M | No | V100 | *64* | - | - | - |
72
+
73
+ All models are available in the `HuggingFace` model page under the [aubmindlab](https://huggingface.co/aubmindlab/) name. Checkpoints are available in PyTorch, TF2 and TF1 formats.
74
+
75
+ # Preprocessing
76
+
77
+ You can test the Arabic Preprocessing pipeline in the Arabic Text Preprocessing page.
78
+
79
+ It is recommended to apply our preprocessing function before training/testing on any dataset.
80
+ **Install farasapy to segment text for AraBERT v1 & v2 `pip install farasapy`**
81
+
82
+ ```python
83
+ from arabert.preprocess import ArabertPreprocessor
84
+
85
+ model_name = "aubmindlab/bert-base-arabertv2"
86
+ arabert_prep = ArabertPreprocessor(model_name=model_name)
87
+
88
+ text = "ولن نبالغ إذا قلنا: إن 'هاتف' أو 'كمبيوتر المكتب' في زمننا هذا ضروري"
89
+ arabert_prep.preprocess(text)
90
+ >>>"و+ لن نبالغ إذا قل +نا : إن ' هاتف ' أو ' كمبيوتر ال+ مكتب ' في زمن +نا هذا ضروري"
91
+ ```
92
+
93
+ You can also use the `unpreprocess()` function to reverse the preprocessing changes, by fixing the spacing around non alphabetical characters, and also de-segmenting if the model selected need pre-segmentation. We highly recommend unprocessing generated content of `AraGPT2` model, to make it look more natural.
94
+ ```python
95
+ output_text = "و+ لن نبالغ إذا قل +نا : إن ' هاتف ' أو ' كمبيوتر ال+ مكتب ' في زمن +نا هذا ضروري"
96
+ arabert_prep.unpreprocess(output_text)
97
+ >>>"ولن نبالغ إذا قلنا: إن 'هاتف' أو 'كمبيوتر المكتب' في زمننا هذا ضروري"
98
+ ```
99
+
100
+ # If you used this model please cite us as :
101
+
102
+ ## AraBERT
103
+ Google Scholar has our Bibtex wrong (missing name), use this instead
104
+ ```
105
+ @inproceedings{antoun2020arabert,
106
+ title={AraBERT: Transformer-based Model for Arabic Language Understanding},
107
+ author={Antoun, Wissam and Baly, Fady and Hajj, Hazem},
108
+ booktitle={LREC 2020 Workshop Language Resources and Evaluation Conference 11--16 May 2020},
109
+ pages={9}
110
+ }
111
+ ```
112
+ ## AraGPT2
113
+ ```
114
+ @inproceedings{antoun-etal-2021-aragpt2,
115
+ title = "{A}ra{GPT}2: Pre-Trained Transformer for {A}rabic Language Generation",
116
+ author = "Antoun, Wissam and
117
+ Baly, Fady and
118
+ Hajj, Hazem",
119
+ booktitle = "Proceedings of the Sixth Arabic Natural Language Processing Workshop",
120
+ month = apr,
121
+ year = "2021",
122
+ address = "Kyiv, Ukraine (Virtual)",
123
+ publisher = "Association for Computational Linguistics",
124
+ url = "https://www.aclweb.org/anthology/2021.wanlp-1.21",
125
+ pages = "196--207",
126
+ }
127
+ ```
128
+
129
+ ## AraELECTRA
130
+ ```
131
+ @inproceedings{antoun-etal-2021-araelectra,
132
+ title = "{A}ra{ELECTRA}: Pre-Training Text Discriminators for {A}rabic Language Understanding",
133
+ author = "Antoun, Wissam and
134
+ Baly, Fady and
135
+ Hajj, Hazem",
136
+ booktitle = "Proceedings of the Sixth Arabic Natural Language Processing Workshop",
137
+ month = apr,
138
+ year = "2021",
139
+ address = "Kyiv, Ukraine (Virtual)",
140
+ publisher = "Association for Computational Linguistics",
141
+ url = "https://www.aclweb.org/anthology/2021.wanlp-1.20",
142
+ pages = "191--195",
143
+ }
144
+ ```
145
+
146
+
147
+ # Acknowledgments
148
+ Thanks to TensorFlow Research Cloud (TFRC) for the free access to Cloud TPUs, couldn't have done it without this program, and to the [AUB MIND Lab](https://sites.aub.edu.lb/mindlab/) Members for the continous support. Also thanks to [Yakshof](https://www.yakshof.com/#/) and Assafir for data and storage access. Another thanks for Habib Rahal (https://www.behance.net/rahalhabib), for putting a face to AraBERT.
149
+
150
+ # Contacts
151
+ **Wissam Antoun**: [Linkedin](https://www.linkedin.com/in/wissam-antoun-622142b4/) | [Twitter](https://twitter.com/wissam_antoun) | [Github](https://github.com/WissamAntoun) | wfa07 (AT) mail (DOT) aub (DOT) edu | wissam.antoun (AT) gmail (DOT) com
152
+
153
+ **Fady Baly**: [Linkedin](https://www.linkedin.com/in/fadybaly/) | [Twitter](https://twitter.com/fadybaly) | [Github](https://github.com/fadybaly) | fgb06 (AT) mail (DOT) aub (DOT) edu | baly.fady (AT) gmail (DOT) com
154
+
155
+ """
156
+ )
backend/modeling_gpt2.py ADDED
@@ -0,0 +1,1599 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ """
18
+ PyTorch OpenAI GPT-2 model.
19
+ Adapted from https://github.com/huggingface/transformers/blob/v4.0.1/src/transformers/models/gpt2/modeling_gpt2.py
20
+ and https://github.com/ghosthamlet/gpt2-ml-torch/blob/master/gpt2_ml_torch/modeling_gpt2.py
21
+ """
22
+
23
+
24
+ import logging
25
+ import os
26
+ from dataclasses import dataclass
27
+ from typing import List, Optional, Tuple
28
+
29
+ import torch
30
+ import torch.nn as nn
31
+ from torch.nn import CrossEntropyLoss, MSELoss
32
+ from transformers import CONFIG_NAME, WEIGHTS_NAME, GPT2Config, GPT2Model
33
+ from transformers.activations import ACT2FN
34
+ from transformers.file_utils import (
35
+ ModelOutput,
36
+ add_code_sample_docstrings,
37
+ add_start_docstrings,
38
+ add_start_docstrings_to_model_forward,
39
+ replace_return_docstrings,
40
+ )
41
+ from transformers.modeling_outputs import (
42
+ BaseModelOutputWithPastAndCrossAttentions,
43
+ CausalLMOutputWithCrossAttentions,
44
+ SequenceClassifierOutputWithPast,
45
+ TokenClassifierOutput,
46
+ )
47
+ from transformers.modeling_utils import (
48
+ Conv1D,
49
+ PreTrainedModel,
50
+ SequenceSummary,
51
+ find_pruneable_heads_and_indices,
52
+ prune_conv1d_layer,
53
+ )
54
+ from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
55
+
56
+ # THe Difference from Transformers is code under _USE_GROVER
57
+ _USE_GROVER = True
58
+
59
+ logger = logging.getLogger(__name__)
60
+
61
+ _CONFIG_FOR_DOC = "GPT2Config"
62
+ _TOKENIZER_FOR_DOC = "GPT2Tokenizer"
63
+
64
+ GPT2_PRETRAINED_MODEL_ARCHIVE_LIST = [
65
+ "gpt2",
66
+ "gpt2-medium",
67
+ "gpt2-large",
68
+ "gpt2-xl",
69
+ "distilgpt2",
70
+ # See all GPT-2 models at https://huggingface.co/models?filter=gpt2
71
+ ]
72
+
73
+ logger.setLevel(logging.INFO)
74
+ console = logging.StreamHandler()
75
+ console.setLevel(logging.INFO)
76
+ logger.addHandler(console)
77
+
78
+ _GPT2_ML_TF_TO_TORCH = {
79
+ "LayerNorm_embed_norm": "emb_norm",
80
+ "pos_embed": "wpe.weight",
81
+ "word_embed": "wte.weight",
82
+ "layer": "h",
83
+ # Most importently This two layer norm must be put on the same position as gpt2-ml
84
+ # or generated data is bad, just repeat the last token
85
+ "LayerNorm_mlp_ln0": "ln_1",
86
+ "LayerNorm_mlp_ln1": "ln_2",
87
+ "intermediate": "mlp.c_fc",
88
+ "output": "mlp.c_proj",
89
+ "query_layer": "attn.c_attn",
90
+ "key_layer": "attn.c_attn",
91
+ "value_layer": "attn.c_attn",
92
+ "context_projection_layer": "attn.c_proj",
93
+ "gamma": "weight",