xzxyx abxhr commited on
Commit
b71b65d
0 Parent(s):

Duplicate from abxhr/design-project

Browse files

Co-authored-by: Abshar Mohammed Aslam <abxhr@users.noreply.huggingface.co>

README.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Arabic NLP Demo
3
+ emoji: ⌨
4
+ colorFrom: purple
5
+ colorTo: green
6
+ sdk: streamlit
7
+ app_file: app.py
8
+ pinned: true
9
+ duplicated_from: abxhr/design-project
10
+ ---
app.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import awesome_streamlit as ast
2
+ import streamlit as st
3
+
4
+ from backend.utils import get_current_ram_usage, ga
5
+
6
+ import backend.aragpt
7
+ import backend.home
8
+ import backend.processor
9
+ import backend.sa
10
+ import backend.qa
11
+
12
+ st.set_page_config(
13
+ page_title="TEST", page_icon="📖", initial_sidebar_state="expanded", layout="wide"
14
+ )
15
+
16
+ ga(st.__file__)
17
+
18
+ PAGES = {
19
+ "Home": backend.home,
20
+ "Arabic Sentiment Analysis": backend.sa,
21
+ "Arabic Question Answering": backend.qa,
22
+ }
23
+
24
+
25
+ st.sidebar.title("Navigation")
26
+ selection = st.sidebar.radio("Pages", list(PAGES.keys()))
27
+
28
+ page = PAGES[selection]
29
+ # with st.spinner(f"Loading {selection} ..."):
30
+ ast.shared.components.write_page(page)
31
+
32
+ st.sidebar.header("Info")
33
+ st.sidebar.write("Arabic NLP by [**Abshar Mohammed Aslam** (*2019A7PS0233U*)](https://github.com/abxhr)")
34
+
35
+ st.sidebar.write("Submitted to *Dr. Sujala D. Shetty*")
36
+ # if st.sidebar.checkbox("Show RAM usage"):
37
+ # ram = get_current_ram_usage()
38
+ # st.sidebar.write("Ram usage: {:.2f}/{:.2f} GB".format(ram[0], ram[1]))
backend/__init__.py ADDED
File without changes
backend/aragpt.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from .services import TextGeneration
3
+ from tokenizers import Tokenizer
4
+ from functools import lru_cache
5
+
6
+ # @st.cache(allow_output_mutation=False, hash_funcs={Tokenizer: str})
7
+ @lru_cache(maxsize=1)
8
+ def load_text_generator():
9
+ generator = TextGeneration()
10
+ generator.load()
11
+ return generator
12
+
13
+
14
+ generator = load_text_generator()
15
+
16
+ qa_prompt = """
17
+ أجب عن السؤال التالي:
18
+ """
19
+ qa_prompt_post = """ الجواب هو """
20
+ qa_prompt_post_year = """ في سنة: """
21
+
22
+
23
+ def write():
24
+ st.markdown(
25
+ """
26
+ <h1 style="text-align:left;">Arabic Language Generation</h1>
27
+ """,
28
+ unsafe_allow_html=True,
29
+ )
30
+
31
+ # Sidebar
32
+
33
+ # Taken from https://huggingface.co/spaces/flax-community/spanish-gpt2/blob/main/app.py
34
+ st.sidebar.subheader("Configurable parameters")
35
+
36
+ model_name = st.sidebar.selectbox(
37
+ "Model Selector",
38
+ options=[
39
+ "AraGPT2-Base",
40
+ # "AraGPT2-Medium",
41
+ # "Aragpt2-Large",
42
+ "AraGPT2-Mega",
43
+ ],
44
+ index=0,
45
+ )
46
+
47
+ max_new_tokens = st.sidebar.number_input(
48
+ "Maximum length",
49
+ min_value=0,
50
+ max_value=1024,
51
+ value=100,
52
+ help="The maximum length of the sequence to be generated.",
53
+ )
54
+ temp = st.sidebar.slider(
55
+ "Temperature",
56
+ value=1.0,
57
+ min_value=0.1,
58
+ max_value=100.0,
59
+ help="The value used to module the next token probabilities.",
60
+ )
61
+ top_k = st.sidebar.number_input(
62
+ "Top k",
63
+ value=10,
64
+ help="The number of highest probability vocabulary tokens to keep for top-k-filtering.",
65
+ )
66
+ top_p = st.sidebar.number_input(
67
+ "Top p",
68
+ value=0.95,
69
+ help=" If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.",
70
+ )
71
+ do_sample = st.sidebar.selectbox(
72
+ "Sampling?",
73
+ (True, False),
74
+ help="Whether or not to use sampling; use greedy decoding otherwise.",
75
+ )
76
+ num_beams = st.sidebar.number_input(
77
+ "Number of beams",
78
+ min_value=1,
79
+ max_value=10,
80
+ value=3,
81
+ help="The number of beams to use for beam search.",
82
+ )
83
+ repetition_penalty = st.sidebar.number_input(
84
+ "Repetition Penalty",
85
+ min_value=0.0,
86
+ value=3.0,
87
+ step=0.1,
88
+ help="The parameter for repetition penalty. 1.0 means no penalty",
89
+ )
90
+ no_repeat_ngram_size = st.sidebar.number_input(
91
+ "No Repeat N-Gram Size",
92
+ min_value=0,
93
+ value=3,
94
+ help="If set to int > 0, all ngrams of that size can only occur once.",
95
+ )
96
+
97
+ st.write("#")
98
+
99
+ col = st.columns(2)
100
+
101
+ col[0].image("images/AraGPT2.png", width=200)
102
+
103
+ st.markdown(
104
+ """
105
+
106
+ <h3 style="text-align:left;">AraGPT2 is GPT2 model trained from scratch on 77GB of Arabic text.</h3>
107
+ <h4 style="text-align:left;"> More details in our <a href="https://github.com/aub-mind/arabert/tree/master/aragpt2">repo</a>.</h4>
108
+
109
+ <p style="text-align:left;"><p>
110
+ <p style="text-align:left;">Use the generation paramters on the sidebar to adjust generation quality.</p>
111
+ <p style="text-align:right;"><p>
112
+ """,
113
+ unsafe_allow_html=True,
114
+ )
115
+
116
+ # col[0].write(
117
+ # "AraGPT2 is trained from screatch on 77GB of Arabic text. More details in our [repo](https://github.com/aub-mind/arabert/tree/master/aragpt2)."
118
+ # )
119
+ # st.write("## Generate Arabic Text")
120
+
121
+ st.markdown(
122
+ """
123
+ <style>
124
+ p, div, input, label, textarea{
125
+ text-align: right;
126
+ }
127
+ </style>
128
+ """,
129
+ unsafe_allow_html=True,
130
+ )
131
+
132
+ prompt = st.text_area(
133
+ "Prompt",
134
+ "يحكى أن مزارعا مخادعا قام ببيع بئر الماء الموجود في أرضه لجاره مقابل مبلغ كبير من المال",
135
+ )
136
+ if st.button("Generate"):
137
+ with st.spinner("Generating..."):
138
+ generated_text = generator.generate(
139
+ prompt=prompt,
140
+ model_name=model_name,
141
+ max_new_tokens=max_new_tokens,
142
+ temperature=temp,
143
+ top_k=top_k,
144
+ top_p=top_p,
145
+ repetition_penalty=repetition_penalty,
146
+ do_sample=do_sample,
147
+ num_beams=num_beams,
148
+ no_repeat_ngram_size=no_repeat_ngram_size,
149
+ )
150
+ st.write(generated_text)
151
+
152
+ st.markdown("---")
153
+ st.subheader("")
154
+ st.markdown(
155
+ """
156
+ <p style="text-align:left;"><p>
157
+ <h2 style="text-align:left;">Zero-Shot Question Answering</h2>
158
+
159
+ <p style="text-align:left;">Adjust the maximum length to closely match the expected output length. Setting the Sampling paramter to False is recommended</p>
160
+ <p style="text-align:left;"><p>
161
+ """,
162
+ unsafe_allow_html=True,
163
+ )
164
+
165
+ question = st.text_input(
166
+ "Question", "من كان رئيس ألمانيا النازية في الحرب العالمية الثانية ؟"
167
+ )
168
+ is_date = st.checkbox("Help the model: Is the answer a date?")
169
+ if st.button("Answer"):
170
+
171
+ prompt2 = qa_prompt + question + qa_prompt_post
172
+ if is_date:
173
+ prompt2 += qa_prompt_post_year
174
+ else:
175
+ prompt2 += " : "
176
+ with st.spinner("Thinking..."):
177
+ answer = generator.generate(
178
+ prompt=prompt2,
179
+ model_name=model_name,
180
+ max_new_tokens=max_new_tokens,
181
+ temperature=temp,
182
+ top_k=top_k,
183
+ top_p=top_p,
184
+ repetition_penalty=repetition_penalty,
185
+ do_sample=do_sample,
186
+ num_beams=num_beams,
187
+ no_repeat_ngram_size=no_repeat_ngram_size,
188
+ )
189
+ st.write(answer)
backend/home.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import awesome_streamlit as ast
3
+
4
+
5
+ def write():
6
+ st.markdown(
7
+ """
8
+ # Arabic Natural Language Processing
9
+
10
+ Design project for **Arabic Natural Language Processing**, by [**Abshar Mohammed Aslam**](https://github.com/abxhr).
11
+ """
12
+ )
13
+ st.markdown("#")
14
+
15
+
16
+ st.markdown(
17
+ """
18
+
19
+ """
20
+ )
backend/modeling_gpt2.py ADDED
@@ -0,0 +1,1599 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ """
18
+ PyTorch OpenAI GPT-2 model.
19
+ Adapted from https://github.com/huggingface/transformers/blob/v4.0.1/src/transformers/models/gpt2/modeling_gpt2.py
20
+ and https://github.com/ghosthamlet/gpt2-ml-torch/blob/master/gpt2_ml_torch/modeling_gpt2.py
21
+ """
22
+
23
+
24
+ import logging
25
+ import os
26
+ from dataclasses import dataclass
27
+ from typing import List, Optional, Tuple
28
+
29
+ import torch
30
+ import torch.nn as nn
31
+ from torch.nn import CrossEntropyLoss, MSELoss
32
+ from transformers import CONFIG_NAME, WEIGHTS_NAME, GPT2Config, GPT2Model
33
+ from transformers.activations import ACT2FN
34
+ from transformers.file_utils import (
35
+ ModelOutput,
36
+ add_code_sample_docstrings,
37
+ add_start_docstrings,
38
+ add_start_docstrings_to_model_forward,
39
+ replace_return_docstrings,
40
+ )
41
+ from transformers.modeling_outputs import (
42
+ BaseModelOutputWithPastAndCrossAttentions,
43
+ CausalLMOutputWithCrossAttentions,
44
+ SequenceClassifierOutputWithPast,
45
+ TokenClassifierOutput,
46
+ )
47
+ from transformers.modeling_utils import (
48
+ Conv1D,
49
+ PreTrainedModel,
50
+ SequenceSummary,
51
+ find_pruneable_heads_and_indices,
52
+ prune_conv1d_layer,
53
+ )
54
+ from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
55
+
56
+ # THe Difference from Transformers is code under _USE_GROVER
57
+ _USE_GROVER = True
58
+
59
+ logger = logging.getLogger(__name__)
60
+
61
+ _CONFIG_FOR_DOC = "GPT2Config"
62
+ _TOKENIZER_FOR_DOC = "GPT2Tokenizer"
63
+
64
+ GPT2_PRETRAINED_MODEL_ARCHIVE_LIST = [
65
+ "gpt2",
66
+ "gpt2-medium",
67
+ "gpt2-large",
68
+ "gpt2-xl",
69
+ "distilgpt2",
70
+ # See all GPT-2 models at https://huggingface.co/models?filter=gpt2
71
+ ]
72
+
73
+ logger.setLevel(logging.INFO)
74
+ console = logging.StreamHandler()
75
+ console.setLevel(logging.INFO)
76
+ logger.addHandler(console)
77
+
78
+ _GPT2_ML_TF_TO_TORCH = {
79
+ "LayerNorm_embed_norm": "emb_norm",
80
+ "pos_embed": "wpe.weight",
81
+ "word_embed": "wte.weight",
82
+ "layer": "h",
83
+ # Most importently This two layer norm must be put on the same position as gpt2-ml
84
+ # or generated data is bad, just repeat the last token
85
+ "LayerNorm_mlp_ln0": "ln_1",
86
+ "LayerNorm_mlp_ln1": "ln_2",
87
+ "intermediate": "mlp.c_fc",
88
+ "output": "mlp.c_proj",
89
+ "query_layer": "attn.c_attn",
90
+ "key_layer": "attn.c_attn",
91
+ "value_layer": "attn.c_attn",
92
+ "context_projection_layer": "attn.c_proj",
93
+ "gamma": "weight",
94
+ "kernel": "weight",
95
+ "beta": "bias",
96
+ "bias": "bias",
97
+ }
98
+
99
+
100
+ def convert_gpt2_checkpoint_to_pytorch(
101
+ gpt2_checkpoint_path, gpt2_config_file, pytorch_dump_folder_path
102
+ ):
103
+ # Construct model
104
+ if gpt2_config_file == "":
105
+ config = GPT2Config()
106
+ else:
107
+ config = GPT2Config.from_json_file(gpt2_config_file)
108
+ model = GPT2Model(config)
109
+
110
+ # Load weights from numpy
111
+ load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path)
112
+
113
+ # Save pytorch-model
114
+ pytorch_weights_dump_path = pytorch_dump_folder_path + "/" + WEIGHTS_NAME
115
+ pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME
116
+ print("Save PyTorch model to {}".format(pytorch_weights_dump_path))
117
+ torch.save(model.state_dict(), pytorch_weights_dump_path)
118
+ print("Save configuration file to {}".format(pytorch_config_dump_path))
119
+ with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
120
+ f.write(config.to_json_string())
121
+
122
+
123
+ # XXX: MUST do like: convert_gpt2_checkpoint_to_pytorch('./model.ckpt-100000', './mega.json', './')
124
+ # https://github.com/tensorflow/models/issues/2675#issuecomment-516595597
125
+ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
126
+ """Load tf checkpoints in a pytorch model"""
127
+ try:
128
+ import re
129
+
130
+ import tensorflow as tf
131
+ except ImportError:
132
+ logger.error(
133
+ "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
134
+ "https://www.tensorflow.org/install/ for installation instructions."
135
+ )
136
+ raise
137
+ tf_path = os.path.abspath(gpt2_checkpoint_path)
138
+ logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
139
+ # Load weights from TF model
140
+ init_vars = tf.train.list_variables(tf_path)
141
+ names = []
142
+ arrays = []
143
+ for name, shape in init_vars:
144
+ logger.info("Loading TF weight {} with shape {}".format(name, shape))
145
+ array = tf.train.load_variable(tf_path, name)
146
+ names.append(name)
147
+ arrays.append(array.squeeze())
148
+
149
+ import copy
150
+
151
+ orig_model = copy.deepcopy(model)
152
+
153
+ for name, array in zip(names, arrays):
154
+ name = name[6:] # skip "model/"
155
+ name = name.split("/")
156
+ pointer = model
157
+
158
+ attn_layer = ""
159
+ for m_name in name:
160
+ if re.fullmatch(r"[A-Za-z]+\d+", m_name):
161
+ scope_names = re.split(r"(\d+)", m_name)
162
+ else:
163
+ scope_names = [m_name]
164
+ sname = scope_names[0]
165
+
166
+ if sname == "" or sname == "embeddings":
167
+ continue
168
+ elif sname not in _GPT2_ML_TF_TO_TORCH:
169
+ print("=========================================================")
170
+ logger.info("Skip var name {}".format(scope_names))
171
+ pointer = None
172
+ break
173
+ else:
174
+ tname = _GPT2_ML_TF_TO_TORCH[sname]
175
+ if "." in tname:
176
+ parent, child = tname.split(".")
177
+ pointer = getattr(pointer, parent)
178
+ pointer = getattr(pointer, child)
179
+ else:
180
+ pointer = getattr(pointer, tname)
181
+
182
+ if tname == "attn.c_attn":
183
+ attn_layer = sname
184
+
185
+ if len(scope_names) >= 2:
186
+ num = int(scope_names[1])
187
+ pointer = pointer[num]
188
+
189
+ if pointer is None:
190
+ continue
191
+ if attn_layer == "":
192
+ try:
193
+ assert pointer.shape == array.shape
194
+ except AssertionError as e:
195
+ e.args += (pointer.shape, array.shape)
196
+ raise
197
+ logger.info(
198
+ "Initialize PyTorch weight {}, {}, {}".format(
199
+ name, array.mean(), pointer.mean()
200
+ )
201
+ )
202
+ if attn_layer == "":
203
+ pointer.data = torch.from_numpy(array)
204
+ else:
205
+ shape = pointer.shape
206
+ d = torch.from_numpy(array)
207
+ is_bias = len(shape) == 1
208
+ end = int(shape[0 if is_bias else 1] / 3)
209
+ m = dict(
210
+ query_layer=0,
211
+ key_layer=end,
212
+ value_layer=end * 2,
213
+ )
214
+ start = m[attn_layer]
215
+ end = start + end
216
+ if is_bias:
217
+ pointer.data[start:end] = d
218
+ else:
219
+ pointer.data[:, start:end] = d
220
+ logger.info(
221
+ "Initialize PyTorch weight {}, {}, {}".format(
222
+ name, array.mean(), pointer.mean()
223
+ )
224
+ )
225
+
226
+ for name, params in orig_model.named_parameters():
227
+ for n, p in model.named_parameters():
228
+ if name == n:
229
+ if params.equal(p):
230
+ print("--------------------------")
231
+ print(" %s not changed!" % n)
232
+ return model
233
+
234
+
235
+ class Attention(nn.Module):
236
+ def __init__(self, nx, n_ctx, config, scale=False, is_cross_attention=False):
237
+ super().__init__()
238
+
239
+ n_state = nx # in Attention: n_state=768 (nx=n_embd)
240
+ # [switch nx => n_state from Block to Attention to keep identical to TF implem]
241
+ assert n_state % config.n_head == 0
242
+ self.register_buffer(
243
+ "bias",
244
+ torch.tril(torch.ones((n_ctx, n_ctx), dtype=torch.uint8)).view(
245
+ 1, 1, n_ctx, n_ctx
246
+ ),
247
+ )
248
+ self.register_buffer("masked_bias", torch.tensor(-1e4))
249
+ self.n_head = config.n_head
250
+ self.split_size = n_state
251
+ self.scale = scale
252
+ self.is_cross_attention = is_cross_attention
253
+ if self.is_cross_attention:
254
+ self.c_attn = Conv1D(2 * n_state, nx)
255
+ self.q_attn = Conv1D(n_state, nx)
256
+ else:
257
+ self.c_attn = Conv1D(3 * n_state, nx)
258
+ self.c_proj = Conv1D(n_state, nx)
259
+ self.attn_dropout = nn.Dropout(config.attn_pdrop)
260
+ self.resid_dropout = nn.Dropout(config.resid_pdrop)
261
+ self.pruned_heads = set()
262
+
263
+ def prune_heads(self, heads):
264
+ if len(heads) == 0:
265
+ return
266
+ heads, index = find_pruneable_heads_and_indices(
267
+ heads, self.n_head, self.split_size // self.n_head, self.pruned_heads
268
+ )
269
+ index_attn = torch.cat(
270
+ [index, index + self.split_size, index + (2 * self.split_size)]
271
+ )
272
+
273
+ # Prune conv1d layers
274
+ self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
275
+ self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
276
+
277
+ # Update hyper params
278
+ self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads))
279
+ self.n_head = self.n_head - len(heads)
280
+ self.pruned_heads = self.pruned_heads.union(heads)
281
+
282
+ def _attn(
283
+ self, q, k, v, attention_mask=None, head_mask=None, output_attentions=False
284
+ ):
285
+ w = torch.matmul(q, k)
286
+ if self.scale:
287
+ w = w / (float(v.size(-1)) ** 0.5)
288
+ nd, ns = w.size(-2), w.size(-1)
289
+
290
+ if not self.is_cross_attention:
291
+ # if only "normal" attention layer implements causal mask
292
+ mask = self.bias[:, :, ns - nd : ns, :ns]
293
+ w = torch.where(mask.bool(), w, self.masked_bias.to(w.dtype))
294
+
295
+ if attention_mask is not None:
296
+ # Apply the attention mask
297
+ w = w + attention_mask
298
+
299
+ w = nn.Softmax(dim=-1)(w)
300
+ w = self.attn_dropout(w)
301
+
302
+ # Mask heads if we want to
303
+ if head_mask is not None:
304
+ w = w * head_mask
305
+
306
+ outputs = [torch.matmul(w, v)]
307
+ if output_attentions:
308
+ outputs.append(w)
309
+ return outputs
310
+
311
+ def merge_heads(self, x):
312
+ x = x.permute(0, 2, 1, 3).contiguous()
313
+ new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
314
+ return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states
315
+
316
+ def split_heads(self, x, k=False):
317
+ new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
318
+ x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states
319
+ if k:
320
+ return x.permute(0, 2, 3, 1) # (batch, head, head_features, seq_length)
321
+ else:
322
+ return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
323
+
324
+ def forward(
325
+ self,
326
+ hidden_states,
327
+ layer_past=None,
328
+ attention_mask=None,
329
+ head_mask=None,
330
+ encoder_hidden_states=None,
331
+ encoder_attention_mask=None,
332
+ use_cache=False,
333
+ output_attentions=False,
334
+ ):
335
+ if encoder_hidden_states is not None:
336
+ assert hasattr(
337
+ self, "q_attn"
338
+ ), "If class is used as cross attention, the weights `q_attn` have to be defined. Please make sure to instantiate class with `Attention(..., is_cross_attention=True)`."
339
+ query = self.q_attn(hidden_states)
340
+ key, value = self.c_attn(encoder_hidden_states).split(
341
+ self.split_size, dim=2
342
+ )
343
+ attention_mask = encoder_attention_mask
344
+ else:
345
+ query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
346
+
347
+ query = self.split_heads(query)
348
+ key = self.split_heads(key, k=True)
349
+ value = self.split_heads(value)
350
+ if layer_past is not None:
351
+ past_key, past_value = (
352
+ layer_past[0].transpose(-2, -1),
353
+ layer_past[1],
354
+ ) # transpose back cf below
355
+ key = torch.cat((past_key, key), dim=-1)
356
+ value = torch.cat((past_value, value), dim=-2)
357
+
358
+ if use_cache is True:
359
+ present = torch.stack(
360
+ (key.transpose(-2, -1), value)
361
+ ) # transpose to have same shapes for stacking
362
+ else:
363
+ present = (None,)
364
+
365
+ attn_outputs = self._attn(
366
+ query, key, value, attention_mask, head_mask, output_attentions
367
+ )
368
+ a = attn_outputs[0]
369
+
370
+ a = self.merge_heads(a)
371
+ a = self.c_proj(a)
372
+ a = self.resid_dropout(a)
373
+
374
+ outputs = [a, present] + attn_outputs[1:]
375
+ return outputs # a, present, (attentions)
376
+
377
+
378
+ class MLP(nn.Module):
379
+ def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd)
380
+ super().__init__()
381
+ nx = config.n_embd
382
+ self.c_fc = Conv1D(n_state, nx)
383
+ self.c_proj = Conv1D(nx, n_state)
384
+ self.act = ACT2FN[config.activation_function]
385
+ self.dropout = nn.Dropout(config.resid_pdrop)
386
+
387
+ def forward(self, x):
388
+ h = self.act(self.c_fc(x))
389
+ h2 = self.c_proj(h)
390
+ return self.dropout(h2)
391
+
392
+
393
+ class Block(nn.Module):
394
+ def __init__(self, n_ctx, config, scale=False):
395
+ super().__init__()
396
+ hidden_size = config.n_embd
397
+ inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
398
+ self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
399
+ self.attn = Attention(hidden_size, n_ctx, config, scale)
400
+ self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
401
+ if config.add_cross_attention:
402
+ self.crossattention = Attention(
403
+ hidden_size, n_ctx, config, scale, is_cross_attention=True
404
+ )
405
+ self.ln_cross_attn = nn.LayerNorm(
406
+ hidden_size, eps=config.layer_norm_epsilon
407
+ )
408
+ self.mlp = MLP(inner_dim, config)
409
+
410
+ def forward(
411
+ self,
412
+ hidden_states,
413
+ layer_past=None,
414
+ attention_mask=None,
415
+ head_mask=None,
416
+ encoder_hidden_states=None,
417
+ encoder_attention_mask=None,
418
+ use_cache=False,
419
+ output_attentions=False,
420
+ ):
421
+ attn_outputs = self.attn(
422
+ hidden_states,
423
+ layer_past=layer_past,
424
+ attention_mask=attention_mask,
425
+ head_mask=head_mask,
426
+ use_cache=use_cache,
427
+ output_attentions=output_attentions,
428
+ )
429
+ attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
430
+ outputs = attn_outputs[1:]
431
+ # residual connection
432
+ hidden_states = attn_output + hidden_states
433
+
434
+ if encoder_hidden_states is not None:
435
+ # add one self-attention block for cross-attention
436
+ assert hasattr(
437
+ self, "crossattention"
438
+ ), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
439
+ cross_attn_outputs = self.crossattention(
440
+ self.ln_cross_attn(hidden_states),
441
+ attention_mask=attention_mask,
442
+ head_mask=head_mask,
443
+ encoder_hidden_states=encoder_hidden_states,
444
+ encoder_attention_mask=encoder_attention_mask,
445
+ output_attentions=output_attentions,
446
+ )
447
+ attn_output = cross_attn_outputs[0]
448
+ # residual connection
449
+ hidden_states = hidden_states + attn_output
450
+ outputs = (
451
+ outputs + cross_attn_outputs[2:]
452
+ ) # add cross attentions if we output attention weights
453
+
454
+ feed_forward_hidden_states = self.mlp(self.ln_1(hidden_states))
455
+ # residual connection
456
+ hidden_states = hidden_states + feed_forward_hidden_states
457
+
458
+ hidden_states = self.ln_2(hidden_states)
459
+
460
+ outputs = [hidden_states] + outputs
461
+ return outputs # hidden_states, present, (attentions, cross_attentions)
462
+
463
+
464
+ class GPT2PreTrainedModel(PreTrainedModel):
465
+ """
466
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
467
+ models.
468
+ """
469
+
470
+ config_class = GPT2Config
471
+ load_tf_weights = load_tf_weights_in_gpt2
472
+ base_model_prefix = "transformer"
473
+ is_parallelizable = True
474
+
475
+ def __init__(self, *inputs, **kwargs):
476
+ super().__init__(*inputs, **kwargs)
477
+
478
+ def _init_weights(self, module):
479
+ """Initialize the weights."""
480
+ if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)):
481
+ # Slightly different from the TF version which uses truncated_normal for initialization
482
+ # cf https://github.com/pytorch/pytorch/pull/5617
483
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
484
+ if isinstance(module, (nn.Linear, Conv1D)) and module.bias is not None:
485
+ module.bias.data.zero_()
486
+ elif isinstance(module, nn.LayerNorm):
487
+ module.bias.data.zero_()
488
+ module.weight.data.fill_(1.0)
489
+
490
+
491
+ @dataclass
492
+ class GPT2DoubleHeadsModelOutput(ModelOutput):
493
+ """
494
+ Base class for outputs of models predicting if two sentences are consecutive or not.
495
+
496
+ Args:
497
+ loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided):
498
+ Language modeling loss.
499
+ mc_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`mc_labels` is provided):
500
+ Multiple choice classification loss.
501
+ logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`):
502
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
503
+ mc_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
504
+ Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
505
+ past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
506
+ List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape :obj:`(2,
507
+ batch_size, num_heads, sequence_length, embed_size_per_head)`).
508
+
509
+ Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
510
+ :obj:`past_key_values` input) to speed up sequential decoding.
511
+ hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
512
+ Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
513
+ of shape :obj:`(batch_size, sequence_length, hidden_size)`.
514
+
515
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
516
+ attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
517
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
518
+ sequence_length, sequence_length)`.
519
+
520
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
521
+ heads.
522
+ """
523
+
524
+ loss: Optional[torch.FloatTensor] = None
525
+ mc_loss: Optional[torch.FloatTensor] = None
526
+ logits: torch.FloatTensor = None
527
+ mc_logits: torch.FloatTensor = None
528
+ past_key_values: Optional[List[torch.FloatTensor]] = None
529
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
530
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
531
+
532
+
533
+ GPT2_START_DOCSTRING = r"""
534
+
535
+ This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
536
+ methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
537
+ pruning heads etc.)
538
+
539
+ This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__
540
+ subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
541
+ general usage and behavior.
542
+
543
+ Parameters:
544
+ config (:class:`~transformers.GPT2Config`): Model configuration class with all the parameters of the model.
545
+ Initializing with a config file does not load the weights associated with the model, only the
546
+ configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
547
+ weights.
548
+ """
549
+
550
+ GPT2_INPUTS_DOCSTRING = r"""
551
+ Args:
552
+ input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`):
553
+ :obj:`input_ids_length` = ``sequence_length`` if :obj:`past_key_values` is ``None`` else
554
+ ``past_key_values[0].shape[-2]`` (``sequence_length`` of input past key value states). Indices of input
555
+ sequence tokens in the vocabulary.
556
+
557
+ If :obj:`past_key_values` is used, only ``input_ids`` that do not have their past calculated should be
558
+ passed as ``input_ids``.
559
+
560
+ Indices can be obtained using :class:`~transformers.GPT2Tokenizer`. See
561
+ :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
562
+ details.
563
+
564
+ `What are input IDs? <../glossary.html#input-ids>`__
565
+ past_key_values (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
566
+ Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
567
+ :obj:`past_key_values` output below). Can be used to speed up sequential decoding. The ``input_ids`` which
568
+ have their past given to this model should not be passed as ``input_ids`` as they have already been
569
+ computed.
570
+ attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
571
+ Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
572
+
573
+ - 1 for tokens that are **not masked**,
574
+ - 0 for tokens that are **masked**.
575
+
576
+ `What are attention masks? <../glossary.html#attention-mask>`__
577
+ token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`, `optional`):
578
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,
579
+ 1]``:
580
+
581
+ - 0 corresponds to a `sentence A` token,
582
+ - 1 corresponds to a `sentence B` token.
583
+
584
+ `What are token type IDs? <../glossary.html#token-type-ids>`_
585
+ position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
586
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
587
+ config.max_position_embeddings - 1]``.
588
+
589
+ `What are position IDs? <../glossary.html#position-ids>`_
590
+ head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
591
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:
592
+
593
+ - 1 indicates the head is **not masked**,
594
+ - 0 indicates the head is **masked**.
595
+
596
+ inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
597
+ Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
598
+ This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
599
+ vectors than the model's internal embedding lookup matrix.
600
+
601
+ If :obj:`past_key_values` is used, optionally only the last :obj:`inputs_embeds` have to be input (see
602
+ :obj:`past_key_values`).
603
+ use_cache (:obj:`bool`, `optional`):
604
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
605
+ decoding (see :obj:`past_key_values`).
606
+ output_attentions (:obj:`bool`, `optional`):
607
+ Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
608
+ tensors for more detail.
609
+ output_hidden_states (:obj:`bool`, `optional`):
610
+ Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
611
+ more detail.
612
+ return_dict (:obj:`bool`, `optional`):
613
+ Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
614
+ """
615
+
616
+ PARALLELIZE_DOCSTRING = r"""
617
+ This is an experimental feature and is a subject to change at a moment's notice.
618
+
619
+ Uses a device map to distribute attention modules of the model across several devices. If no device map is given,
620
+ it will evenly distribute blocks across all devices.
621
+
622
+ Args:
623
+ device_map (:obj:`Dict[int, list]`, optional, defaults to None):
624
+ A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always
625
+ automatically mapped to the first device (for esoteric reasons). That means that the first device should
626
+ have fewer attention modules mapped to it than other devices. For reference, the gpt2 models have the
627
+ following number of attention modules:
628
+
629
+ - gpt2: 12
630
+ - gpt2-medium: 24
631
+ - gpt2-large: 36
632
+ - gpt2-xl: 48
633
+
634
+ Example::
635
+
636
+ # Here is an example of a device map on a machine with 4 GPUs using gpt2-xl, which has a total of 48 attention modules:
637
+ model = GPT2LMHeadModel.from_pretrained('gpt2-xl')
638
+ device_map = {0: [0, 1, 2, 3, 4, 5, 6, 7, 8],
639
+
640
+ 1: [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21],
641
+ 2: [22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34],
642
+ 3: [35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47]}
643
+ model.parallelize(device_map)
644
+ """
645
+ DEPARALLELIZE_DOCSTRING = r"""
646
+ Moves the model to cpu from a model parallel state.
647
+
648
+ Example::
649
+
650
+ # On a 4 GPU machine with gpt2-large:
651
+ model = GPT2LMHeadModel.from_pretrained('gpt2-large')
652
+ device_map = {0: [0, 1, 2, 3, 4, 5, 6, 7],
653
+
654
+ 1: [8, 9, 10, 11, 12, 13, 14, 15],
655
+ 2: [16, 17, 18, 19, 20, 21, 22, 23],
656
+ 3: [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35]}
657
+ model.parallelize(device_map) # Splits the model across several devices
658
+ model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache()
659
+ """
660
+
661
+
662
+ @add_start_docstrings(
663
+ "The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.",
664
+ GPT2_START_DOCSTRING,
665
+ )
666
+ class GPT2Model(GPT2PreTrainedModel):
667
+ def __init__(self, config):
668
+ super().__init__(config)
669
+
670
+ self.wte = nn.Embedding(config.vocab_size, config.n_embd)
671
+ self.wpe = nn.Embedding(config.n_positions, config.n_embd)
672
+ if _USE_GROVER:
673
+ self.emb_norm = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
674
+
675
+ self.drop = nn.Dropout(config.embd_pdrop)
676
+ self.h = nn.ModuleList(
677
+ [Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)]
678
+ )
679
+ if not _USE_GROVER:
680
+ self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
681
+
682
+ self.init_weights()
683
+
684
+ # Model parallel
685
+ self.model_parallel = False
686
+ self.device_map = None
687
+
688
+ @add_start_docstrings(PARALLELIZE_DOCSTRING)
689
+ def parallelize(self, device_map=None):
690
+ # Check validity of device_map
691
+ self.device_map = (
692
+ get_device_map(len(self.h), range(torch.cuda.device_count()))
693
+ if device_map is None
694
+ else device_map
695
+ )
696
+ assert_device_map(self.device_map, len(self.h))
697
+ self.model_parallel = True
698
+ self.first_device = (
699
+ "cpu"
700
+ if "cpu" in self.device_map.keys()
701
+ else "cuda:" + str(min(self.device_map.keys()))
702
+ )
703
+ self.last_device = "cuda:" + str(max(self.device_map.keys()))
704
+ self.wte = self.wte.to(self.first_device)
705
+ self.wpe = self.wpe.to(self.first_device)
706
+ # Load onto devices
707
+ for k, v in self.device_map.items():
708
+ for block in v:
709
+ cuda_device = "cuda:" + str(k)
710
+ self.h[block] = self.h[block].to(cuda_device)
711
+ # ln_f to last
712
+ self.ln_f = self.ln_f.to(self.last_device)
713
+
714
+ @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
715
+ def deparallelize(self):
716
+ self.model_parallel = False
717
+ self.device_map = None
718
+ self.first_device = "cpu"
719
+ self.last_device = "cpu"
720
+ self.wte = self.wte.to("cpu")
721
+ self.wpe = self.wpe.to("cpu")
722
+ for index in range(len(self.h)):
723
+ self.h[index] = self.h[index].to("cpu")
724
+ self.ln_f = self.ln_f.to("cpu")
725
+ torch.cuda.empty_cache()
726
+
727
+ def get_input_embeddings(self):
728
+ return self.wte
729
+
730
+ def set_input_embeddings(self, new_embeddings):
731
+ self.wte = new_embeddings
732
+
733
+ def _prune_heads(self, heads_to_prune):
734
+ """
735
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
736
+ """
737
+ for layer, heads in heads_to_prune.items():
738
+ self.h[layer].attn.prune_heads(heads)
739
+
740
+ @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
741
+ @add_code_sample_docstrings(
742
+ tokenizer_class=_TOKENIZER_FOR_DOC,
743
+ checkpoint="gpt2",
744
+ output_type=BaseModelOutputWithPastAndCrossAttentions,
745
+ config_class=_CONFIG_FOR_DOC,
746
+ )
747
+ def forward(
748
+ self,
749
+ input_ids=None,
750
+ past_key_values=None,
751
+ attention_mask=None,
752
+ token_type_ids=None,
753
+ position_ids=None,
754
+ head_mask=None,
755
+ inputs_embeds=None,
756
+ encoder_hidden_states=None,
757
+ encoder_attention_mask=None,
758
+ use_cache=None,
759
+ output_attentions=None,
760
+ output_hidden_states=None,
761
+ return_dict=None,
762
+ ):
763
+ output_attentions = (
764
+ output_attentions
765
+ if output_attentions is not None
766
+ else self.config.output_attentions
767
+ )
768
+ output_hidden_states = (
769
+ output_hidden_states
770
+ if output_hidden_states is not None
771
+ else self.config.output_hidden_states
772
+ )
773
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
774
+ return_dict = (
775
+ return_dict if return_dict is not None else self.config.use_return_dict
776
+ )
777
+
778
+ if input_ids is not None and inputs_embeds is not None:
779
+ raise ValueError(
780
+ "You cannot specify both input_ids and inputs_embeds at the same time"
781
+ )
782
+ elif input_ids is not None:
783
+ input_shape = input_ids.size()
784
+ input_ids = input_ids.view(-1, input_shape[-1])
785
+ batch_size = input_ids.shape[0]
786
+ elif inputs_embeds is not None:
787
+ input_shape = inputs_embeds.size()[:-1]
788
+ batch_size = inputs_embeds.shape[0]
789
+ else:
790
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
791
+
792
+ if token_type_ids is not None:
793
+ token_type_ids = token_type_ids.view(-1, input_shape[-1])
794
+ if position_ids is not None:
795
+ position_ids = position_ids.view(-1, input_shape[-1])
796
+
797
+ if past_key_values is None:
798
+ past_length = 0
799
+ past_key_values = [None] * len(self.h)
800
+ else:
801
+ past_length = past_key_values[0][0].size(-2)
802
+ if position_ids is None:
803
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
804
+ position_ids = torch.arange(
805
+ past_length,
806
+ input_shape[-1] + past_length,
807
+ dtype=torch.long,
808
+ device=device,
809
+ )
810
+ position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
811
+
812
+ # Attention mask.
813
+ if attention_mask is not None:
814
+ if batch_size <= 0:
815
+ raise ValueError("batch_size has to be defined and > 0")
816
+ attention_mask = attention_mask.view(batch_size, -1)
817
+ # We create a 3D attention mask from a 2D tensor mask.
818
+ # Sizes are [batch_size, 1, 1, to_seq_length]
819
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
820
+ # this attention mask is more simple than the triangular masking of causal attention
821
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
822
+ attention_mask = attention_mask[:, None, None, :]
823
+
824
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
825
+ # masked positions, this operation will create a tensor which is 0.0 for
826
+ # positions we want to attend and -10000.0 for masked positions.
827
+ # Since we are adding it to the raw scores before the softmax, this is
828
+ # effectively the same as removing these entirely.
829
+ attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
830
+ attention_mask = (1.0 - attention_mask) * -10000.0
831
+
832
+ # If a 2D ou 3D attention mask is provided for the cross-attention
833
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
834
+ if self.config.add_cross_attention and encoder_hidden_states is not None:
835
+ (
836
+ encoder_batch_size,
837
+ encoder_sequence_length,
838
+ _,
839
+ ) = encoder_hidden_states.size()
840
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
841
+ if encoder_attention_mask is None:
842
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
843
+ encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
844
+ else:
845
+ encoder_attention_mask = None
846
+
847
+ # Prepare head mask if needed
848
+ # 1.0 in head_mask indicate we keep the head
849
+ # attention_probs has shape bsz x n_heads x N x N
850
+ # head_mask has shape n_layer x batch x n_heads x N x N
851
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
852
+
853
+ if inputs_embeds is None:
854
+ inputs_embeds = self.wte(input_ids)
855
+ position_embeds = self.wpe(position_ids)
856
+ hidden_states = inputs_embeds + position_embeds
857
+
858
+ if token_type_ids is not None:
859
+ token_type_embeds = self.wte(token_type_ids)
860
+ hidden_states = hidden_states + token_type_embeds
861
+
862
+ hidden_states = self.drop(hidden_states)
863
+ if _USE_GROVER:
864
+ hidden_states = self.emb_norm(hidden_states)
865
+ output_shape = input_shape + (hidden_states.size(-1),)
866
+
867
+ presents = () if use_cache else None
868
+ all_self_attentions = () if output_attentions else None
869
+ all_cross_attentions = (
870
+ () if output_attentions and self.config.add_cross_attention else None
871
+ )
872
+ all_hidden_states = () if output_hidden_states else None
873
+ for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
874
+
875
+ # Model parallel
876
+ if self.model_parallel:
877
+ torch.cuda.set_device(hidden_states.device)
878
+ # Ensure layer_past is on same device as hidden_states (might not be correct)
879
+ if layer_past is not None:
880
+ layer_past = tuple(
881
+ past_state.to(hidden_states.device) for past_state in layer_past
882
+ )
883
+ # Ensure that attention_mask is always on the same device as hidden_states
884
+ if attention_mask is not None:
885
+ attention_mask = attention_mask.to(hidden_states.device)
886
+ if isinstance(head_mask, torch.Tensor):
887
+ head_mask = head_mask.to(hidden_states.device)
888
+
889
+ if output_hidden_states:
890
+ all_hidden_states = all_hidden_states + (
891
+ hidden_states.view(*output_shape),
892
+ )
893
+
894
+ if getattr(self.config, "gradient_checkpointing", False):
895
+
896
+ def create_custom_forward(module):
897
+ def custom_forward(*inputs):
898
+ # checkpointing only works with tuple returns, not with lists
899
+ return tuple(
900
+ output
901
+ for output in module(*inputs, use_cache, output_attentions)
902
+ )
903
+
904
+ return custom_forward
905
+
906
+ outputs = torch.utils.checkpoint.checkpoint(
907
+ create_custom_forward(block),
908
+ hidden_states,
909
+ layer_past,
910
+ attention_mask,
911
+ head_mask[i],
912
+ encoder_hidden_states,
913
+ encoder_attention_mask,
914
+ )
915
+ else:
916
+ outputs = block(
917
+ hidden_states,
918
+ layer_past=layer_past,
919
+ attention_mask=attention_mask,
920
+ head_mask=head_mask[i],
921
+ encoder_hidden_states=encoder_hidden_states,
922
+ encoder_attention_mask=encoder_attention_mask,
923
+ use_cache=use_cache,
924
+ output_attentions=output_attentions,
925
+ )
926
+
927
+ hidden_states, present = outputs[:2]
928
+ if use_cache is True:
929
+ presents = presents + (present,)
930
+
931
+ if output_attentions:
932
+ all_self_attentions = all_self_attentions + (
933
+ outputs[2 if use_cache else 1],
934
+ )
935
+ if self.config.add_cross_attention:
936
+ all_cross_attentions = all_cross_attentions + (
937
+ outputs[3 if use_cache else 2],
938
+ )
939
+
940
+ # Model Parallel: If it's the last layer for that device, put things on the next device
941
+ if self.model_parallel:
942
+ for k, v in self.device_map.items():
943
+ if i == v[-1] and "cuda:" + str(k) != self.last_device:
944
+ hidden_states = hidden_states.to("cuda:" + str(k + 1))
945
+
946
+ if not _USE_GROVER:
947
+ hidden_states = self.ln_f(hidden_states)
948
+
949
+ hidden_states = hidden_states.view(*output_shape)
950
+ # Add last hidden state
951
+ if output_hidden_states:
952
+ all_hidden_states = all_hidden_states + (hidden_states,)
953
+
954
+ if not return_dict:
955
+ return tuple(
956
+ v
957
+ for v in [
958
+ hidden_states,
959
+ presents,
960
+ all_hidden_states,
961
+ all_self_attentions,
962
+ all_cross_attentions,
963
+ ]
964
+ if v is not None
965
+ )
966
+
967
+ return BaseModelOutputWithPastAndCrossAttentions(
968
+ last_hidden_state=hidden_states,
969
+ past_key_values=presents,
970
+ hidden_states=all_hidden_states,
971
+ attentions=all_self_attentions,
972
+ cross_attentions=all_cross_attentions,
973
+ )
974
+
975
+
976
+ @add_start_docstrings(
977
+ """
978
+ The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input
979
+ embeddings).
980
+ """,
981
+ GPT2_START_DOCSTRING,
982
+ )
983
+ class GPT2LMHeadModel(GPT2PreTrainedModel):
984
+ _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"]
985
+
986
+ def __init__(self, config):
987
+ super().__init__(config)
988
+ self.transformer = GPT2Model(config)
989
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
990
+
991
+ self.init_weights()
992
+
993
+ # Model parallel
994
+ self.model_parallel = False
995
+ self.device_map = None
996
+
997
+ @add_start_docstrings(PARALLELIZE_DOCSTRING)
998
+ def parallelize(self, device_map=None):
999
+ self.device_map = (
1000
+ get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
1001
+ if device_map is None
1002
+ else device_map
1003
+ )
1004
+ assert_device_map(self.device_map, len(self.transformer.h))
1005
+ self.transformer.parallelize(self.device_map)
1006
+ self.lm_head = self.lm_head.to(self.transformer.first_device)
1007
+ self.model_parallel = True
1008
+
1009
+ @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
1010
+ def deparallelize(self):
1011
+ self.transformer.deparallelize()
1012
+ self.transformer = self.transformer.to("cpu")
1013
+ self.lm_head = self.lm_head.to("cpu")
1014
+ self.model_parallel = False
1015
+ torch.cuda.empty_cache()
1016
+
1017
+ def get_output_embeddings(self):
1018
+ return self.lm_head
1019
+
1020
+ def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
1021
+ token_type_ids = kwargs.get("token_type_ids", None)
1022
+ # only last token for inputs_ids if past is defined in kwargs
1023
+ if past:
1024
+ input_ids = input_ids[:, -1].unsqueeze(-1)
1025
+ if token_type_ids is not None:
1026
+ token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
1027
+
1028
+ attention_mask = kwargs.get("attention_mask", None)
1029
+ position_ids = kwargs.get("position_ids", None)
1030
+
1031
+ if attention_mask is not None and position_ids is None:
1032
+ # create position_ids on the fly for batch generation
1033
+ position_ids = attention_mask.long().cumsum(-1) - 1
1034
+ position_ids.masked_fill_(attention_mask == 0, 1)
1035
+ if past:
1036
+ position_ids = position_ids[:, -1].unsqueeze(-1)
1037
+ else:
1038
+ position_ids = None
1039
+ return {
1040
+ "input_ids": input_ids,
1041
+ "past_key_values": past,
1042
+ "use_cache": kwargs.get("use_cache"),
1043
+ "position_ids": position_ids,
1044
+ "attention_mask": attention_mask,
1045
+ "token_type_ids": token_type_ids,
1046
+ }
1047
+
1048
+ @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
1049
+ @add_code_sample_docstrings(
1050
+ tokenizer_class=_TOKENIZER_FOR_DOC,
1051
+ checkpoint="gpt2",
1052
+ output_type=CausalLMOutputWithCrossAttentions,
1053
+ config_class=_CONFIG_FOR_DOC,
1054
+ )
1055
+ def forward(
1056
+ self,
1057
+ input_ids=None,
1058
+ past_key_values=None,
1059
+ attention_mask=None,
1060
+ token_type_ids=None,
1061
+ position_ids=None,
1062
+ head_mask=None,
1063
+ inputs_embeds=None,
1064
+ encoder_hidden_states=None,
1065
+ encoder_attention_mask=None,
1066
+ labels=None,
1067
+ use_cache=None,
1068
+ output_attentions=None,
1069
+ output_hidden_states=None,
1070
+ return_dict=None,
1071
+ ):
1072
+ r"""
1073
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1074
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
1075
+ ``labels = input_ids`` Indices are selected in ``[-100, 0, ..., config.vocab_size]`` All labels set to
1076
+ ``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ..., config.vocab_size]``
1077
+ """
1078
+ return_dict = (
1079
+ return_dict if return_dict is not None else self.config.use_return_dict
1080
+ )
1081
+
1082
+ transformer_outputs = self.transformer(
1083
+ input_ids,
1084
+ past_key_values=past_key_values,
1085
+ attention_mask=attention_mask,
1086
+ token_type_ids=token_type_ids,
1087
+ position_ids=position_ids,
1088
+ head_mask=head_mask,
1089
+ inputs_embeds=inputs_embeds,
1090
+ encoder_hidden_states=encoder_hidden_states,
1091
+ encoder_attention_mask=encoder_attention_mask,
1092
+ use_cache=use_cache,
1093
+ output_attentions=output_attentions,
1094
+ output_hidden_states=output_hidden_states,
1095
+ return_dict=return_dict,
1096
+ )
1097
+ hidden_states = transformer_outputs[0]
1098
+
1099
+ # Set device for model parallelism
1100
+ if self.model_parallel:
1101
+ torch.cuda.set_device(self.transformer.first_device)
1102
+ hidden_states = hidden_states.to(self.lm_head.weight.device)
1103
+
1104
+ lm_logits = self.lm_head(hidden_states)
1105
+
1106
+ loss = None
1107
+ if labels is not None:
1108
+ # Shift so that tokens < n predict n
1109
+ shift_logits = lm_logits[..., :-1, :].contiguous()
1110
+ shift_labels = labels[..., 1:].contiguous()
1111
+ # Flatten the tokens
1112
+ loss_fct = CrossEntropyLoss()
1113
+ loss = loss_fct(
1114
+ shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
1115
+ )
1116
+
1117
+ if not return_dict:
1118
+ output = (lm_logits,) + transformer_outputs[1:]
1119
+ return ((loss,) + output) if loss is not None else output
1120
+
1121
+ return CausalLMOutputWithCrossAttentions(
1122
+ loss=loss,
1123
+ logits=lm_logits,
1124
+ past_key_values=transformer_outputs.past_key_values,
1125
+ hidden_states=transformer_outputs.hidden_states,
1126
+ attentions=transformer_outputs.attentions,
1127
+ cross_attentions=transformer_outputs.cross_attentions,
1128
+ )
1129
+
1130
+ @staticmethod
1131
+ def _reorder_cache(
1132
+ past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
1133
+ ) -> Tuple[Tuple[torch.Tensor]]:
1134
+ """
1135
+ This function is used to re-order the :obj:`past_key_values` cache if
1136
+ :meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is
1137
+ called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
1138
+ """
1139
+ return tuple(
1140
+ tuple(
1141
+ past_state.index_select(0, beam_idx.to(past_state.device))
1142
+ for past_state in layer_past
1143
+ )
1144
+ for layer_past in past
1145
+ )
1146
+
1147
+
1148
+ @add_start_docstrings(
1149
+ """
1150
+ The GPT2 Model transformer with a language modeling and a multiple-choice classification head on top e.g. for
1151
+ RocStories/SWAG tasks. The two heads are two linear layers. The language modeling head has its weights tied to the
1152
+ input embeddings, the classification head takes as input the input of a specified classification token index in the
1153
+ input sequence).
1154
+ """,
1155
+ GPT2_START_DOCSTRING,
1156
+ )
1157
+ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
1158
+ def __init__(self, config):
1159
+ super().__init__(config)
1160
+ config.num_labels = 1
1161
+ self.transformer = GPT2Model(config)
1162
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
1163
+ self.multiple_choice_head = SequenceSummary(config)
1164
+
1165
+ self.init_weights()
1166
+
1167
+ # Model parallel
1168
+ self.model_parallel = False
1169
+ self.device_map = None
1170
+
1171
+ @add_start_docstrings(PARALLELIZE_DOCSTRING)
1172
+ def parallelize(self, device_map=None):
1173
+ self.device_map = (
1174
+ get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
1175
+ if device_map is None
1176
+ else device_map
1177
+ )
1178
+ assert_device_map(self.device_map, len(self.transformer.h))
1179
+ self.transformer.parallelize(self.device_map)
1180
+ self.lm_head = self.lm_head.to(self.transformer.first_device)
1181
+ self.multiple_choice_head = self.multiple_choice_head.to(
1182
+ self.transformer.first_device
1183
+ )
1184
+ self.model_parallel = True
1185
+
1186
+ @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
1187
+ def deparallelize(self):
1188
+ self.transformer.deparallelize()
1189
+ self.transformer = self.transformer.to("cpu")
1190
+ self.lm_head = self.lm_head.to("cpu")
1191
+ self.multiple_choice_head = self.multiple_choice_head.to("cpu")
1192
+ self.model_parallel = False
1193
+ torch.cuda.empty_cache()
1194
+
1195
+ def get_output_embeddings(self):
1196
+ return self.lm_head
1197
+
1198
+ def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
1199
+ token_type_ids = kwargs.get("token_type_ids", None)
1200
+ # only last token for inputs_ids if past is defined in kwargs
1201
+ if past:
1202
+ input_ids = input_ids[:, -1].unsqueeze(-1)
1203
+ if token_type_ids is not None:
1204
+ token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
1205
+
1206
+ attention_mask = kwargs.get("attention_mask", None)
1207
+ position_ids = kwargs.get("position_ids", None)
1208
+
1209
+ if attention_mask is not None and position_ids is None:
1210
+ # create position_ids on the fly for batch generation
1211
+ position_ids = attention_mask.long().cumsum(-1) - 1
1212
+ position_ids.masked_fill_(attention_mask == 0, 1)
1213
+ if past:
1214
+ position_ids = position_ids[:, -1].unsqueeze(-1)
1215
+ else:
1216
+ position_ids = None
1217
+
1218
+ return {
1219
+ "input_ids": input_ids,
1220
+ "past_key_values": past,
1221
+ "use_cache": kwargs.get("use_cache"),
1222
+ "position_ids": position_ids,
1223
+ "attention_mask": attention_mask,
1224
+ "token_type_ids": token_type_ids,
1225
+ }
1226
+
1227
+ @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
1228
+ @replace_return_docstrings(
1229
+ output_type=GPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC
1230
+ )
1231
+ def forward(
1232
+ self,
1233
+ input_ids=None,
1234
+ past_key_values=None,
1235
+ attention_mask=None,
1236
+ token_type_ids=None,
1237
+ position_ids=None,
1238
+ head_mask=None,
1239
+ inputs_embeds=None,
1240
+ mc_token_ids=None,
1241
+ labels=None,
1242
+ mc_labels=None,
1243
+ use_cache=None,
1244
+ output_attentions=None,
1245
+ output_hidden_states=None,
1246
+ return_dict=None,
1247
+ **kwargs,
1248
+ ):
1249
+ r"""
1250
+ mc_token_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, num_choices)`, `optional`, default to index of the last token of the input):
1251
+ Index of the classification token in each input sequence. Selected in the range ``[0, input_ids.size(-1) -
1252
+ 1[``.
1253
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1254
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
1255
+ ``labels = input_ids`` Indices are selected in ``[-1, 0, ..., config.vocab_size]`` All labels set to
1256
+ ``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ..., config.vocab_size]``
1257
+ mc_labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size)`, `optional`):
1258
+ Labels for computing the multiple choice classification loss. Indices should be in ``[0, ...,
1259
+ num_choices]`` where `num_choices` is the size of the second dimension of the input tensors. (see
1260
+ `input_ids` above)
1261
+
1262
+ Return:
1263
+
1264
+ Example::
1265
+
1266
+ >>> import torch
1267
+ >>> from transformers import GPT2Tokenizer, GPT2DoubleHeadsModel
1268
+
1269
+ >>> tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
1270
+ >>> model = GPT2DoubleHeadsModel.from_pretrained('gpt2')
1271
+
1272
+ >>> # Add a [CLS] to the vocabulary (we should train it also!)
1273
+ >>> num_added_tokens = tokenizer.add_special_tokens({'cls_token': '[CLS]'})
1274
+
1275
+ >>> embedding_layer = model.resize_token_embeddings(len(tokenizer)) # Update the model embeddings with the new vocabulary size
1276
+
1277
+ >>> choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"]
1278
+ >>> encoded_choices = [tokenizer.encode(s) for s in choices]
1279
+ >>> cls_token_location = [tokens.index(tokenizer.cls_token_id) for tokens in encoded_choices]
1280
+
1281
+ >>> input_ids = torch.tensor(encoded_choices).unsqueeze(0) # Batch size: 1, number of choices: 2
1282
+ >>> mc_token_ids = torch.tensor([cls_token_location]) # Batch size: 1
1283
+
1284
+ >>> outputs = model(input_ids, mc_token_ids=mc_token_ids)
1285
+ >>> lm_logits = outputs.lm_logits
1286
+ >>> mc_logits = outputs.mc_logits
1287
+
1288
+ """
1289
+ return_dict = (
1290
+ return_dict if return_dict is not None else self.config.use_return_dict
1291
+ )
1292
+
1293
+ transformer_outputs = self.transformer(
1294
+ input_ids,
1295
+ past_key_values=past_key_values,
1296
+ attention_mask=attention_mask,
1297
+ token_type_ids=token_type_ids,
1298
+ position_ids=position_ids,
1299
+ head_mask=head_mask,
1300
+ inputs_embeds=inputs_embeds,
1301
+ use_cache=use_cache,
1302
+ output_attentions=output_attentions,
1303
+ output_hidden_states=output_hidden_states,
1304
+ return_dict=return_dict,
1305
+ )
1306
+
1307
+ hidden_states = transformer_outputs[0]
1308
+
1309
+ # Set device for model parallelism
1310
+ if self.model_parallel:
1311
+ torch.cuda.set_device(self.transformer.first_device)
1312
+ hidden_states = hidden_states.to(self.lm_head.weight.device)
1313
+
1314
+ lm_logits = self.lm_head(hidden_states)
1315
+ mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1)
1316
+
1317
+ mc_loss = None
1318
+ if mc_labels is not None:
1319
+ loss_fct = CrossEntropyLoss()
1320
+ mc_loss = loss_fct(
1321
+ mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)
1322
+ )
1323
+ lm_loss = None
1324
+ if labels is not None:
1325
+ shift_logits = lm_logits[..., :-1, :].contiguous()
1326
+ shift_labels = labels[..., 1:].contiguous()
1327
+ loss_fct = CrossEntropyLoss()
1328
+ lm_loss = loss_fct(
1329
+ shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
1330
+ )
1331
+
1332
+ if not return_dict:
1333
+ output = (lm_logits, mc_logits) + transformer_outputs[1:]
1334
+ if mc_loss is not None:
1335
+ output = (mc_loss,) + output
1336
+ return ((lm_loss,) + output) if lm_loss is not None else output
1337
+
1338
+ return GPT2DoubleHeadsModelOutput(
1339
+ loss=lm_loss,
1340
+ mc_loss=mc_loss,
1341
+ logits=lm_logits,
1342
+ mc_logits=mc_logits,
1343
+ past_key_values=transformer_outputs.past_key_values,
1344
+ hidden_states=transformer_outputs.hidden_states,
1345
+ attentions=transformer_outputs.attentions,
1346
+ )
1347
+
1348
+ @staticmethod
1349
+ def _reorder_cache(
1350
+ past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
1351
+ ) -> Tuple[Tuple[torch.Tensor]]:
1352
+ """
1353
+ This function is used to re-order the :obj:`past_key_values` cache if
1354
+ :meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is
1355
+ called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
1356
+ """
1357
+ return tuple(
1358
+ tuple(
1359
+ past_state.index_select(0, beam_idx.to(past_state.device))
1360
+ for past_state in layer_past
1361
+ )
1362
+ for layer_past in past
1363
+ )
1364
+
1365
+
1366
+ @add_start_docstrings(
1367
+ """
1368
+ The GPT2 Model transformer with a sequence classification head on top (linear layer).
1369
+
1370
+ :class:`~transformers.GPT2ForSequenceClassification` uses the last token in order to do the classification, as
1371
+ other causal models (e.g. GPT-1) do.
1372
+
1373
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1374
+ :obj:`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each
1375
+ row. If no :obj:`pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot
1376
+ guess the padding tokens when :obj:`inputs_embeds` are passed instead of :obj:`input_ids`, it does the same (take
1377
+ the last value in each row of the batch).
1378
+ """,
1379
+ GPT2_START_DOCSTRING,
1380
+ )
1381
+ class GPT2ForSequenceClassification(GPT2PreTrainedModel):
1382
+ _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"]
1383
+
1384
+ def __init__(self, config):
1385
+ super().__init__(config)
1386
+ self.num_labels = config.num_labels
1387
+ self.transformer = GPT2Model(config)
1388
+ self.score = nn.Linear(config.n_embd, self.num_labels, bias=False)
1389
+
1390
+ self.init_weights()
1391
+
1392
+ # Model parallel
1393
+ self.model_parallel = False
1394
+ self.device_map = None
1395
+
1396
+ @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
1397
+ @add_code_sample_docstrings(
1398
+ tokenizer_class=_TOKENIZER_FOR_DOC,
1399
+ checkpoint="microsoft/dialogrpt",
1400
+ output_type=SequenceClassifierOutputWithPast,
1401
+ config_class=_CONFIG_FOR_DOC,
1402
+ )
1403
+ def forward(
1404
+ self,
1405
+ input_ids=None,
1406
+ past_key_values=None,
1407
+ attention_mask=None,
1408
+ token_type_ids=None,
1409
+ position_ids=None,
1410
+ head_mask=None,
1411
+ inputs_embeds=None,
1412
+ labels=None,
1413
+ use_cache=None,
1414
+ output_attentions=None,
1415
+ output_hidden_states=None,
1416
+ return_dict=None,
1417
+ ):
1418
+ r"""
1419
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1420
+ Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
1421
+ config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
1422
+ If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1423
+ """
1424
+ return_dict = (
1425
+ return_dict if return_dict is not None else self.config.use_return_dict
1426
+ )
1427
+
1428
+ transformer_outputs = self.transformer(
1429
+ input_ids,
1430
+ past_key_values=past_key_values,
1431
+ attention_mask=attention_mask,
1432
+ token_type_ids=token_type_ids,
1433
+ position_ids=position_ids,
1434
+ head_mask=head_mask,
1435
+ inputs_embeds=inputs_embeds,
1436
+ use_cache=use_cache,
1437
+ output_attentions=output_attentions,
1438
+ output_hidden_states=output_hidden_states,
1439
+ return_dict=return_dict,
1440
+ )
1441
+ hidden_states = transformer_outputs[0]
1442
+ logits = self.score(hidden_states)
1443
+
1444
+ if input_ids is not None:
1445
+ batch_size, sequence_length = input_ids.shape[:2]
1446
+ else:
1447
+ batch_size, sequence_length = inputs_embeds.shape[:2]
1448
+
1449
+ assert (
1450
+ self.config.pad_token_id is not None or batch_size == 1
1451
+ ), "Cannot handle batch sizes > 1 if no padding token is defined."
1452
+ if self.config.pad_token_id is None:
1453
+ sequence_lengths = -1
1454
+ else:
1455
+ if input_ids is not None:
1456
+ sequence_lengths = (
1457
+ torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
1458
+ )
1459
+ else:
1460
+ sequence_lengths = -1
1461
+ logger.warning(
1462
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
1463
+ f"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
1464
+ )
1465
+
1466
+ pooled_logits = logits[range(batch_size), sequence_lengths]
1467
+
1468
+ loss = None
1469
+ if labels is not None:
1470
+ if self.num_labels == 1:
1471
+ # We are doing regression
1472
+ loss_fct = MSELoss()
1473
+ loss = loss_fct(pooled_logits.view(-1), labels.to(self.dtype).view(-1))
1474
+ else:
1475
+ loss_fct = CrossEntropyLoss()
1476
+ loss = loss_fct(
1477
+ pooled_logits.view(-1, self.num_labels), labels.view(-1)
1478
+ )
1479
+
1480
+ if not return_dict:
1481
+ output = (pooled_logits,) + transformer_outputs[1:]
1482
+ return ((loss,) + output) if loss is not None else output
1483
+
1484
+ return SequenceClassifierOutputWithPast(
1485
+ loss=loss,
1486
+ logits=pooled_logits,
1487
+ past_key_values=transformer_outputs.past_key_values,
1488
+ hidden_states=transformer_outputs.hidden_states,
1489
+ attentions=transformer_outputs.attentions,
1490
+ )
1491
+
1492
+
1493
+ @add_start_docstrings(
1494
+ """
1495
+ GPT2 Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1496
+ Named-Entity-Recognition (NER) tasks.
1497
+ """,
1498
+ GPT2_START_DOCSTRING,
1499
+ )
1500
+ class GPT2ForTokenClassification(GPT2PreTrainedModel):
1501
+ def __init__(self, config):
1502
+ super().__init__(config)
1503
+ self.num_labels = config.num_labels
1504
+
1505
+ self.transformer = GPT2Model(config)
1506
+ if (
1507
+ hasattr(config, "classifier_dropout")
1508
+ and config.classifier_dropout is not None
1509
+ ):
1510
+ classifier_dropout = config.classifier_dropout
1511
+ elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
1512
+ classifier_dropout = config.hidden_dropout
1513
+ else:
1514
+ classifier_dropout = 0.1
1515
+ self.dropout = nn.Dropout(classifier_dropout)
1516
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1517
+
1518
+ self.init_weights()
1519
+
1520
+ # Model parallel
1521
+ self.model_parallel = False
1522
+ self.device_map = None
1523
+
1524
+ @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
1525
+ @add_code_sample_docstrings(
1526
+ tokenizer_class=_TOKENIZER_FOR_DOC,
1527
+ checkpoint="microsoft/DialogRPT-updown",
1528
+ output_type=TokenClassifierOutput,
1529
+ config_class=_CONFIG_FOR_DOC,
1530
+ )
1531
+ def forward(
1532
+ self,
1533
+ input_ids=None,
1534
+ past_key_values=None,
1535
+ attention_mask=None,
1536
+ token_type_ids=None,
1537
+ position_ids=None,
1538
+ head_mask=None,
1539
+ inputs_embeds=None,
1540
+ labels=None,
1541
+ use_cache=None,
1542
+ output_attentions=None,
1543
+ output_hidden_states=None,
1544
+ return_dict=None,
1545
+ ):
1546
+ r"""
1547
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1548
+ Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
1549
+ config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
1550
+ If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1551
+ """
1552
+ return_dict = (
1553
+ return_dict if return_dict is not None else self.config.use_return_dict
1554
+ )
1555
+
1556
+ transformer_outputs = self.transformer(
1557
+ input_ids,
1558
+ past_key_values=past_key_values,
1559
+ attention_mask=attention_mask,
1560
+ token_type_ids=token_type_ids,
1561
+ position_ids=position_ids,
1562
+ head_mask=head_mask,
1563
+ inputs_embeds=inputs_embeds,
1564
+ use_cache=use_cache,
1565
+ output_attentions=output_attentions,
1566
+ output_hidden_states=output_hidden_states,
1567
+ return_dict=return_dict,
1568
+ )
1569
+
1570
+ hidden_states = transformer_outputs[0]
1571
+ hidden_states = self.dropout(hidden_states)
1572
+ logits = self.classifier(hidden_states)
1573
+
1574
+ loss = None
1575
+ if labels is not None:
1576
+ loss_fct = CrossEntropyLoss()
1577
+ # Only keep active parts of the loss
1578
+ if attention_mask is not None:
1579
+ active_loss = attention_mask.view(-1) == 1
1580
+ active_logits = logits.view(-1, self.num_labels)
1581
+ active_labels = torch.where(
1582
+ active_loss,
1583
+ labels.view(-1),
1584
+ torch.tensor(loss_fct.ignore_index).type_as(labels),
1585
+ )
1586
+ loss = loss_fct(active_logits, active_labels)
1587
+ else:
1588
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1589
+
1590
+ if not return_dict:
1591
+ output = (logits,) + transformer_outputs[2:]
1592
+ return ((loss,) + output) if loss is not None else output
1593
+
1594
+ return TokenClassifierOutput(
1595
+ loss=loss,
1596
+ logits=logits,
1597
+ hidden_states=transformer_outputs.hidden_states,
1598
+ attentions=transformer_outputs.attentions,
1599
+ )
backend/preprocess.py ADDED
@@ -0,0 +1,736 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import html
2
+ import logging
3
+ import re
4
+ from typing import List
5
+ from farasa.segmenter import FarasaSegmenter
6
+ import emoji
7
+
8
+ import pyarabic.araby as araby
9
+
10
+ ACCEPTED_MODELS = [
11
+ "bert-base-arabertv01",
12
+ "bert-base-arabert",
13
+ "bert-base-arabertv02",
14
+ "bert-base-arabertv2",
15
+ "bert-large-arabertv02",
16
+ "bert-large-arabertv2",
17
+ "araelectra-base",
18
+ "araelectra-base-discriminator",
19
+ "araelectra-base-generator",
20
+ "araelectra-base-artydiqa",
21
+ "aragpt2-base",
22
+ "aragpt2-medium",
23
+ "aragpt2-large",
24
+ "aragpt2-mega",
25
+ ]
26
+
27
+ SEGMENTED_MODELS = [
28
+ "bert-base-arabert",
29
+ "bert-base-arabertv2",
30
+ "bert-large-arabertv2",
31
+ ]
32
+
33
+ SECOND_GEN_MODELS = [
34
+ "bert-base-arabertv02",
35
+ "bert-base-arabertv2",
36
+ "bert-large-arabertv02",
37
+ "bert-large-arabertv2",
38
+ "araelectra-base",
39
+ "araelectra-base-discriminator",
40
+ "araelectra-base-generator",
41
+ "araelectra-base-artydiqa",
42
+ "aragpt2-base",
43
+ "aragpt2-medium",
44
+ "aragpt2-large",
45
+ "aragpt2-mega",
46
+ ]
47
+
48
+ farasa_segmenter = FarasaSegmenter(interactive=True)
49
+
50
+
51
+ class ArabertPreprocessor:
52
+ """
53
+ A Preprocessor class that cleans and preprocesses text for all models in the AraBERT repo.
54
+ It also can unprocess the text ouput of the generated text
55
+
56
+ Args:
57
+
58
+ model_name (:obj:`str`): model name from the HuggingFace Models page without
59
+ the aubmindlab tag. Will default to a base Arabic preprocessor if model name was not found.
60
+ Current accepted models are:
61
+
62
+ - "bert-base-arabertv01": No farasa segmentation.
63
+ - "bert-base-arabert": with farasa segmentation.
64
+ - "bert-base-arabertv02": No farasas egmentation.
65
+ - "bert-base-arabertv2": with farasa segmentation.
66
+ - "bert-large-arabertv02": No farasas egmentation.
67
+ - "bert-large-arabertv2": with farasa segmentation.
68
+ - "araelectra-base": No farasa segmentation.
69
+ - "araelectra-base-discriminator": No farasa segmentation.
70
+ - "araelectra-base-generator": No farasa segmentation.
71
+ - "aragpt2-base": No farasa segmentation.
72
+ - "aragpt2-medium": No farasa segmentation.
73
+ - "aragpt2-large": No farasa segmentation.
74
+ - "aragpt2-mega": No farasa segmentation.
75
+
76
+
77
+ keep_emojis(:obj:`bool`, `optional`, defaults to :obj:`False`): don't remove emojis while preprocessing.
78
+
79
+ remove_html_markup(:obj: `bool`, `optional`, defaults to :obj:`True`): Whether to remove html artfacts,
80
+ should be set to False when preprocessing TyDi QA.
81
+
82
+ replace_urls_emails_mentions(:obj:`bool`, `optional`, defaults to :obj:`True`): Whether to replace email urls
83
+ and mentions by special tokens.
84
+
85
+ strip_tashkeel(:obj:`bool`, `optional`, defaults to :obj:`True`): remove diacritics (FATHATAN, DAMMATAN, KASRATAN, FATHA, DAMMA,
86
+ KASRA, SUKUN, SHADDA).
87
+
88
+ strip_tatweel(:obj:`bool`, `optional`, defaults to :obj:`True`): remove tatweel '\\u0640'.
89
+
90
+ insert_white_spaces(:obj:`bool`, `optional`, defaults to :obj:`True`): insert whitespace before and after all non Arabic digits
91
+ or English digits or Arabic and English Alphabet or the 2 brackets, then inserts whitespace
92
+ between words and numbers or numbers and words.
93
+
94
+ remove_non_digit_repetition(:obj:`bool`, `optional`, defaults to :obj:`True`): replace repetition of more than 2 non-digit character with
95
+ 2 of this character.
96
+
97
+ replace_slash_with_dash(:obj:`bool`, `optional`, defaults to :obj:`None`): Will be automatically set to True in AraBERTv02,
98
+ AraELECTRA and AraGPT2.
99
+ Set to False to force disable, and True to force enable. Replaces the "/" with "-",
100
+ since "/" is missing from AraBERTv2, AraELECTRA and ARAGPT2 vocabulary.
101
+
102
+ map_hindi_numbers_to_arabic(:obj:`bool`, `optional`, defaults to :obj:`None`): Will be automatically set to True in
103
+ AraBERTv02, AraELECTRA and AraGPT2.Set to False to force disable, and True to force enable.
104
+ Replaces hindi numbers with the corresponding Arabic one. ex: "١٩٩٥" --> "1995".
105
+ This is behavior is present by default in AraBERTv1 and v2 (with pre-segmentation),
106
+ and fixes the issue of caused by a bug when inserting white spaces.
107
+
108
+ apply_farasa_segmentation(:obj:`bool`, `optional`, defaults to :obj:`None`): Will be automatically set to True in
109
+ AraBERTv2, and AraBERTv1. Set to False to force disable, and True to force enable.
110
+
111
+
112
+
113
+ Returns:
114
+
115
+ ArabertPreprocessor: A preprocessor instance
116
+
117
+ Example:
118
+
119
+ from preprocess import ArabertPreprocessor
120
+
121
+ arabert_prep = ArabertPreprocessor("aubmindlab/bert-base-arabertv2")
122
+
123
+ arabert_prep.preprocess("SOME ARABIC TEXT")
124
+ """
125
+
126
+ def __init__(
127
+ self,
128
+ model_name: str,
129
+ keep_emojis: bool = False,
130
+ remove_html_markup: bool = True,
131
+ replace_urls_emails_mentions: bool = True,
132
+ strip_tashkeel: bool = True,
133
+ strip_tatweel: bool = True,
134
+ insert_white_spaces: bool = True,
135
+ remove_non_digit_repetition: bool = True,
136
+ replace_slash_with_dash: bool = None,
137
+ map_hindi_numbers_to_arabic: bool = None,
138
+ apply_farasa_segmentation: bool = None,
139
+ ):
140
+
141
+ model_name = model_name.replace("aubmindlab/", "").replace("wissamantoun/", "")
142
+
143
+ if model_name not in ACCEPTED_MODELS:
144
+ logging.warning(
145
+ """Model provided is not in the accepted model list. Preprocessor will default to a base Arabic preprocessor"""
146
+ )
147
+ self.model_name = "bert-base-arabertv02"
148
+ else:
149
+ self.model_name = model_name
150
+
151
+ if apply_farasa_segmentation is None:
152
+ if self.model_name in SEGMENTED_MODELS:
153
+ self.apply_farasa_segmentation = True
154
+ else:
155
+ self.apply_farasa_segmentation = False
156
+ else:
157
+ if (
158
+ apply_farasa_segmentation == False
159
+ and self.model_name in SEGMENTED_MODELS
160
+ ):
161
+ logging.warning(
162
+ "The selected model_name requires Farasa pre-segmentation, but apply_farasa_segmentation was set to False!"
163
+ )
164
+
165
+ self.apply_farasa_segmentation = apply_farasa_segmentation
166
+
167
+ self.keep_emojis = keep_emojis
168
+ self.remove_html_markup = remove_html_markup
169
+ self.replace_urls_emails_mentions = replace_urls_emails_mentions
170
+ self.strip_tashkeel = strip_tashkeel
171
+ self.strip_tatweel = strip_tatweel
172
+ self.insert_white_spaces = insert_white_spaces
173
+ self.remove_non_digit_repetition = remove_non_digit_repetition
174
+
175
+ if replace_slash_with_dash is None:
176
+ if self.model_name in SECOND_GEN_MODELS:
177
+ self.replace_slash_with_dash = True
178
+ else:
179
+ self.replace_slash_with_dash = False
180
+ else:
181
+ self.replace_slash_with_dash = replace_slash_with_dash
182
+
183
+ if map_hindi_numbers_to_arabic is None:
184
+ if self.model_name in SECOND_GEN_MODELS:
185
+ self.map_hindi_numbers_to_arabic = True
186
+ else:
187
+ self.map_hindi_numbers_to_arabic = False
188
+ else:
189
+ self.map_hindi_numbers_to_arabic = map_hindi_numbers_to_arabic
190
+
191
+ def preprocess(self, text: str) -> str:
192
+ """
193
+ Preprocess takes an input text line an applies the same preprocessing used in AraBERT
194
+ pretraining, or according to settings
195
+
196
+ Args:
197
+
198
+ text (:obj:`str`): inout text string
199
+
200
+ Returns:
201
+
202
+ string: A preprocessed string depending on which model was selected
203
+ """
204
+ if (
205
+ self.model_name == "bert-base-arabert"
206
+ or self.model_name == "bert-base-arabertv01"
207
+ ):
208
+ return self._preprocess_v1(
209
+ text,
210
+ do_farasa_tokenization=self.apply_farasa_segmentation,
211
+ )
212
+
213
+ if self.model_name in SECOND_GEN_MODELS:
214
+ return self._preprocess_v2(text)
215
+
216
+ return self._preprocess_v3(text)
217
+
218
+ def unpreprocess(self, text: str, desegment: bool = True) -> str:
219
+ """Re-formats the text to a classic format where punctuations, brackets, parenthesis are not seperated by whitespaces.
220
+ The objective is to make the generated text of any model appear natural and not preprocessed.
221
+
222
+ Args:
223
+ text (:obj:`str`): input text to be un-preprocessed
224
+ desegment (:obj:`bool`, optional): [whether or not to remove farasa pre-segmentation before]..
225
+
226
+ Returns:
227
+ str: The unpreprocessed (and possibly Farasa-desegmented) text.
228
+ """
229
+
230
+ if self.apply_farasa_segmentation and desegment:
231
+ text = self.desegment(text)
232
+
233
+ # removes the spaces around quotation marks ex: i " ate " an apple --> i "ate" an apple
234
+ # https://stackoverflow.com/a/53436792/5381220
235
+ text = re.sub(white_spaced_double_quotation_regex, '"' + r"\1" + '"', text)
236
+ text = re.sub(white_spaced_single_quotation_regex, "'" + r"\1" + "'", text)
237
+ text = re.sub(white_spaced_back_quotation_regex, "\`" + r"\1" + "\`", text)
238
+ text = re.sub(white_spaced_back_quotation_regex, "\—" + r"\1" + "\—", text)
239
+
240
+ # during generation, sometimes the models don't put a space after the dot, this handles it
241
+ text = text.replace(".", " . ")
242
+ text = " ".join(text.split())
243
+
244
+ # handle decimals
245
+ text = re.sub(r"(\d+) \. (\d+)", r"\1.\2", text)
246
+ text = re.sub(r"(\d+) \, (\d+)", r"\1,\2", text)
247
+
248
+ text = re.sub(left_and_right_spaced_chars, r"\1", text)
249
+ text = re.sub(left_spaced_chars, r"\1", text)
250
+ text = re.sub(right_spaced_chars, r"\1", text)
251
+
252
+ return text
253
+
254
+ def desegment(self, text: str) -> str:
255
+ """
256
+ Use this function if sentence tokenization was done using
257
+ `from arabert.preprocess_arabert import preprocess` with Farasa enabled
258
+ AraBERT segmentation using Farasa adds a space after the '+' for prefixes,
259
+ and after before the '+' for suffixes
260
+
261
+ Example:
262
+ >>> desegment('ال+ دراس +ات')
263
+ الدراسات
264
+ """
265
+ text = text.replace("+ ", "+")
266
+ text = text.replace(" +", "+")
267
+ text = " ".join([self._desegmentword(word) for word in text.split(" ")])
268
+ return text
269
+
270
+ def _desegmentword(self, orig_word: str) -> str:
271
+ """
272
+ Word segmentor that takes a Farasa Segmented Word and removes the '+' signs
273
+
274
+ Example:
275
+ >>> _desegmentword("ال+يومي+ة")
276
+ اليومية
277
+ """
278
+ word = orig_word.replace("ل+ال+", "لل")
279
+ if "ال+ال" not in orig_word:
280
+ word = word.replace("ل+ال", "لل")
281
+ word = word.replace("+", "")
282
+ word = word.replace("للل", "لل")
283
+ return word
284
+
285
+ def _preprocess_v3(self, text: str) -> str:
286
+ text = str(text)
287
+ text = html.unescape(text)
288
+ if self.strip_tashkeel:
289
+ text = araby.strip_tashkeel(text)
290
+ if self.strip_tatweel:
291
+ text = araby.strip_tatweel(text)
292
+
293
+ if self.replace_urls_emails_mentions:
294
+ # replace all possible URLs
295
+ for reg in url_regexes:
296
+ text = re.sub(reg, " [رابط] ", text)
297
+ # REplace Emails with [بريد]
298
+ for reg in email_regexes:
299
+ text = re.sub(reg, " [بريد] ", text)
300
+ # replace mentions with [مستخدم]
301
+ text = re.sub(user_mention_regex, " [مستخدم] ", text)
302
+
303
+ if self.remove_html_markup:
304
+ # remove html line breaks
305
+ text = re.sub("<br />", " ", text)
306
+ # remove html markup
307
+ text = re.sub("</?[^>]+>", " ", text)
308
+
309
+ if self.map_hindi_numbers_to_arabic:
310
+ text = text.translate(hindi_to_arabic_map)
311
+
312
+ # remove repeated characters >2
313
+ if self.remove_non_digit_repetition:
314
+ text = self._remove_non_digit_repetition(text)
315
+
316
+ # insert whitespace before and after all non Arabic digits or English Digits and Alphabet and the 2 brackets
317
+ if self.insert_white_spaces:
318
+ text = re.sub(
319
+ "([^0-9\u0621-\u063A\u0641-\u064A\u0660-\u0669a-zA-Z ])",
320
+ r" \1 ",
321
+ text,
322
+ )
323
+
324
+ # re-fix brackets
325
+ text = text.replace("[ رابط ]", "[رابط]")
326
+ text = text.replace("[ بريد ]", "[بريد]")
327
+ text = text.replace("[ مستخدم ]", "[مستخدم]")
328
+
329
+ # insert whitespace between words and numbers or numbers and words
330
+ text = re.sub(
331
+ "(\d+)([\u0621-\u063A\u0641-\u064A\u066A-\u066C\u0654-\u0655]+)",
332
+ r" \1 \2 ",
333
+ text,
334
+ )
335
+ text = re.sub(
336
+ "([\u0621-\u063A\u0641-\u064A\u066A-\u066C\u0654-\u0655]+)(\d+)",
337
+ r" \1 \2 ",
338
+ text,
339
+ )
340
+
341
+ # remove unwanted characters
342
+ if self.keep_emojis:
343
+ emoji_regex = "".join(list(emoji.UNICODE_EMOJI["en"].keys()))
344
+ rejected_chars_regex2 = "[^%s%s]" % (chars_regexv2, emoji_regex)
345
+ text = re.sub(rejected_chars_regex2, " ", text)
346
+ else:
347
+ text = re.sub(rejected_chars_regexv2, " ", text)
348
+
349
+ # remove extra spaces
350
+ text = " ".join(text.replace("\uFE0F", "").split())
351
+
352
+ if self.apply_farasa_segmentation:
353
+ if self.keep_emojis:
354
+ new_text = []
355
+ for word in text.split():
356
+ if word in list(emoji.UNICODE_EMOJI["en"].keys()):
357
+ new_text.append(word)
358
+ else:
359
+ new_text.append(farasa_segmenter.segment(word))
360
+ text = " ".join(new_text)
361
+ else:
362
+ text = farasa_segmenter.segment(text)
363
+ return self._farasa_segment(text)
364
+
365
+ # ALl the other models dont require Farasa Segmentation
366
+ return text
367
+
368
+ def _preprocess_v2(self, text: str) -> str:
369
+ text = str(text)
370
+ text = html.unescape(text)
371
+ if self.strip_tashkeel:
372
+ text = araby.strip_tashkeel(text)
373
+ if self.strip_tatweel:
374
+ text = araby.strip_tatweel(text)
375
+
376
+ if self.replace_urls_emails_mentions:
377
+ # replace all possible URLs
378
+ for reg in url_regexes:
379
+ text = re.sub(reg, " [رابط] ", text)
380
+ # REplace Emails with [بريد]
381
+ for reg in email_regexes:
382
+ text = re.sub(reg, " [بريد] ", text)
383
+ # replace mentions with [مستخدم]
384
+ text = re.sub(user_mention_regex, " [مستخدم] ", text)
385
+
386
+ if self.remove_html_markup:
387
+ # remove html line breaks
388
+ text = re.sub("<br />", " ", text)
389
+ # remove html markup
390
+ text = re.sub("</?[^>]+>", " ", text)
391
+
392
+ if self.map_hindi_numbers_to_arabic:
393
+ text = text.translate(hindi_to_arabic_map)
394
+
395
+ # remove repeated characters >2
396
+ if self.remove_non_digit_repetition:
397
+ text = self._remove_non_digit_repetition(text)
398
+
399
+ # insert whitespace before and after all non Arabic digits or English Digits and Alphabet and the 2 brackets
400
+ if self.insert_white_spaces:
401
+ text = re.sub(
402
+ "([^0-9\u0621-\u063A\u0641-\u064A\u0660-\u0669a-zA-Z\[\]])",
403
+ r" \1 ",
404
+ text,
405
+ )
406
+
407
+ # insert whitespace between words and numbers or numbers and words
408
+ text = re.sub(
409
+ "(\d+)([\u0621-\u063A\u0641-\u064A\u0660-\u066C]+)", r" \1 \2 ", text
410
+ )
411
+ text = re.sub(
412
+ "([\u0621-\u063A\u0641-\u064A\u0660-\u066C]+)(\d+)", r" \1 \2 ", text
413
+ )
414
+
415
+ if self.replace_slash_with_dash:
416
+ text = text.replace("/", "-")
417
+
418
+ # remove unwanted characters
419
+ if self.keep_emojis:
420
+ emoji_regex = "".join(list(emoji.UNICODE_EMOJI["en"].keys()))
421
+ rejected_chars_regex2 = "[^%s%s]" % (chars_regex, emoji_regex)
422
+ text = re.sub(rejected_chars_regex2, " ", text)
423
+ else:
424
+ text = re.sub(rejected_chars_regex, " ", text)
425
+
426
+ # remove extra spaces
427
+ text = " ".join(text.replace("\uFE0F", "").split())
428
+
429
+ if (
430
+ self.model_name == "bert-base-arabertv2"
431
+ or self.model_name == "bert-large-arabertv2"
432
+ ):
433
+ if self.keep_emojis:
434
+ new_text = []
435
+ for word in text.split():
436
+ if word in list(emoji.UNICODE_EMOJI["en"].keys()):
437
+ new_text.append(word)
438
+ else:
439
+ new_text.append(farasa_segmenter.segment(word))
440
+ text = " ".join(new_text)
441
+ else:
442
+ text = farasa_segmenter.segment(text)
443
+ return self._farasa_segment(text)
444
+
445
+ # ALl the other models dont require Farasa Segmentation
446
+ return text
447
+
448
+ def _preprocess_v1(self, text: str, do_farasa_tokenization: bool) -> str:
449
+ """
450
+ AraBERTv1 preprocessing Function
451
+ """
452
+ text = str(text)
453
+ if self.strip_tashkeel:
454
+ text = araby.strip_tashkeel(text)
455
+
456
+ text = re.sub(r"\d+\/[ء-ي]+\/\d+\]", "", text)
457
+ text = re.sub("ـ", "", text)
458
+ text = re.sub("[«»]", ' " ', text)
459
+
460
+ if self.replace_urls_emails_mentions:
461
+ # replace the [رابط] token with space if you want to clean links
462
+ text = re.sub(regex_url_step1, "[رابط]", text)
463
+ text = re.sub(regex_url_step2, "[رابط]", text)
464
+ text = re.sub(regex_url, "[رابط]", text)
465
+ text = re.sub(regex_email, "[بريد]", text)
466
+ text = re.sub(regex_mention, "[مستخدم]", text)
467
+ text = re.sub("…", r"\.", text).strip()
468
+ text = self._remove_redundant_punct(text)
469
+
470
+ if self.replace_urls_emails_mentions:
471
+ text = re.sub(r"\[ رابط \]|\[ رابط\]|\[رابط \]", " [رابط] ", text)
472
+ text = re.sub(r"\[ بريد \]|\[ بريد\]|\[بريد \]", " [بريد] ", text)
473
+ text = re.sub(r"\[ مستخدم \]|\[ مستخدم\]|\[مستخدم \]", " [مستخدم] ", text)
474
+
475
+ if self.remove_non_digit_repetition:
476
+ text = self._remove_non_digit_repetition(text)
477
+
478
+ if self.insert_white_spaces:
479
+ text = re.sub(
480
+ "([^0-9\u0621-\u063A\u0641-\u0669\u0671-\u0673a-zA-Z\[\]])",
481
+ r" \1 ",
482
+ text,
483
+ )
484
+ if do_farasa_tokenization:
485
+ text = self._tokenize_arabic_words_farasa(text)
486
+
487
+ text = " ".join(text.split())
488
+
489
+ return text
490
+
491
+ def _farasa_segment(self, text: str) -> str:
492
+ line_farasa = text.split()
493
+ segmented_line = []
494
+ for index, word in enumerate(line_farasa):
495
+ if word in ["[", "]"]:
496
+ continue
497
+ if word in ["رابط", "بريد", "مستخدم"] and line_farasa[index - 1] in [
498
+ "[",
499
+ "]",
500
+ ]:
501
+ segmented_line.append("[" + word + "]")
502
+ continue
503
+ if "+" not in word:
504
+ segmented_line.append(word)
505
+ continue
506
+ segmented_word = self._split_farasa_output(word)
507
+ segmented_line.extend(segmented_word)
508
+
509
+ return " ".join(segmented_line)
510
+
511
+ def _split_farasa_output(self, word: str) -> str:
512
+ segmented_word = []
513
+ temp_token = ""
514
+ for i, c in enumerate(word):
515
+ if c == "+":
516
+ # if the token is KAF, it could be a suffix or prefix
517
+ if temp_token == "ك":
518
+ # if we are at the second token, then KAF is surely a prefix
519
+ if i == 1:
520
+ segmented_word.append(temp_token + "+")
521
+ temp_token = ""
522
+ # If the KAF token is between 2 tokens
523
+ elif word[i - 2] == "+":
524
+ # if the previous token is prefix, then this KAF must be a prefix
525
+ if segmented_word[-1][-1] == "+":
526
+ segmented_word.append(temp_token + "+")
527
+ temp_token = ""
528
+ # else it is a suffix, this KAF could not be a second suffix
529
+ else:
530
+ segmented_word.append("+" + temp_token)
531
+ temp_token = ""
532
+ # if Kaf is at the end, this is handled with the statement after the loop
533
+ elif temp_token in prefix_list:
534
+ segmented_word.append(temp_token + "+")
535
+ temp_token = ""
536
+ elif temp_token in suffix_list:
537
+ segmented_word.append("+" + temp_token)
538
+ temp_token = ""
539
+ else:
540
+ segmented_word.append(temp_token)
541
+ temp_token = ""
542
+ continue
543
+ temp_token += c
544
+ if temp_token != "":
545
+ if temp_token in suffix_list:
546
+ segmented_word.append("+" + temp_token)
547
+ else:
548
+ segmented_word.append(temp_token)
549
+ return segmented_word
550
+
551
+ def _tokenize_arabic_words_farasa(self, line_input: str) -> str:
552
+
553
+ if self.keep_emojis:
554
+ # insert whitespace before and after all non Arabic digits or English Digits and Alphabet and the 2 brackets
555
+ line_farasa = []
556
+ for word in line_input.split():
557
+ if word in list(emoji.UNICODE_EMOJI["en"].keys()):
558
+ line_farasa.append(word)
559
+ else:
560
+ line_farasa.append(farasa_segmenter.segment(word))
561
+ else:
562
+ line_farasa = farasa_segmenter.segment(line_input).split()
563
+
564
+ segmented_line = []
565
+ for index, word in enumerate(line_farasa):
566
+ if word in ["[", "]"]:
567
+ continue
568
+ if word in ["رابط", "بريد", "مستخدم"] and line_farasa[index - 1] in [
569
+ "[",
570
+ "]",
571
+ ]:
572
+ segmented_line.append("[" + word + "]")
573
+ continue
574
+ segmented_word = []
575
+ for token in word.split("+"):
576
+ if token in prefix_list:
577
+ segmented_word.append(token + "+")
578
+ elif token in suffix_list:
579
+ segmented_word.append("+" + token)
580
+ else:
581
+ segmented_word.append(token)
582
+ segmented_line.extend(segmented_word)
583
+ return " ".join(segmented_line)
584
+
585
+ def _remove_non_digit_repetition(self, text: str) -> str:
586
+ """
587
+ :param text: the input text to remove elongation
588
+ :return: delongated text
589
+ """
590
+ # loop over the number of times the regex matched the text
591
+ # OLD
592
+ # for index_ in range(len(re.findall(regex_tatweel, text))):
593
+ # elongation = re.search(regex_tatweel, text)
594
+ # if elongation:
595
+ # elongation_pattern = elongation.group()
596
+ # elongation_replacement = elongation_pattern[0]
597
+ # elongation_pattern = re.escape(elongation_pattern)
598
+ # text = re.sub(
599
+ # elongation_pattern, elongation_replacement, text, flags=re.MULTILINE
600
+ # )
601
+ # else:
602
+ # break
603
+
604
+ # New
605
+ text = multiple_char_pattern.sub(r"\1\1", text)
606
+ return text
607
+
608
+ def _remove_redundant_punct(self, text: str) -> str:
609
+ text_ = text
610
+ result = re.search(redundant_punct_pattern, text)
611
+ dif = 0
612
+ while result:
613
+ sub = result.group()
614
+ sub = sorted(set(sub), key=sub.index)
615
+ sub = " " + "".join(list(sub)) + " "
616
+ text = "".join(
617
+ (text[: result.span()[0] + dif], sub, text[result.span()[1] + dif :])
618
+ )
619
+ text_ = "".join(
620
+ (text_[: result.span()[0]], text_[result.span()[1] :])
621
+ ).strip()
622
+ dif = abs(len(text) - len(text_))
623
+ result = re.search(redundant_punct_pattern, text_)
624
+ text = re.sub(r"\s+", " ", text)
625
+ return text.strip()
626
+
627
+
628
+ prefix_list = [
629
+ "ال",
630
+ "و",
631
+ "ف",
632
+ "ب",
633
+ "ك",
634
+ "ل",
635
+ "لل",
636
+ "\u0627\u0644",
637
+ "\u0648",
638
+ "\u0641",
639
+ "\u0628",
640
+ "\u0643",
641
+ "\u0644",
642
+ "\u0644\u0644",
643
+ "س",
644
+ ]
645
+ suffix_list = [
646
+ "ه",
647
+ "ها",
648
+ "ك",
649
+ "ي",
650
+ "هما",
651
+ "كما",
652
+ "نا",
653
+ "كم",
654
+ "هم",
655
+ "هن",
656
+ "كن",
657
+ "ا",
658
+ "ان",
659
+ "ين",
660
+ "ون",
661
+ "وا",
662
+ "ات",
663
+ "ت",
664
+ "ن",
665
+ "ة",
666
+ "\u0647",
667
+ "\u0647\u0627",
668
+ "\u0643",
669
+ "\u064a",
670
+ "\u0647\u0645\u0627",
671
+ "\u0643\u0645\u0627",
672
+ "\u0646\u0627",
673
+ "\u0643\u0645",
674
+ "\u0647\u0645",
675
+ "\u0647\u0646",
676
+ "\u0643\u0646",
677
+ "\u0627",
678
+ "\u0627\u0646",
679
+ "\u064a\u0646",
680
+ "\u0648\u0646",
681
+ "\u0648\u0627",
682
+ "\u0627\u062a",
683
+ "\u062a",
684
+ "\u0646",
685
+ "\u0629",
686
+ ]
687
+ other_tokens = ["[رابط]", "[مستخدم]", "[بريد]"]
688
+
689
+ # the never_split list is ussed with the transformers library
690
+ prefix_symbols = [x + "+" for x in prefix_list]
691
+ suffix_symblos = ["+" + x for x in suffix_list]
692
+ never_split_tokens = list(set(prefix_symbols + suffix_symblos + other_tokens))
693
+
694
+ url_regexes = [
695
+ r"(http(s)?:\/\/.)?(www\.)?[-a-zA-Z0-9@:%._\+~#=]{2,256}\.[a-z]{2,6}\b([-a-zA-Z0-9@:%_\+.~#?&//=]*)",
696
+ r"@(https?|ftp)://(-\.)?([^\s/?\.#-]+\.?)+(/[^\s]*)?$@iS",
697
+ r"http[s]?://[a-zA-Z0-9_\-./~\?=%&]+",
698
+ r"www[a-zA-Z0-9_\-?=%&/.~]+",
699
+ r"[a-zA-Z]+\.com",
700
+ r"(?=http)[^\s]+",
701
+ r"(?=www)[^\s]+",
702
+ r"://",
703
+ ]
704
+ user_mention_regex = r"@[\w\d]+"
705
+ email_regexes = [r"[\w-]+@([\w-]+\.)+[\w-]+", r"\S+@\S+"]
706
+ redundant_punct_pattern = (
707
+ r"([!\"#\$%\'\(\)\*\+,\.:;\-<=·>?@\[\\\]\^_ـ`{\|}~—٪’،؟`୍“؛”ۚ【»؛\s+«–…‘]{2,})"
708
+ )
709
+
710
+ regex_tatweel = r"(\D)\1{2,}"
711
+ multiple_char_pattern = re.compile(r"(\D)\1{2,}", re.DOTALL)
712
+
713
+ rejected_chars_regex = r"[^0-9\u0621-\u063A\u0640-\u066C\u0671-\u0674a-zA-Z\[\]!\"#\$%\'\(\)\*\+,\.:;\-<=·>?@\[\\\]\^_ـ`{\|}~—٪’،؟`୍“؛”ۚ»؛\s+«–…‘]"
714
+ rejected_chars_regexv2 = r"[^0-9\u0621-\u063A\u0641-\u066C\u0671-\u0674a-zA-Z\[\]!\"#\$%\'\(\)\*\+,\.:;\-<=·>?@\[\\\]\^_ـ`{\|}~—٪’،؟`୍“؛”ۚ»؛\s+«–…‘/]"
715
+
716
+ regex_url_step1 = r"(?=http)[^\s]+"
717
+ regex_url_step2 = r"(?=www)[^\s]+"
718
+ regex_url = r"(http(s)?:\/\/.)?(www\.)?[-a-zA-Z0-9@:%._\+~#=]{2,256}\.[a-z]{2,6}\b([-a-zA-Z0-9@:%_\+.~#?&//=]*)"
719
+ regex_mention = r"@[\w\d]+"
720
+ regex_email = r"\S+@\S+"
721
+
722
+ chars_regex = r"0-9\u0621-\u063A\u0640-\u066C\u0671-\u0674a-zA-Z\[\]!\"#\$%\'\(\)\*\+,\.:;\-<=·>?@\[\\\]\^_ـ`{\|}~—٪’،؟`୍“؛”ۚ»؛\s+«–…‘"
723
+ chars_regexv2 = r"0-9\u0621-\u063A\u0640-\u066C\u0671-\u0674a-zA-Z\[\]!\"#\$%\'\(\)\*\+,\.:;\-<=·>?@\[\\\]\^_ـ`{\|}~—٪’،؟`୍“؛”ۚ»؛\s+«–…‘/"
724
+
725
+ white_spaced_double_quotation_regex = r'\"\s+([^"]+)\s+\"'
726
+ white_spaced_single_quotation_regex = r"\'\s+([^']+)\s+\'"
727
+ white_spaced_back_quotation_regex = r"\`\s+([^`]+)\s+\`"
728
+ white_spaced_em_dash = r"\—\s+([^—]+)\s+\—"
729
+
730
+ left_spaced_chars = r" ([\]!#\$%\),\.:;\?}٪’،؟”؛…»·])"
731
+ right_spaced_chars = r"([\[\(\{“«‘*\~]) "
732
+ left_and_right_spaced_chars = r" ([\+\-\<\=\>\@\\\^\_\|\–]) "
733
+
734
+ hindi_nums = "٠١٢٣٤٥٦٧٨٩"
735
+ arabic_nums = "0123456789"
736
+ hindi_to_arabic_map = str.maketrans(hindi_nums, arabic_nums)
backend/processor.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import awesome_streamlit as ast
3
+ from .preprocess import (
4
+ ArabertPreprocessor,
5
+ white_spaced_back_quotation_regex,
6
+ white_spaced_double_quotation_regex,
7
+ white_spaced_em_dash,
8
+ white_spaced_single_quotation_regex,
9
+ left_and_right_spaced_chars,
10
+ left_spaced_chars,
11
+ right_spaced_chars,
12
+ )
13
+ import re
14
+
15
+ MODELS_to_SELECT = [
16
+ "None",
17
+ "bert-base-arabertv01",
18
+ "bert-base-arabert",
19
+ "bert-base-arabertv02",
20
+ "bert-base-arabertv2",
21
+ "bert-large-arabertv02",
22
+ "bert-large-arabertv2",
23
+ "araelectra-base",
24
+ "araelectra-base-discriminator",
25
+ "araelectra-base-generator",
26
+ "araelectra-base-artydiqa",
27
+ "aragpt2-base",
28
+ "aragpt2-medium",
29
+ "aragpt2-large",
30
+ "aragpt2-mega",
31
+ ]
32
+
33
+
34
+ def unpreprocess(text: str) -> str:
35
+ """Re-formats the text to a classic format where punctuations, brackets, parenthesis are not seperated by whitespaces.
36
+ The objective is to make the generated text of any model appear natural and not preprocessed.
37
+
38
+ Args:
39
+ text (:obj:`str`): input text to be un-preprocessed
40
+ desegment (:obj:`bool`, optional): [whether or not to remove farasa pre-segmentation before]..
41
+
42
+ Returns:
43
+ str: The unpreprocessed (and possibly Farasa-desegmented) text.
44
+ """
45
+
46
+ text = desegment(text)
47
+
48
+ # removes the spaces around quotation marks ex: i " ate " an apple --> i "ate" an apple
49
+ # https://stackoverflow.com/a/53436792/5381220
50
+ text = re.sub(white_spaced_double_quotation_regex, '"' + r"\1" + '"', text)
51
+ text = re.sub(white_spaced_single_quotation_regex, "'" + r"\1" + "'", text)
52
+ text = re.sub(white_spaced_back_quotation_regex, "\`" + r"\1" + "\`", text)
53
+ text = re.sub(white_spaced_back_quotation_regex, "\—" + r"\1" + "\—", text)
54
+
55
+ # during generation, sometimes the models don't put a space after the dot, this handles it
56
+ text = text.replace(".", " . ")
57
+ text = " ".join(text.split())
58
+
59
+ # handle decimals
60
+ text = re.sub(r"(\d+) \. (\d+)", r"\1.\2", text)
61
+ text = re.sub(r"(\d+) \, (\d+)", r"\1,\2", text)
62
+
63
+ text = re.sub(left_and_right_spaced_chars, r"\1", text)
64
+ text = re.sub(left_spaced_chars, r"\1", text)
65
+ text = re.sub(right_spaced_chars, r"\1", text)
66
+
67
+ return text
68
+
69
+
70
+ def desegment(text: str) -> str:
71
+ """
72
+ Use this function if sentence tokenization was done using
73
+ `from arabert.preprocess_arabert import preprocess` with Farasa enabled
74
+ AraBERT segmentation using Farasa adds a space after the '+' for prefixes,
75
+ and after before the '+' for suffixes
76
+
77
+ Example:
78
+ >>> desegment('ال+ دراس +ات')
79
+ الدراسات
80
+ """
81
+ text = text.replace("+ ", "+")
82
+ text = text.replace(" +", "+")
83
+ text = " ".join([_desegmentword(word) for word in text.split(" ")])
84
+ return text
85
+
86
+
87
+ def _desegmentword(orig_word: str) -> str:
88
+ """
89
+ Word segmentor that takes a Farasa Segmented Word and removes the '+' signs
90
+
91
+ Example:
92
+ >>> _desegmentword("ال+يومي+ة")
93
+ اليومية
94
+ """
95
+ word = orig_word.replace("ل+ال+", "لل")
96
+ if "ال+ال" not in orig_word:
97
+ word = word.replace("ل+ال", "لل")
98
+ word = word.replace("+", "")
99
+ word = word.replace("للل", "لل")
100
+ return word
101
+
102
+
103
+ def write():
104
+
105
+ st.markdown(
106
+ """
107
+ <h1 style="text-align:left;">Arabic Text Pre-Processor</h1>
108
+ """,
109
+ unsafe_allow_html=True,
110
+ )
111
+ st.markdown(
112
+ """
113
+ <style>
114
+ p, div, input, label {
115
+ text-align: right;
116
+ }
117
+ </style>
118
+ """,
119
+ unsafe_allow_html=True,
120
+ )
121
+ input_text = st.text_input(
122
+ "Text to Pre-Process",
123
+ value="ولن نبالغ إذا قلنا: إن 'هاتف' أو 'كمبيوتر المكتب' في زمننا هذا ضروري",
124
+ )
125
+
126
+ st.sidebar.title("Model Selector")
127
+ model_selector = st.sidebar.selectbox(
128
+ """Select None to enable further filters""", options=MODELS_to_SELECT, index=3
129
+ )
130
+ if model_selector == "None":
131
+ keep_emojis = st.sidebar.checkbox("Keep emojis", False)
132
+ remove_html_markup = st.sidebar.checkbox("Remove html markup", True)
133
+ strip_tashkeel = st.sidebar.checkbox("Strip tashkeel", True)
134
+ replace_urls_emails_mentions = st.sidebar.checkbox(
135
+ "Replace urls and emails", True
136
+ )
137
+ strip_tatweel = st.sidebar.checkbox("Strip tatweel", True)
138
+ insert_white_spaces = st.sidebar.checkbox("Insert white spaces", True)
139
+ remove_non_digit_repetition = st.sidebar.checkbox(
140
+ "Remove non-digit repetition", True
141
+ )
142
+ replace_slash_with_dash = st.sidebar.checkbox("Replace slash with dash", None)
143
+ map_hindi_numbers_to_arabic = st.sidebar.checkbox(
144
+ "Map hindi numbers to arabic", None
145
+ )
146
+ apply_farasa_segmentation = st.sidebar.checkbox(
147
+ "Apply farasa segmentation", None
148
+ )
149
+
150
+ run_preprocessor = st.button("Run Pre-Processor")
151
+
152
+ prep_text = None
153
+ if run_preprocessor:
154
+ if model_selector == "None":
155
+ arabert_preprocessor = ArabertPreprocessor(
156
+ model_selector,
157
+ keep_emojis,
158
+ remove_html_markup,
159
+ replace_urls_emails_mentions,
160
+ strip_tashkeel,
161
+ strip_tatweel,
162
+ insert_white_spaces,
163
+ remove_non_digit_repetition,
164
+ replace_slash_with_dash,
165
+ map_hindi_numbers_to_arabic,
166
+ apply_farasa_segmentation,
167
+ )
168
+ else:
169
+ arabert_preprocessor = ArabertPreprocessor(model_name=model_selector)
170
+ prep_text = arabert_preprocessor._preprocess_v3(input_text)
171
+ st.write(prep_text)
172
+
173
+ st.write("-----")
174
+ input_text_unprep = st.text_input(
175
+ "Text to Undo the Pre-Processing",
176
+ value=prep_text
177
+ if prep_text
178
+ else "و+ لن نبالغ إذا قل +نا : إن ' هاتف ' أو ' كمبيوتر ال+ مكتب ' في زمن +نا هذا ضروري",
179
+ )
180
+ run_unpreprocessor = st.button("Run Un-Pre-Processor")
181
+
182
+ if run_unpreprocessor:
183
+ st.write(unpreprocess(input_text_unprep))
backend/qa.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ from .qa_utils import annotate_answer
4
+ from .services import get_qa_answers
5
+
6
+
7
+ def write():
8
+ _, col1, _ = st.columns(3)
9
+
10
+ with col1:
11
+ st.title("Ask any question!")
12
+
13
+ st.markdown(
14
+ """
15
+ <style>
16
+ p, div, input, label {
17
+ text-align: right;
18
+ }
19
+ </style>
20
+ """,
21
+ unsafe_allow_html=True,
22
+ )
23
+
24
+ st.sidebar.write("\n")
25
+ n_answers = st.sidebar.slider(
26
+ "Max. number of answers", min_value=1, max_value=10, value=2, step=1
27
+ )
28
+
29
+ question = st.text_input("", value="من هو جو بايدن؟")
30
+ if "؟" not in question:
31
+ question += "؟"
32
+
33
+ run_query = st.button("Find answers")
34
+ if run_query:
35
+ # https://discuss.streamlit.io/t/showing-a-gif-while-st-spinner-runs/5084
36
+ with st.spinner("Searching..."):
37
+ results_dict = get_qa_answers(question)
38
+
39
+ if len(results_dict) > 0:
40
+ st.write("## Answers:")
41
+ for result in results_dict["results"][:n_answers]:
42
+ annotate_answer(result)
43
+ f"[**Source**](<{result['link']}>)"
44
+ else:
45
+ st.write("## 😞 No results found.")
backend/qa_utils.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit.components.v1
2
+
3
+ from htbuilder import HtmlElement, div, span, styles
4
+ from htbuilder.units import px, rem, em
5
+
6
+
7
+ def annotation(body, label="", background="#ddd", color="#333", **style):
8
+ """Build an HtmlElement span object with the given body and annotation label.
9
+
10
+ The end result will look something like this:
11
+
12
+ [body | label]
13
+
14
+ Parameters
15
+ ----------
16
+ body : string
17
+ The string to put in the "body" part of the annotation.
18
+ label : string
19
+ The string to put in the "label" part of the annotation.
20
+ background : string
21
+ The color to use for the background "chip" containing this annotation.
22
+ color : string
23
+ The color to use for the body and label text.
24
+ **style : dict
25
+ Any CSS you want to use to customize the containing "chip".
26
+
27
+ Examples
28
+ --------
29
+
30
+ Produce a simple annotation with default colors:
31
+
32
+ >>> annotation("apple", "fruit")
33
+
34
+ Produce an annotation with custom colors:
35
+
36
+ >>> annotation("apple", "fruit", background="#FF0", color="black")
37
+
38
+ Produce an annotation with crazy CSS:
39
+
40
+ >>> annotation("apple", "fruit", background="#FF0", border="1px dashed red")
41
+
42
+ """
43
+
44
+ if "font_family" not in style:
45
+ style["font_family"] = "sans-serif"
46
+
47
+ return span(
48
+ style=styles(
49
+ background=background,
50
+ border_radius=rem(0.33),
51
+ color=color,
52
+ padding=(rem(0.17), rem(0.67)),
53
+ display="inline-flex",
54
+ justify_content="center",
55
+ align_items="center",
56
+ **style,
57
+ )
58
+ )(
59
+ body,
60
+ span(
61
+ style=styles(
62
+ color=color,
63
+ font_size=em(0.67),
64
+ opacity=0.5,
65
+ padding_left=rem(0.5),
66
+ text_transform="uppercase",
67
+ margin_bottom=px(-2),
68
+ )
69
+ )(label),
70
+ )
71
+
72
+
73
+ def annotated_text(*args, **kwargs):
74
+ """Writes test with annotations into your Streamlit app.
75
+
76
+ Parameters
77
+ ----------
78
+ *args : str, tuple or htbuilder.HtmlElement
79
+ Arguments can be:
80
+ - strings, to draw the string as-is on the screen.
81
+ - tuples of the form (main_text, annotation_text, background, color) where
82
+ background and foreground colors are optional and should be an CSS-valid string such as
83
+ "#aabbcc" or "rgb(10, 20, 30)"
84
+ - HtmlElement objects in case you want to customize the annotations further. In particular,
85
+ you can import the `annotation()` function from this module to easily produce annotations
86
+ whose CSS you can customize via keyword arguments.
87
+
88
+ Examples
89
+ --------
90
+
91
+ >>> annotated_text(
92
+ ... "This ",
93
+ ... ("is", "verb", "#8ef"),
94
+ ... " some ",
95
+ ... ("annotated", "adj", "#faa"),
96
+ ... ("text", "noun", "#afa"),
97
+ ... " for those of ",
98
+ ... ("you", "pronoun", "#fea"),
99
+ ... " who ",
100
+ ... ("like", "verb", "#8ef"),
101
+ ... " this sort of ",
102
+ ... ("thing", "noun", "#afa"),
103
+ ... )
104
+
105
+ >>> annotated_text(
106
+ ... "Hello ",
107
+ ... annotation("world!", "noun", color="#8ef", border="1px dashed red"),
108
+ ... )
109
+
110
+ """
111
+ out = div(
112
+ style=styles(
113
+ font_family="sans-serif",
114
+ line_height="1.45",
115
+ font_size=px(16),
116
+ text_align="right",
117
+ )
118
+ )
119
+
120
+ for arg in args:
121
+ if isinstance(arg, str):
122
+ out(arg)
123
+
124
+ elif isinstance(arg, HtmlElement):
125
+ out(arg)
126
+
127
+ elif isinstance(arg, tuple):
128
+ out(annotation(*arg))
129
+
130
+ else:
131
+ raise Exception("Oh noes!")
132
+
133
+ streamlit.components.v1.html(str(out), **kwargs)
134
+
135
+
136
+ def shorten_text(text, n, reverse=False):
137
+ if text.isspace() or text == "":
138
+ return text
139
+ if reverse:
140
+ text = text[::-1]
141
+ words = iter(text.split())
142
+ lines, current = [], next(words)
143
+ for word in words:
144
+ if len(current) + 1 + len(word) > n:
145
+ break
146
+ else:
147
+ current += " " + word
148
+ lines.append(current)
149
+ if reverse:
150
+ return lines[0][::-1]
151
+ return lines[0]
152
+
153
+
154
+ def annotate_answer(result):
155
+ annotated_text(
156
+ shorten_text(
157
+ result["original"][: result["new_start"]],
158
+ 500,
159
+ reverse=True,
160
+ ),
161
+ (result["new_answer"], "جواب", "#8ef"),
162
+ shorten_text(result["original"][result["new_end"] :], 500) + " ...... إلخ",
163
+ )
backend/sa.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from .services import SentimentAnalyzer
3
+ from functools import lru_cache
4
+
5
+ # @st.cache(allow_output_mutation=False, hash_funcs={Tokenizer: str})
6
+ @lru_cache(maxsize=1)
7
+ def load_text_generator():
8
+ predictor = SentimentAnalyzer()
9
+ return predictor
10
+
11
+
12
+ predictor = load_text_generator()
13
+
14
+
15
+ def write():
16
+ st.markdown(
17
+ """
18
+ # Arabic Sentiment Analysis
19
+
20
+ """
21
+ )
22
+
23
+ input_text = st.text_input(
24
+ "Enter your text here:",
25
+ )
26
+ if st.button("Predict"):
27
+ with st.spinner("Predicting..."):
28
+ prediction, score, all_score = predictor.predict([input_text])
29
+ st.write(f"Result: {prediction[0]}")
30
+ detailed_score = {
31
+ "Positive": all_score[0][0],
32
+ "Neutral": all_score[0][1],
33
+ "Negative": all_score[0][2],
34
+ }
35
+ st.write("All scores:")
36
+ st.write(detailed_score)
backend/sa_utils.py ADDED
@@ -0,0 +1,510 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from contextlib import contextmanager
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from fuzzysearch import find_near_matches
8
+ from pyarabic import araby
9
+ from torch import nn
10
+ from transformers import AutoTokenizer, BertModel, BertPreTrainedModel, pipeline
11
+ from transformers.modeling_outputs import SequenceClassifierOutput
12
+
13
+ from .preprocess import ArabertPreprocessor, url_regexes, user_mention_regex
14
+
15
+ multiple_char_pattern = re.compile(r"(.)\1{2,}", re.DOTALL)
16
+
17
+ # ASAD-NEW_AraBERT_PREP-Balanced
18
+ class NewArabicPreprocessorBalanced(ArabertPreprocessor):
19
+ def __init__(
20
+ self,
21
+ model_name: str,
22
+ keep_emojis: bool = False,
23
+ remove_html_markup: bool = True,
24
+ replace_urls_emails_mentions: bool = True,
25
+ strip_tashkeel: bool = True,
26
+ strip_tatweel: bool = True,
27
+ insert_white_spaces: bool = True,
28
+ remove_non_digit_repetition: bool = True,
29
+ replace_slash_with_dash: bool = None,
30
+ map_hindi_numbers_to_arabic: bool = None,
31
+ apply_farasa_segmentation: bool = None,
32
+ ):
33
+ if "UBC-NLP" in model_name or "CAMeL-Lab" in model_name:
34
+ keep_emojis = True
35
+ remove_non_digit_repetition = True
36
+ super().__init__(
37
+ model_name=model_name,
38
+ keep_emojis=keep_emojis,
39
+ remove_html_markup=remove_html_markup,
40
+ replace_urls_emails_mentions=replace_urls_emails_mentions,
41
+ strip_tashkeel=strip_tashkeel,
42
+ strip_tatweel=strip_tatweel,
43
+ insert_white_spaces=insert_white_spaces,
44
+ remove_non_digit_repetition=remove_non_digit_repetition,
45
+ replace_slash_with_dash=replace_slash_with_dash,
46
+ map_hindi_numbers_to_arabic=map_hindi_numbers_to_arabic,
47
+ apply_farasa_segmentation=apply_farasa_segmentation,
48
+ )
49
+ self.true_model_name = model_name
50
+
51
+ def preprocess(self, text):
52
+ if "UBC-NLP" in self.true_model_name:
53
+ return self.ubc_prep(text)
54
+
55
+ def ubc_prep(self, text):
56
+ text = re.sub("\s", " ", text)
57
+ text = text.replace("\\n", " ")
58
+ text = text.replace("\\r", " ")
59
+ text = araby.strip_tashkeel(text)
60
+ text = araby.strip_tatweel(text)
61
+ # replace all possible URLs
62
+ for reg in url_regexes:
63
+ text = re.sub(reg, " URL ", text)
64
+ text = re.sub("(URL\s*)+", " URL ", text)
65
+ # replace mentions with USER
66
+ text = re.sub(user_mention_regex, " USER ", text)
67
+ text = re.sub("(USER\s*)+", " USER ", text)
68
+ # replace hashtags with HASHTAG
69
+ # text = re.sub(r"#[\w\d]+", " HASH TAG ", text)
70
+ text = text.replace("#", " HASH ")
71
+ text = text.replace("_", " ")
72
+ text = " ".join(text.split())
73
+ # text = re.sub("\B\\[Uu]\w+", "", text)
74
+ text = text.replace("\\U0001f97a", "🥺")
75
+ text = text.replace("\\U0001f928", "🤨")
76
+ text = text.replace("\\U0001f9d8", "😀")
77
+ text = text.replace("\\U0001f975", "😥")
78
+ text = text.replace("\\U0001f92f", "😲")
79
+ text = text.replace("\\U0001f92d", "🤭")
80
+ text = text.replace("\\U0001f9d1", "😐")
81
+ text = text.replace("\\U000e0067", "")
82
+ text = text.replace("\\U000e006e", "")
83
+ text = text.replace("\\U0001f90d", "♥")
84
+ text = text.replace("\\U0001f973", "🎉")
85
+ text = text.replace("\\U0001fa79", "")
86
+ text = text.replace("\\U0001f92b", "🤐")
87
+ text = text.replace("\\U0001f9da", "🦋")
88
+ text = text.replace("\\U0001f90e", "♥")
89
+ text = text.replace("\\U0001f9d0", "🧐")
90
+ text = text.replace("\\U0001f9cf", "")
91
+ text = text.replace("\\U0001f92c", "😠")
92
+ text = text.replace("\\U0001f9f8", "😸")
93
+ text = text.replace("\\U0001f9b6", "💩")
94
+ text = text.replace("\\U0001f932", "🤲")
95
+ text = text.replace("\\U0001f9e1", "🧡")
96
+ text = text.replace("\\U0001f974", "☹")
97
+ text = text.replace("\\U0001f91f", "")
98
+ text = text.replace("\\U0001f9fb", "💩")
99
+ text = text.replace("\\U0001f92a", "🤪")
100
+ text = text.replace("\\U0001f9fc", "")
101
+ text = text.replace("\\U000e0065", "")
102
+ text = text.replace("\\U0001f92e", "💩")
103
+ text = text.replace("\\U000e007f", "")
104
+ text = text.replace("\\U0001f970", "🥰")
105
+ text = text.replace("\\U0001f929", "🤩")
106
+ text = text.replace("\\U0001f6f9", "")
107
+ text = text.replace("🤍", "♥")
108
+ text = text.replace("🦠", "😷")
109
+ text = text.replace("🤢", "مقرف")
110
+ text = text.replace("🤮", "مقرف")
111
+ text = text.replace("🕠", "⌚")
112
+ text = text.replace("🤬", "😠")
113
+ text = text.replace("🤧", "😷")
114
+ text = text.replace("🥳", "🎉")
115
+ text = text.replace("🥵", "🔥")
116
+ text = text.replace("🥴", "☹")
117
+ text = text.replace("🤫", "🤐")
118
+ text = text.replace("🤥", "كذاب")
119
+ text = text.replace("\\u200d", " ")
120
+ text = text.replace("u200d", " ")
121
+ text = text.replace("\\u200c", " ")
122
+ text = text.replace("u200c", " ")
123
+ text = text.replace('"', "'")
124
+ text = text.replace("\\xa0", "")
125
+ text = text.replace("\\u2066", " ")
126
+ text = re.sub("\B\\\[Uu]\w+", "", text)
127
+ text = super(NewArabicPreprocessorBalanced, self).preprocess(text)
128
+
129
+ text = " ".join(text.split())
130
+ return text
131
+
132
+
133
+ """CNNMarbertArabicPreprocessor"""
134
+ # ASAD-CNN_MARBERT
135
+ class CNNMarbertArabicPreprocessor(ArabertPreprocessor):
136
+ def __init__(
137
+ self,
138
+ model_name,
139
+ keep_emojis=False,
140
+ remove_html_markup=True,
141
+ replace_urls_emails_mentions=True,
142
+ remove_elongations=True,
143
+ ):
144
+ if "UBC-NLP" in model_name or "CAMeL-Lab" in model_name:
145
+ keep_emojis = True
146
+ remove_elongations = False
147
+ super().__init__(
148
+ model_name,
149
+ keep_emojis,
150
+ remove_html_markup,
151
+ replace_urls_emails_mentions,
152
+ remove_elongations,
153
+ )
154
+ self.true_model_name = model_name
155
+
156
+ def preprocess(self, text):
157
+ if "UBC-NLP" in self.true_model_name:
158
+ return self.ubc_prep(text)
159
+
160
+ def ubc_prep(self, text):
161
+ text = re.sub("\s", " ", text)
162
+ text = text.replace("\\n", " ")
163
+ text = araby.strip_tashkeel(text)
164
+ text = araby.strip_tatweel(text)
165
+ # replace all possible URLs
166
+ for reg in url_regexes:
167
+ text = re.sub(reg, " URL ", text)
168
+ text = re.sub("(URL\s*)+", " URL ", text)
169
+ # replace mentions with USER
170
+ text = re.sub(user_mention_regex, " USER ", text)
171
+ text = re.sub("(USER\s*)+", " USER ", text)
172
+ # replace hashtags with HASHTAG
173
+ # text = re.sub(r"#[\w\d]+", " HASH TAG ", text)
174
+ text = text.replace("#", " HASH ")
175
+ text = text.replace("_", " ")
176
+ text = " ".join(text.split())
177
+ text = super(CNNMarbertArabicPreprocessor, self).preprocess(text)
178
+ text = text.replace("\u200d", " ")
179
+ text = text.replace("u200d", " ")
180
+ text = text.replace("\u200c", " ")
181
+ text = text.replace("u200c", " ")
182
+ text = text.replace('"', "'")
183
+ # text = re.sub('[\d\.]+', ' NUM ', text)
184
+ # text = re.sub('(NUM\s*)+', ' NUM ', text)
185
+ text = multiple_char_pattern.sub(r"\1\1", text)
186
+ text = " ".join(text.split())
187
+ return text
188
+
189
+
190
+ """Trial5ArabicPreprocessor"""
191
+
192
+
193
+ class Trial5ArabicPreprocessor(ArabertPreprocessor):
194
+ def __init__(
195
+ self,
196
+ model_name,
197
+ keep_emojis=False,
198
+ remove_html_markup=True,
199
+ replace_urls_emails_mentions=True,
200
+ ):
201
+ if "UBC-NLP" in model_name:
202
+ keep_emojis = True
203
+ super().__init__(
204
+ model_name, keep_emojis, remove_html_markup, replace_urls_emails_mentions
205
+ )
206
+ self.true_model_name = model_name
207
+
208
+ def preprocess(self, text):
209
+ if "UBC-NLP" in self.true_model_name:
210
+ return self.ubc_prep(text)
211
+
212
+ def ubc_prep(self, text):
213
+ text = re.sub("\s", " ", text)
214
+ text = text.replace("\\n", " ")
215
+ text = araby.strip_tashkeel(text)
216
+ text = araby.strip_tatweel(text)
217
+ # replace all possible URLs
218
+ for reg in url_regexes:
219
+ text = re.sub(reg, " URL ", text)
220
+ # replace mentions with USER
221
+ text = re.sub(user_mention_regex, " USER ", text)
222
+ # replace hashtags with HASHTAG
223
+ # text = re.sub(r"#[\w\d]+", " HASH TAG ", text)
224
+ text = text.replace("#", " HASH TAG ")
225
+ text = text.replace("_", " ")
226
+ text = " ".join(text.split())
227
+ text = super(Trial5ArabicPreprocessor, self).preprocess(text)
228
+ # text = text.replace("السلام عليكم"," ")
229
+ # text = text.replace(find_near_matches("السلام عليكم",text,max_deletions=3,max_l_dist=3)[0].matched," ")
230
+ return text
231
+
232
+
233
+ """SarcasmArabicPreprocessor"""
234
+
235
+
236
+ class SarcasmArabicPreprocessor(ArabertPreprocessor):
237
+ def __init__(
238
+ self,
239
+ model_name,
240
+ keep_emojis=False,
241
+ remove_html_markup=True,
242
+ replace_urls_emails_mentions=True,
243
+ ):
244
+ if "UBC-NLP" in model_name:
245
+ keep_emojis = True
246
+ super().__init__(
247
+ model_name, keep_emojis, remove_html_markup, replace_urls_emails_mentions
248
+ )
249
+ self.true_model_name = model_name
250
+
251
+ def preprocess(self, text):
252
+ if "UBC-NLP" in self.true_model_name:
253
+ return self.ubc_prep(text)
254
+ else:
255
+ return super(SarcasmArabicPreprocessor, self).preprocess(text)
256
+
257
+ def ubc_prep(self, text):
258
+ text = re.sub("\s", " ", text)
259
+ text = text.replace("\\n", " ")
260
+ text = araby.strip_tashkeel(text)
261
+ text = araby.strip_tatweel(text)
262
+ # replace all possible URLs
263
+ for reg in url_regexes:
264
+ text = re.sub(reg, " URL ", text)
265
+ # replace mentions with USER
266
+ text = re.sub(user_mention_regex, " USER ", text)
267
+ # replace hashtags with HASHTAG
268
+ # text = re.sub(r"#[\w\d]+", " HASH TAG ", text)
269
+ text = text.replace("#", " HASH TAG ")
270
+ text = text.replace("_", " ")
271
+ text = text.replace('"', " ")
272
+ text = " ".join(text.split())
273
+ text = super(SarcasmArabicPreprocessor, self).preprocess(text)
274
+ return text
275
+
276
+
277
+ """NoAOAArabicPreprocessor"""
278
+
279
+
280
+ class NoAOAArabicPreprocessor(ArabertPreprocessor):
281
+ def __init__(
282
+ self,
283
+ model_name,
284
+ keep_emojis=False,
285
+ remove_html_markup=True,
286
+ replace_urls_emails_mentions=True,
287
+ ):
288
+ if "UBC-NLP" in model_name:
289
+ keep_emojis = True
290
+ super().__init__(
291
+ model_name, keep_emojis, remove_html_markup, replace_urls_emails_mentions
292
+ )
293
+ self.true_model_name = model_name
294
+
295
+ def preprocess(self, text):
296
+ if "UBC-NLP" in self.true_model_name:
297
+ return self.ubc_prep(text)
298
+ else:
299
+ return super(NoAOAArabicPreprocessor, self).preprocess(text)
300
+
301
+ def ubc_prep(self, text):
302
+ text = re.sub("\s", " ", text)
303
+ text = text.replace("\\n", " ")
304
+ text = araby.strip_tashkeel(text)
305
+ text = araby.strip_tatweel(text)
306
+ # replace all possible URLs
307
+ for reg in url_regexes:
308
+ text = re.sub(reg, " URL ", text)
309
+ # replace mentions with USER
310
+ text = re.sub(user_mention_regex, " USER ", text)
311
+ # replace hashtags with HASHTAG
312
+ # text = re.sub(r"#[\w\d]+", " HASH TAG ", text)
313
+ text = text.replace("#", " HASH TAG ")
314
+ text = text.replace("_", " ")
315
+ text = " ".join(text.split())
316
+ text = super(NoAOAArabicPreprocessor, self).preprocess(text)
317
+ text = text.replace("السلام عليكم", " ")
318
+ text = text.replace("ورحمة الله وبركاته", " ")
319
+ matched = find_near_matches("السلام عليكم", text, max_deletions=3, max_l_dist=3)
320
+ if len(matched) > 0:
321
+ text = text.replace(matched[0].matched, " ")
322
+ matched = find_near_matches(
323
+ "ورحمة الله وبركاته", text, max_deletions=3, max_l_dist=3
324
+ )
325
+ if len(matched) > 0:
326
+ text = text.replace(matched[0].matched, " ")
327
+ return text
328
+
329
+
330
+ class CnnBertForSequenceClassification(BertPreTrainedModel):
331
+ def __init__(self, config):
332
+ super().__init__(config)
333
+ self.num_labels = config.num_labels
334
+ self.config = config
335
+
336
+ self.bert = BertModel(config)
337
+
338
+ filter_sizes = [1, 2, 3, 4, 5]
339
+ num_filters = 32
340
+ self.convs1 = nn.ModuleList(
341
+ [nn.Conv2d(4, num_filters, (K, config.hidden_size)) for K in filter_sizes]
342
+ )
343
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
344
+ self.classifier = nn.Linear(len(filter_sizes) * num_filters, config.num_labels)
345
+
346
+ self.init_weights()
347
+
348
+ def forward(
349
+ self,
350
+ input_ids=None,
351
+ attention_mask=None,
352
+ token_type_ids=None,
353
+ position_ids=None,
354
+ head_mask=None,
355
+ inputs_embeds=None,
356
+ labels=None,
357
+ output_attentions=None,
358
+ output_hidden_states=None,
359
+ return_dict=None,
360
+ ):
361
+
362
+ return_dict = (
363
+ return_dict if return_dict is not None else self.config.use_return_dict
364
+ )
365
+
366
+ outputs = self.bert(
367
+ input_ids,
368
+ attention_mask=attention_mask,
369
+ token_type_ids=token_type_ids,
370
+ position_ids=position_ids,
371
+ head_mask=head_mask,
372
+ inputs_embeds=inputs_embeds,
373
+ output_attentions=output_attentions,
374
+ output_hidden_states=output_hidden_states,
375
+ return_dict=return_dict,
376
+ )
377
+
378
+ x = outputs[2][-4:]
379
+
380
+ x = torch.stack(x, dim=1)
381
+ x = [F.relu(conv(x)).squeeze(3) for conv in self.convs1]
382
+ x = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x]
383
+ x = torch.cat(x, 1)
384
+ x = self.dropout(x)
385
+ logits = self.classifier(x)
386
+
387
+ loss = None
388
+ if labels is not None:
389
+ if self.config.problem_type is None:
390
+ if self.num_labels == 1:
391
+ self.config.problem_type = "regression"
392
+ elif self.num_labels > 1 and (
393
+ labels.dtype == torch.long or labels.dtype == torch.int
394
+ ):
395
+ self.config.problem_type = "single_label_classification"
396
+ else:
397
+ self.config.problem_type = "multi_label_classification"
398
+
399
+ if self.config.problem_type == "regression":
400
+ loss_fct = nn.MSELoss()
401
+ if self.num_labels == 1:
402
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
403
+ else:
404
+ loss = loss_fct(logits, labels)
405
+ elif self.config.problem_type == "single_label_classification":
406
+ loss_fct = nn.CrossEntropyLoss()
407
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
408
+ elif self.config.problem_type == "multi_label_classification":
409
+ loss_fct = nn.BCEWithLogitsLoss()
410
+ loss = loss_fct(logits, labels)
411
+ if not return_dict:
412
+ output = (logits,) + outputs[2:]
413
+ return ((loss,) + output) if loss is not None else output
414
+
415
+ return SequenceClassifierOutput(
416
+ loss=loss,
417
+ logits=logits,
418
+ hidden_states=None,
419
+ attentions=outputs.attentions,
420
+ )
421
+
422
+
423
+ class CNNTextClassificationPipeline:
424
+ def __init__(self, model_path, device, return_all_scores=False):
425
+ self.model_path = model_path
426
+ self.model = CnnBertForSequenceClassification.from_pretrained(self.model_path)
427
+ # Special handling
428
+ self.device = torch.device("cpu" if device < 0 else f"cuda:{device}")
429
+ if self.device.type == "cuda":
430
+ self.model = self.model.to(self.device)
431
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
432
+ self.return_all_scores = return_all_scores
433
+
434
+ @contextmanager
435
+ def device_placement(self):
436
+ """
437
+ Context Manager allowing tensor allocation on the user-specified device in framework agnostic way.
438
+ Returns:
439
+ Context manager
440
+ Examples::
441
+ # Explicitly ask for tensor allocation on CUDA device :0
442
+ pipe = pipeline(..., device=0)
443
+ with pipe.device_placement():
444
+ # Every framework specific tensor allocation will be done on the request device
445
+ output = pipe(...)
446
+ """
447
+
448
+ if self.device.type == "cuda":
449
+ torch.cuda.set_device(self.device)
450
+
451
+ yield
452
+
453
+ def ensure_tensor_on_device(self, **inputs):
454
+ """
455
+ Ensure PyTorch tensors are on the specified device.
456
+ Args:
457
+ inputs (keyword arguments that should be :obj:`torch.Tensor`): The tensors to place on :obj:`self.device`.
458
+ Return:
459
+ :obj:`Dict[str, torch.Tensor]`: The same as :obj:`inputs` but on the proper device.
460
+ """
461
+ return {
462
+ name: tensor.to(self.device) if isinstance(tensor, torch.Tensor) else tensor
463
+ for name, tensor in inputs.items()
464
+ }
465
+
466
+ def __call__(self, text):
467
+ """
468
+ Classify the text(s) given as inputs.
469
+ Args:
470
+ args (:obj:`str` or :obj:`List[str]`):
471
+ One or several texts (or one list of prompts) to classify.
472
+ Return:
473
+ A list or a list of list of :obj:`dict`: Each result comes as list of dictionaries with the following keys:
474
+ - **label** (:obj:`str`) -- The label predicted.
475
+ - **score** (:obj:`float`) -- The corresponding probability.
476
+ If ``self.return_all_scores=True``, one such dictionary is returned per label.
477
+ """
478
+ # outputs = super().__call__(*args, **kwargs)
479
+ inputs = self.tokenizer.batch_encode_plus(
480
+ text,
481
+ add_special_tokens=True,
482
+ max_length=64,
483
+ padding=True,
484
+ truncation="longest_first",
485
+ return_tensors="pt",
486
+ )
487
+
488
+ with torch.no_grad():
489
+ inputs = self.ensure_tensor_on_device(**inputs)
490
+ predictions = self.model(**inputs)[0].cpu()
491
+
492
+ predictions = predictions.numpy()
493
+
494
+ if self.model.config.num_labels == 1:
495
+ scores = 1.0 / (1.0 + np.exp(-predictions))
496
+ else:
497
+ scores = np.exp(predictions) / np.exp(predictions).sum(-1, keepdims=True)
498
+ if self.return_all_scores:
499
+ return [
500
+ [
501
+ {"label": self.model.config.id2label[i], "score": score.item()}
502
+ for i, score in enumerate(item)
503
+ ]
504
+ for item in scores
505
+ ]
506
+ else:
507
+ return [
508
+ {"label": self.inv_label_map[item.argmax()], "score": item.max().item()}
509
+ for item in scores
510
+ ]
backend/sarcasm.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from .sa import predictor
3
+
4
+
5
+ def write():
6
+ st.markdown(
7
+ """
8
+ # Arabic Sarcasm Detection
9
+
10
+ This is a simple sarcasm detection app that uses the [MARBERT](https://huggingface.co/UBC-NLP/MARBERT) model trained on [ArSarcasm](https://github.com/iabufarha/ArSarcasm)
11
+ """
12
+ )
13
+
14
+ input_text = st.text_input(
15
+ "Enter your text here:",
16
+ )
17
+ if st.button("Predict"):
18
+ with st.spinner("Predicting..."):
19
+ prediction, scores = predictor.get_preds_from_sarcasm([input_text])
20
+ st.write(f"Result: {prediction[0]}")
21
+ st.write(f"Score: {scores[0]}")
backend/services.py ADDED
@@ -0,0 +1,519 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ from functools import lru_cache
5
+ from typing import List
6
+ from urllib.parse import unquote
7
+
8
+ import more_itertools
9
+ import pandas as pd
10
+ import requests
11
+ import streamlit as st
12
+ import wikipedia
13
+ from codetiming import Timer
14
+ from fuzzysearch import find_near_matches
15
+ from googleapi import google
16
+ from tqdm.auto import tqdm
17
+ from transformers import (
18
+ AutoTokenizer,
19
+ GPT2LMHeadModel,
20
+ GPT2Tokenizer,
21
+ pipeline,
22
+ set_seed,
23
+ )
24
+
25
+ from .modeling_gpt2 import GPT2LMHeadModel as GROVERLMHeadModel
26
+ from .preprocess import ArabertPreprocessor
27
+ from .sa_utils import *
28
+ from .utils import download_models, softmax
29
+
30
+ logger = logging.getLogger(__name__)
31
+ # Taken and Modified from https://huggingface.co/spaces/flax-community/chef-transformer/blob/main/app.py
32
+ class TextGeneration:
33
+ def __init__(self):
34
+ self.debug = False
35
+ self.generation_pipline = {}
36
+ self.preprocessor = ArabertPreprocessor(model_name="aragpt2-mega")
37
+ self.tokenizer = GPT2Tokenizer.from_pretrained(
38
+ "aubmindlab/aragpt2-mega", use_fast=False
39
+ )
40
+ self.tokenizer.pad_token = self.tokenizer.eos_token
41
+ self.API_KEY = os.getenv("API_KEY")
42
+ self.headers = {"Authorization": f"Bearer {self.API_KEY}"}
43
+ # self.model_names_or_paths = {
44
+ # "aragpt2-medium": "D:/ML/Models/aragpt2-medium",
45
+ # "aragpt2-base": "D:/ML/Models/aragpt2-base",
46
+ # }
47
+ self.model_names_or_paths = {
48
+ # "aragpt2-medium": "aubmindlab/aragpt2-medium",
49
+ "aragpt2-base": "aubmindlab/aragpt2-base",
50
+ # "aragpt2-large": "aubmindlab/aragpt2-large",
51
+ "aragpt2-mega": "aubmindlab/aragpt2-mega",
52
+ }
53
+ set_seed(42)
54
+
55
+ def load_pipeline(self):
56
+ for model_name, model_path in self.model_names_or_paths.items():
57
+ if "base" in model_name or "medium" in model_name:
58
+ self.generation_pipline[model_name] = pipeline(
59
+ "text-generation",
60
+ model=GPT2LMHeadModel.from_pretrained(model_path),
61
+ tokenizer=self.tokenizer,
62
+ device=-1,
63
+ )
64
+ else:
65
+ self.generation_pipline[model_name] = pipeline(
66
+ "text-generation",
67
+ model=GROVERLMHeadModel.from_pretrained(model_path),
68
+ tokenizer=self.tokenizer,
69
+ device=-1,
70
+ )
71
+
72
+ def load(self):
73
+ if not self.debug:
74
+ self.load_pipeline()
75
+
76
+ def generate(
77
+ self,
78
+ model_name,
79
+ prompt,
80
+ max_new_tokens: int,
81
+ temperature: float,
82
+ top_k: int,
83
+ top_p: float,
84
+ repetition_penalty: float,
85
+ no_repeat_ngram_size: int,
86
+ do_sample: bool,
87
+ num_beams: int,
88
+ ):
89
+ logger.info(f"Generating with {model_name}")
90
+ prompt = self.preprocessor.preprocess(prompt)
91
+ return_full_text = False
92
+ return_text = True
93
+ num_return_sequences = 1
94
+ pad_token_id = 0
95
+ eos_token_id = 0
96
+ input_tok = self.tokenizer.tokenize(prompt)
97
+ max_length = len(input_tok) + max_new_tokens
98
+ if max_length > 1024:
99
+ max_length = 1024
100
+ if not self.debug:
101
+ generated_text = self.generation_pipline[model_name.lower()](
102
+ prompt,
103
+ max_length=max_length,
104
+ temperature=temperature,
105
+ top_k=top_k,
106
+ top_p=top_p,
107
+ repetition_penalty=repetition_penalty,
108
+ no_repeat_ngram_size=no_repeat_ngram_size,
109
+ pad_token_id=pad_token_id,
110
+ eos_token_id=eos_token_id,
111
+ return_full_text=return_full_text,
112
+ return_text=return_text,
113
+ do_sample=do_sample,
114
+ num_beams=num_beams,
115
+ num_return_sequences=num_return_sequences,
116
+ )[0]["generated_text"]
117
+ else:
118
+ generated_text = self.generate_by_query(
119
+ prompt,
120
+ model_name,
121
+ max_length=max_length,
122
+ temperature=temperature,
123
+ top_k=top_k,
124
+ top_p=top_p,
125
+ repetition_penalty=repetition_penalty,
126
+ no_repeat_ngram_size=no_repeat_ngram_size,
127
+ pad_token_id=pad_token_id,
128
+ eos_token_id=eos_token_id,
129
+ return_full_text=return_full_text,
130
+ return_text=return_text,
131
+ do_sample=do_sample,
132
+ num_beams=num_beams,
133
+ num_return_sequences=num_return_sequences,
134
+ )
135
+ # print(generated_text)
136
+ if isinstance(generated_text, dict):
137
+ if "error" in generated_text:
138
+ if "is currently loading" in generated_text["error"]:
139
+ return f"Model is currently loading, estimated time is {generated_text['estimated_time']}"
140
+ return generated_text["error"]
141
+ else:
142
+ return "Something happened 🤷‍♂️!!"
143
+ else:
144
+ generated_text = generated_text[0]["generated_text"]
145
+
146
+ logger.info(f"Prompt: {prompt}")
147
+ logger.info(f"Generated text: {generated_text}")
148
+ return self.preprocessor.unpreprocess(generated_text)
149
+
150
+ def query(self, payload, model_name):
151
+ data = json.dumps(payload)
152
+ url = (
153
+ "https://api-inference.huggingface.co/models/aubmindlab/"
154
+ + model_name.lower()
155
+ )
156
+ response = requests.request("POST", url, headers=self.headers, data=data)
157
+ return json.loads(response.content.decode("utf-8"))
158
+
159
+ def generate_by_query(
160
+ self,
161
+ prompt: str,
162
+ model_name: str,
163
+ max_length: int,
164
+ temperature: float,
165
+ top_k: int,
166
+ top_p: float,
167
+ repetition_penalty: float,
168
+ no_repeat_ngram_size: int,
169
+ pad_token_id: int,
170
+ eos_token_id: int,
171
+ return_full_text: int,
172
+ return_text: int,
173
+ do_sample: bool,
174
+ num_beams: int,
175
+ num_return_sequences: int,
176
+ ):
177
+ payload = {
178
+ "inputs": prompt,
179
+ "parameters": {
180
+ "max_length ": max_length,
181
+ "top_k": top_k,
182
+ "top_p": top_p,
183
+ "temperature": temperature,
184
+ "repetition_penalty": repetition_penalty,
185
+ "no_repeat_ngram_size": no_repeat_ngram_size,
186
+ "pad_token_id": pad_token_id,
187
+ "eos_token_id": eos_token_id,
188
+ "return_full_text": return_full_text,
189
+ "return_text": return_text,
190
+ "pad_token_id": pad_token_id,
191
+ "do_sample": do_sample,
192
+ "num_beams": num_beams,
193
+ "num_return_sequences": num_return_sequences,
194
+ },
195
+ "options": {
196
+ "use_cache": True,
197
+ },
198
+ }
199
+ return self.query(payload, model_name)
200
+
201
+
202
+ class SentimentAnalyzer:
203
+ def __init__(self):
204
+ self.sa_models = [
205
+ "sa_trial5_1",
206
+ # "sa_no_aoa_in_neutral",
207
+ # "sa_cnnbert",
208
+ # "sa_sarcasm",
209
+ # "sar_trial10",
210
+ # "sa_no_AOA",
211
+ ]
212
+ download_models(self.sa_models)
213
+ # fmt: off
214
+ self.processors = {
215
+ "sa_trial5_1": Trial5ArabicPreprocessor(model_name='UBC-NLP/MARBERT'),
216
+ # "sa_no_aoa_in_neutral": NewArabicPreprocessorBalanced(model_name='UBC-NLP/MARBERT'),
217
+ # "sa_cnnbert": CNNMarbertArabicPreprocessor(model_name='UBC-NLP/MARBERT'),
218
+ # "sa_sarcasm": SarcasmArabicPreprocessor(model_name='UBC-NLP/MARBERT'),
219
+ # "sar_trial10": SarcasmArabicPreprocessor(model_name='UBC-NLP/MARBERT'),
220
+ # "sa_no_AOA": NewArabicPreprocessorBalanced(model_name='UBC-NLP/MARBERT'),
221
+ }
222
+
223
+ self.pipelines = {
224
+ "sa_trial5_1": [pipeline("sentiment-analysis", model="{}/train_{}/best_model".format("sa_trial5_1",i), device=-1,return_all_scores =True) for i in tqdm(range(0,5), desc=f"Loading pipeline for model: sa_trial5_1")],
225
+ # "sa_no_aoa_in_neutral": [pipeline("sentiment-analysis", model="{}/train_{}/best_model".format("sa_no_aoa_in_neutral",i), device=-1,return_all_scores =True) for i in tqdm(range(0,5), desc=f"Loading pipeline for model: sa_no_aoa_in_neutral")],
226
+ # "sa_cnnbert": [CNNTextClassificationPipeline("{}/train_{}/best_model".format("sa_cnnbert",i), device=-1, return_all_scores =True) for i in tqdm(range(0,5), desc=f"Loading pipeline for model: sa_cnnbert")],
227
+ # "sa_sarcasm": [pipeline("sentiment-analysis", model="{}/train_{}/best_model".format("sa_sarcasm",i), device=-1,return_all_scores =True) for i in tqdm(range(0,5), desc=f"Loading pipeline for model: sa_sarcasm")],
228
+ # "sar_trial10": [pipeline("sentiment-analysis", model="{}/train_{}/best_model".format("sar_trial10",i), device=-1,return_all_scores =True) for i in tqdm(range(0,5), desc=f"Loading pipeline for model: sar_trial10")],
229
+ # "sa_no_AOA": [pipeline("sentiment-analysis", model="{}/train_{}/best_model".format("sa_no_AOA",i), device=-1,return_all_scores =True) for i in tqdm(range(0,5), desc=f"Loading pipeline for model: sa_no_AOA")],
230
+ }
231
+ # fmt: on
232
+
233
+ def get_preds_from_sarcasm(self, texts):
234
+ prep = self.processors["sar_trial10"]
235
+ prep_texts = [prep.preprocess(x) for x in texts]
236
+
237
+ preds_df = pd.DataFrame([])
238
+ for i in range(0, 5):
239
+ preds = []
240
+ for s in more_itertools.chunked(list(prep_texts), 128):
241
+ preds.extend(self.pipelines["sar_trial10"][i](s))
242
+ preds_df[f"model_{i}"] = preds
243
+
244
+ final_labels = []
245
+ final_scores = []
246
+ for id, row in preds_df.iterrows():
247
+ pos_total = 0
248
+ neu_total = 0
249
+ for pred in row[:]:
250
+ pos_total += pred[0]["score"]
251
+ neu_total += pred[1]["score"]
252
+
253
+ pos_avg = pos_total / len(row[:])
254
+ neu_avg = neu_total / len(row[:])
255
+
256
+ final_labels.append(
257
+ self.pipelines["sar_trial10"][0].model.config.id2label[
258
+ np.argmax([pos_avg, neu_avg])
259
+ ]
260
+ )
261
+ final_scores.append(np.max([pos_avg, neu_avg]))
262
+
263
+ return final_labels, final_scores
264
+
265
+ def get_preds_from_a_model(self, texts: List[str], model_name):
266
+ try:
267
+ prep = self.processors[model_name]
268
+
269
+ prep_texts = [prep.preprocess(x) for x in texts]
270
+ if model_name == "sa_sarcasm":
271
+ sarcasm_label, _ = self.get_preds_from_sarcasm(texts)
272
+ sarcastic_map = {"Not_Sarcastic": "غير ساخر", "Sarcastic": "ساخر"}
273
+ labeled_prep_texts = []
274
+ for t, l in zip(prep_texts, sarcasm_label):
275
+ labeled_prep_texts.append(sarcastic_map[l] + " [SEP] " + t)
276
+
277
+ preds_df = pd.DataFrame([])
278
+ for i in range(0, 5):
279
+ preds = []
280
+ for s in more_itertools.chunked(list(prep_texts), 128):
281
+ preds.extend(self.pipelines[model_name][i](s))
282
+ preds_df[f"model_{i}"] = preds
283
+
284
+ final_labels = []
285
+ final_scores = []
286
+ final_scores_list = []
287
+ for id, row in preds_df.iterrows():
288
+ pos_total = 0
289
+ neg_total = 0
290
+ neu_total = 0
291
+ for pred in row[2:]:
292
+ pos_total += pred[0]["score"]
293
+ neu_total += pred[1]["score"]
294
+ neg_total += pred[2]["score"]
295
+
296
+ pos_avg = pos_total / 5
297
+ neu_avg = neu_total / 5
298
+ neg_avg = neg_total / 5
299
+
300
+ if model_name == "sa_no_aoa_in_neutral":
301
+ final_labels.append(
302
+ self.pipelines[model_name][0].model.config.id2label[
303
+ np.argmax([neu_avg, neg_avg, pos_avg])
304
+ ]
305
+ )
306
+ else:
307
+ final_labels.append(
308
+ self.pipelines[model_name][0].model.config.id2label[
309
+ np.argmax([pos_avg, neu_avg, neg_avg])
310
+ ]
311
+ )
312
+ final_scores.append(np.max([pos_avg, neu_avg, neg_avg]))
313
+ final_scores_list.append((pos_avg, neu_avg, neg_avg))
314
+ except RuntimeError as e:
315
+ if model_name == "sa_cnnbert":
316
+ return (
317
+ ["Neutral"] * len(texts),
318
+ [0.0] * len(texts),
319
+ [(0.0, 0.0, 0.0)] * len(texts),
320
+ )
321
+ else:
322
+ raise RuntimeError(e)
323
+ return final_labels, final_scores, final_scores_list
324
+
325
+ def predict(self, texts: List[str]):
326
+ logger.info(f"Predicting for: {texts}")
327
+ # (
328
+ # new_balanced_label,
329
+ # new_balanced_score,
330
+ # new_balanced_score_list,
331
+ # ) = self.get_preds_from_a_model(texts, "sa_no_aoa_in_neutral")
332
+ # (
333
+ # cnn_marbert_label,
334
+ # cnn_marbert_score,
335
+ # cnn_marbert_score_list,
336
+ # ) = self.get_preds_from_a_model(texts, "sa_cnnbert")
337
+ trial5_label, trial5_score, trial5_score_list = self.get_preds_from_a_model(
338
+ texts, "sa_trial5_1"
339
+ )
340
+ # no_aoa_label, no_aoa_score, no_aoa_score_list = self.get_preds_from_a_model(
341
+ # texts, "sa_no_AOA"
342
+ # )
343
+ # sarcasm_label, sarcasm_score, sarcasm_score_list = self.get_preds_from_a_model(
344
+ # texts, "sa_sarcasm"
345
+ # )
346
+
347
+ id_label_map = {0: "Positive", 1: "Neutral", 2: "Negative"}
348
+
349
+ final_ensemble_prediction = []
350
+ final_ensemble_score = []
351
+ final_ensemble_all_score = []
352
+ for entry in zip(
353
+ # new_balanced_score_list,
354
+ # cnn_marbert_score_list,
355
+ trial5_score_list,
356
+ # no_aoa_score_list,
357
+ # sarcasm_score_list,
358
+ ):
359
+ pos_score = 0
360
+ neu_score = 0
361
+ neg_score = 0
362
+ for s in entry:
363
+ pos_score += s[0] * 1.57
364
+ neu_score += s[1] * 0.98
365
+ neg_score += s[2] * 0.93
366
+
367
+ # weighted 2
368
+ # pos_score += s[0]*1.67
369
+ # neu_score += s[1]
370
+ # neg_score += s[2]*0.95
371
+
372
+ final_ensemble_prediction.append(
373
+ id_label_map[np.argmax([pos_score, neu_score, neg_score])]
374
+ )
375
+ final_ensemble_score.append(np.max([pos_score, neu_score, neg_score]))
376
+ final_ensemble_all_score.append(
377
+ softmax(np.array([pos_score, neu_score, neg_score])).tolist()
378
+ )
379
+
380
+ logger.info(f"Result: {final_ensemble_prediction}")
381
+ logger.info(f"Score: {final_ensemble_score}")
382
+ logger.info(f"All Scores: {final_ensemble_all_score}")
383
+ return final_ensemble_prediction, final_ensemble_score, final_ensemble_all_score
384
+
385
+
386
+ wikipedia.set_lang("ar")
387
+
388
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
389
+
390
+ preprocessor = ArabertPreprocessor("wissamantoun/araelectra-base-artydiqa")
391
+ logger.info("Loading QA Pipeline...")
392
+ tokenizer = AutoTokenizer.from_pretrained("wissamantoun/araelectra-base-artydiqa")
393
+ qa_pipe = pipeline("question-answering", model="wissamantoun/araelectra-base-artydiqa")
394
+ logger.info("Finished loading QA Pipeline...")
395
+
396
+
397
+ @lru_cache(maxsize=100)
398
+ def get_qa_answers(question):
399
+ logger.info("\n=================================================================")
400
+ logger.info(f"Question: {question}")
401
+
402
+ if "وسام أنطون" in question or "wissam antoun" in question.lower():
403
+ return {
404
+ "title": "Creator",
405
+ "results": [
406
+ {
407
+ "score": 1.0,
408
+ "new_start": 0,
409
+ "new_end": 12,
410
+ "new_answer": "My Creator 😜",
411
+ "original": "My Creator 😜",
412
+ "link": "https://github.com/WissamAntoun/",
413
+ }
414
+ ],
415
+ }
416
+ search_timer = Timer(
417
+ "search and wiki", text="Search and Wikipedia Time: {:.2f}", logger=logging.info
418
+ )
419
+ try:
420
+ search_timer.start()
421
+ search_results = google.search(
422
+ question + " site:ar.wikipedia.org", lang="ar", area="ar"
423
+ )
424
+ if len(search_results) == 0:
425
+ return {}
426
+
427
+ page_name = search_results[0].link.split("wiki/")[-1]
428
+ wiki_page = wikipedia.page(unquote(page_name))
429
+ wiki_page_content = wiki_page.content
430
+ search_timer.stop()
431
+ except:
432
+ return {}
433
+
434
+ sections = []
435
+ for section in re.split("== .+ ==[^=]", wiki_page_content):
436
+ if not section.isspace():
437
+ prep_section = tokenizer.tokenize(preprocessor.preprocess(section))
438
+ if len(prep_section) > 500:
439
+ subsections = []
440
+ for subsection in re.split("=== .+ ===", section):
441
+ if subsection.isspace():
442
+ continue
443
+ prep_subsection = tokenizer.tokenize(
444
+ preprocessor.preprocess(subsection)
445
+ )
446
+ subsections.append(subsection)
447
+ # logger.info(f"Subsection found with length: {len(prep_subsection)}")
448
+ sections.extend(subsections)
449
+ else:
450
+ # logger.info(f"Regular Section with length: {len(prep_section)}")
451
+ sections.append(section)
452
+
453
+ full_len_sections = []
454
+ temp_section = ""
455
+ for section in sections:
456
+ if (
457
+ len(tokenizer.tokenize(preprocessor.preprocess(temp_section)))
458
+ + len(tokenizer.tokenize(preprocessor.preprocess(section)))
459
+ > 384
460
+ ):
461
+ if temp_section == "":
462
+ temp_section = section
463
+ continue
464
+ full_len_sections.append(temp_section)
465
+ # logger.info(
466
+ # f"full section length: {len(tokenizer.tokenize(preprocessor.preprocess(temp_section)))}"
467
+ # )
468
+ temp_section = ""
469
+ else:
470
+ temp_section += " " + section + " "
471
+ if temp_section != "":
472
+ full_len_sections.append(temp_section)
473
+
474
+ reader_time = Timer("electra", text="Reader Time: {:.2f}", logger=logging.info)
475
+ reader_time.start()
476
+ results = qa_pipe(
477
+ question=[preprocessor.preprocess(question)] * len(full_len_sections),
478
+ context=[preprocessor.preprocess(x) for x in full_len_sections],
479
+ )
480
+
481
+ if not isinstance(results, list):
482
+ results = [results]
483
+
484
+ logger.info(f"Wiki Title: {unquote(page_name)}")
485
+ logger.info(f"Total Sections: {len(sections)}")
486
+ logger.info(f"Total Full Sections: {len(full_len_sections)}")
487
+
488
+ for result, section in zip(results, full_len_sections):
489
+ result["original"] = section
490
+ answer_match = find_near_matches(
491
+ " " + preprocessor.unpreprocess(result["answer"]) + " ",
492
+ result["original"],
493
+ max_l_dist=min(5, len(preprocessor.unpreprocess(result["answer"])) // 2),
494
+ max_deletions=0,
495
+ )
496
+ try:
497
+ result["new_start"] = answer_match[0].start
498
+ result["new_end"] = answer_match[0].end
499
+ result["new_answer"] = answer_match[0].matched
500
+ result["link"] = (
501
+ search_results[0].link + "#:~:text=" + result["new_answer"].strip()
502
+ )
503
+ except:
504
+ result["new_start"] = result["start"]
505
+ result["new_end"] = result["end"]
506
+ result["new_answer"] = result["answer"]
507
+ result["original"] = preprocessor.preprocess(result["original"])
508
+ result["link"] = search_results[0].link
509
+ logger.info(f"Answers: {preprocessor.preprocess(result['new_answer'])}")
510
+
511
+ sorted_results = sorted(results, reverse=True, key=lambda x: x["score"])
512
+
513
+ return_dict = {}
514
+ return_dict["title"] = unquote(page_name)
515
+ return_dict["results"] = sorted_results
516
+
517
+ reader_time.stop()
518
+ logger.info(f"Total time spent: {reader_time.last + search_timer.last}")
519
+ return return_dict
backend/utils.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import numpy as np
3
+ import psutil
4
+ import os
5
+ from tqdm.auto import tqdm
6
+ import logging
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ def get_current_ram_usage():
12
+ ram = psutil.virtual_memory()
13
+ return ram.available / 1024 / 1024 / 1024, ram.total / 1024 / 1024 / 1024
14
+
15
+
16
+ def download_models(models):
17
+ for model in tqdm(models, desc="Downloading models"):
18
+ logger.info(f"Downloading {model}")
19
+ for i in range(0, 5):
20
+ curr_dir = f"{model}/train_{i}/best_model/"
21
+ os.makedirs(curr_dir, exist_ok=True)
22
+ os.system(
23
+ f"wget -q https://huggingface.co/researchaccount/{model}/resolve/main/train_{i}/best_model/config.json -P {curr_dir}"
24
+ )
25
+ os.system(
26
+ f"wget -q https://huggingface.co/researchaccount/{model}/resolve/main/train_{i}/best_model/pytorch_model.bin -P {curr_dir}"
27
+ )
28
+ os.system(
29
+ f"wget -q https://huggingface.co/researchaccount/{model}/resolve/main/train_{i}/best_model/special_tokens_map.json -P {curr_dir}"
30
+ )
31
+ os.system(
32
+ f"wget -q https://huggingface.co/researchaccount/{model}/resolve/main/train_{i}/best_model/tokenizer_config.json -P {curr_dir}"
33
+ )
34
+ os.system(
35
+ f"wget -q https://huggingface.co/researchaccount/{model}/resolve/main/train_{i}/best_model/training_args.bin -P {curr_dir}"
36
+ )
37
+ os.system(
38
+ f"wget -q https://huggingface.co/researchaccount/{model}/resolve/main/train_{i}/best_model/vocab.txt -P {curr_dir}"
39
+ )
40
+
41
+
42
+ def softmax(x):
43
+ return np.exp(x) / sum(np.exp(x))
44
+
45
+
46
+ def ga(file):
47
+ code = """
48
+ <!-- Global site tag (gtag.js) - Google Analytics -->
49
+ <script async src="https://www.googletagmanager.com/gtag/js?id=G-NH9HWCW08F"></script>
50
+ <script>
51
+ window.dataLayer = window.dataLayer || [];
52
+ function gtag(){dataLayer.push(arguments);}
53
+ gtag('js', new Date());
54
+ gtag('config', 'G-NH9HWCW08F');
55
+ </script>
56
+ """
57
+
58
+ a = os.path.dirname(file) + "/static/index.html"
59
+ with open(a, "r") as f:
60
+ data = f.read()
61
+ if len(re.findall("G-", data)) == 0:
62
+ with open(a, "w") as ff:
63
+ newdata = re.sub("<head>", "<head>" + code, data)
64
+ ff.write(newdata)
packages.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ openjdk-11-jre
2
+ curl
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit==0.84.2
2
+ arabic-reshaper==2.1.3
3
+ python-bidi==0.4.2
4
+ PyArabic
5
+ farasapy==0.0.14
6
+ emoji==1.4.2
7
+ awesome_streamlit
8
+ torch==1.9.0
9
+ transformers==4.10.0
10
+ psutil==5.8.0
11
+ fuzzysearch==0.7.3
12
+ more-itertools==8.9.0
13
+ cookiecutter
14
+ git+https://github.com/dantru7/Google-Search-API
15
+ codetiming==1.3.0
16
+ htbuilder
17
+ wikipedia==1.4.0
test.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ #%%
2
+ from transformers import GPT2Tokenizer
3
+
4
+ # %%
5
+ tok = GPT2Tokenizer.from_pretrained("D:/ML/Models/aragpt2-medium", use_fast=False)
6
+ # %%
7
+ tok.pad_token = tok.eos_token
8
+ #%%
9
+ tok.pad_token_id = [tok.eos_token_id]
10
+ # %%