xzxyx abxhr commited on
Commit
b71b65d
0 Parent(s):

Duplicate from abxhr/design-project

Browse files

Co-authored-by: Abshar Mohammed Aslam <abxhr@users.noreply.huggingface.co>

README.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Arabic NLP Demo
3
+ emoji: ⌨
4
+ colorFrom: purple
5
+ colorTo: green
6
+ sdk: streamlit
7
+ app_file: app.py
8
+ pinned: true
9
+ duplicated_from: abxhr/design-project
10
+ ---
app.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import awesome_streamlit as ast
2
+ import streamlit as st
3
+
4
+ from backend.utils import get_current_ram_usage, ga
5
+
6
+ import backend.aragpt
7
+ import backend.home
8
+ import backend.processor
9
+ import backend.sa
10
+ import backend.qa
11
+
12
+ st.set_page_config(
13
+ page_title="TEST", page_icon="📖", initial_sidebar_state="expanded", layout="wide"
14
+ )
15
+
16
+ ga(st.__file__)
17
+
18
+ PAGES = {
19
+ "Home": backend.home,
20
+ "Arabic Sentiment Analysis": backend.sa,
21
+ "Arabic Question Answering": backend.qa,
22
+ }
23
+
24
+
25
+ st.sidebar.title("Navigation")
26
+ selection = st.sidebar.radio("Pages", list(PAGES.keys()))
27
+
28
+ page = PAGES[selection]
29
+ # with st.spinner(f"Loading {selection} ..."):
30
+ ast.shared.components.write_page(page)
31
+
32
+ st.sidebar.header("Info")
33
+ st.sidebar.write("Arabic NLP by [**Abshar Mohammed Aslam** (*2019A7PS0233U*)](https://github.com/abxhr)")
34
+
35
+ st.sidebar.write("Submitted to *Dr. Sujala D. Shetty*")
36
+ # if st.sidebar.checkbox("Show RAM usage"):
37
+ # ram = get_current_ram_usage()
38
+ # st.sidebar.write("Ram usage: {:.2f}/{:.2f} GB".format(ram[0], ram[1]))
backend/__init__.py ADDED
File without changes
backend/aragpt.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from .services import TextGeneration
3
+ from tokenizers import Tokenizer
4
+ from functools import lru_cache
5
+
6
+ # @st.cache(allow_output_mutation=False, hash_funcs={Tokenizer: str})
7
+ @lru_cache(maxsize=1)
8
+ def load_text_generator():
9
+ generator = TextGeneration()
10
+ generator.load()
11
+ return generator
12
+
13
+
14
+ generator = load_text_generator()
15
+
16
+ qa_prompt = """
17
+ أجب عن السؤال التالي:
18
+ """
19
+ qa_prompt_post = """ الجواب هو """
20
+ qa_prompt_post_year = """ في سنة: """
21
+
22
+
23
+ def write():
24
+ st.markdown(
25
+ """
26
+ <h1 style="text-align:left;">Arabic Language Generation</h1>
27
+ """,
28
+ unsafe_allow_html=True,
29
+ )
30
+
31
+ # Sidebar
32
+
33
+ # Taken from https://huggingface.co/spaces/flax-community/spanish-gpt2/blob/main/app.py
34
+ st.sidebar.subheader("Configurable parameters")
35
+
36
+ model_name = st.sidebar.selectbox(
37
+ "Model Selector",
38
+ options=[
39
+ "AraGPT2-Base",
40
+ # "AraGPT2-Medium",
41
+ # "Aragpt2-Large",
42
+ "AraGPT2-Mega",
43
+ ],
44
+ index=0,
45
+ )
46
+
47
+ max_new_tokens = st.sidebar.number_input(
48
+ "Maximum length",
49
+ min_value=0,
50
+ max_value=1024,
51
+ value=100,
52
+ help="The maximum length of the sequence to be generated.",
53
+ )
54
+ temp = st.sidebar.slider(
55
+ "Temperature",
56
+ value=1.0,
57
+ min_value=0.1,
58
+ max_value=100.0,
59
+ help="The value used to module the next token probabilities.",
60
+ )
61
+ top_k = st.sidebar.number_input(
62
+ "Top k",
63
+ value=10,
64
+ help="The number of highest probability vocabulary tokens to keep for top-k-filtering.",
65
+ )
66
+ top_p = st.sidebar.number_input(
67
+ "Top p",
68
+ value=0.95,
69
+ help=" If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.",
70
+ )
71
+ do_sample = st.sidebar.selectbox(
72
+ "Sampling?",
73
+ (True, False),
74
+ help="Whether or not to use sampling; use greedy decoding otherwise.",
75
+ )
76
+ num_beams = st.sidebar.number_input(
77
+ "Number of beams",
78
+ min_value=1,
79
+ max_value=10,
80
+ value=3,
81
+ help="The number of beams to use for beam search.",
82
+ )
83
+ repetition_penalty = st.sidebar.number_input(
84
+ "Repetition Penalty",
85
+ min_value=0.0,
86
+ value=3.0,
87
+ step=0.1,
88
+ help="The parameter for repetition penalty. 1.0 means no penalty",
89
+ )
90
+ no_repeat_ngram_size = st.sidebar.number_input(
91
+ "No Repeat N-Gram Size",
92
+ min_value=0,
93
+ value=3,
94
+ help="If set to int > 0, all ngrams of that size can only occur once.",
95
+ )
96
+
97
+ st.write("#")
98
+
99
+ col = st.columns(2)
100
+
101
+ col[0].image("images/AraGPT2.png", width=200)
102
+
103
+ st.markdown(
104
+ """
105
+
106
+ <h3 style="text-align:left;">AraGPT2 is GPT2 model trained from scratch on 77GB of Arabic text.</h3>
107
+ <h4 style="text-align:left;"> More details in our <a href="https://github.com/aub-mind/arabert/tree/master/aragpt2">repo</a>.</h4>
108
+
109
+ <p style="text-align:left;"><p>
110
+ <p style="text-align:left;">Use the generation paramters on the sidebar to adjust generation quality.</p>
111
+ <p style="text-align:right;"><p>
112
+ """,
113
+ unsafe_allow_html=True,
114
+ )
115
+
116
+ # col[0].write(
117
+ # "AraGPT2 is trained from screatch on 77GB of Arabic text. More details in our [repo](https://github.com/aub-mind/arabert/tree/master/aragpt2)."
118
+ # )
119
+ # st.write("## Generate Arabic Text")
120
+
121
+ st.markdown(
122
+ """
123
+ <style>
124
+ p, div, input, label, textarea{
125
+ text-align: right;
126
+ }
127
+ </style>
128
+ """,
129
+ unsafe_allow_html=True,
130
+ )
131
+
132
+ prompt = st.text_area(
133
+ "Prompt",
134
+ "يحكى أن مزارعا مخادعا قام ببيع بئر الماء الموجود في أرضه لجاره مقابل مبلغ كبير من المال",
135
+ )
136
+ if st.button("Generate"):
137
+ with st.spinner("Generating..."):
138
+ generated_text = generator.generate(
139
+ prompt=prompt,
140
+ model_name=model_name,
141
+ max_new_tokens=max_new_tokens,
142
+ temperature=temp,
143
+ top_k=top_k,
144
+ top_p=top_p,
145
+ repetition_penalty=repetition_penalty,
146
+ do_sample=do_sample,
147
+ num_beams=num_beams,
148
+ no_repeat_ngram_size=no_repeat_ngram_size,
149
+ )
150
+ st.write(generated_text)
151
+
152
+ st.markdown("---")
153
+ st.subheader("")
154
+ st.markdown(
155
+ """
156
+ <p style="text-align:left;"><p>
157
+ <h2 style="text-align:left;">Zero-Shot Question Answering</h2>
158
+
159
+ <p style="text-align:left;">Adjust the maximum length to closely match the expected output length. Setting the Sampling paramter to False is recommended</p>
160
+ <p style="text-align:left;"><p>
161
+ """,
162
+ unsafe_allow_html=True,
163
+ )
164
+
165
+ question = st.text_input(
166
+ "Question", "من كان رئيس ألمانيا النازية في الحرب العالمية الثانية ؟"
167
+ )
168
+ is_date = st.checkbox("Help the model: Is the answer a date?")
169
+ if st.button("Answer"):
170
+
171
+ prompt2 = qa_prompt + question + qa_prompt_post
172
+ if is_date:
173
+ prompt2 += qa_prompt_post_year
174
+ else:
175
+ prompt2 += " : "
176
+ with st.spinner("Thinking..."):
177
+ answer = generator.generate(
178
+ prompt=prompt2,
179
+ model_name=model_name,
180
+ max_new_tokens=max_new_tokens,
181
+ temperature=temp,
182
+ top_k=top_k,
183
+ top_p=top_p,
184
+ repetition_penalty=repetition_penalty,
185
+ do_sample=do_sample,
186
+ num_beams=num_beams,
187
+ no_repeat_ngram_size=no_repeat_ngram_size,
188
+ )
189
+ st.write(answer)
backend/home.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import awesome_streamlit as ast
3
+
4
+
5
+ def write():
6
+ st.markdown(
7
+ """
8
+ # Arabic Natural Language Processing
9
+
10
+ Design project for **Arabic Natural Language Processing**, by [**Abshar Mohammed Aslam**](https://github.com/abxhr).
11
+ """
12
+ )
13
+ st.markdown("#")
14
+
15
+
16
+ st.markdown(
17
+ """
18
+
19
+ """
20
+ )
backend/modeling_gpt2.py ADDED
@@ -0,0 +1,1599 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ """
18
+ PyTorch OpenAI GPT-2 model.
19
+ Adapted from https://github.com/huggingface/transformers/blob/v4.0.1/src/transformers/models/gpt2/modeling_gpt2.py
20
+ and https://github.com/ghosthamlet/gpt2-ml-torch/blob/master/gpt2_ml_torch/modeling_gpt2.py
21
+ """
22
+
23
+
24
+ import logging
25
+ import os
26
+ from dataclasses import dataclass
27
+ from typing import List, Optional, Tuple
28
+
29
+ import torch
30
+ import torch.nn as nn
31
+ from torch.nn import CrossEntropyLoss, MSELoss
32
+ from transformers import CONFIG_NAME, WEIGHTS_NAME, GPT2Config, GPT2Model
33
+ from transformers.activations import ACT2FN
34
+ from transformers.file_utils import (
35
+ ModelOutput,
36
+ add_code_sample_docstrings,
37
+ add_start_docstrings,
38
+ add_start_docstrings_to_model_forward,
39
+ replace_return_docstrings,
40
+ )
41
+ from transformers.modeling_outputs import (
42
+ BaseModelOutputWithPastAndCrossAttentions,
43
+ CausalLMOutputWithCrossAttentions,
44
+ SequenceClassifierOutputWithPast,
45
+ TokenClassifierOutput,
46
+ )
47
+ from transformers.modeling_utils import (
48
+ Conv1D,
49
+ PreTrainedModel,
50
+ SequenceSummary,
51
+ find_pruneable_heads_and_indices,
52
+ prune_conv1d_layer,
53
+ )
54
+ from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
55
+
56
+ # THe Difference from Transformers is code under _USE_GROVER
57
+ _USE_GROVER = True
58
+
59
+ logger = logging.getLogger(__name__)
60
+
61
+ _CONFIG_FOR_DOC = "GPT2Config"
62
+ _TOKENIZER_FOR_DOC = "GPT2Tokenizer"
63
+
64
+ GPT2_PRETRAINED_MODEL_ARCHIVE_LIST = [
65
+ "gpt2",
66
+ "gpt2-medium",
67
+ "gpt2-large",
68
+ "gpt2-xl",
69
+ "distilgpt2",
70
+ # See all GPT-2 models at https://huggingface.co/models?filter=gpt2
71
+ ]
72
+
73
+ logger.setLevel(logging.INFO)
74
+ console = logging.StreamHandler()
75
+ console.setLevel(logging.INFO)
76
+ logger.addHandler(console)
77
+
78
+ _GPT2_ML_TF_TO_TORCH = {
79
+ "LayerNorm_embed_norm": "emb_norm",
80
+ "pos_embed": "wpe.weight",
81
+ "word_embed": "wte.weight",
82
+ "layer": "h",
83
+ # Most importently This two layer norm must be put on the same position as gpt2-ml
84
+ # or generated data is bad, just repeat the last token
85
+ "LayerNorm_mlp_ln0": "ln_1",
86
+ "LayerNorm_mlp_ln1": "ln_2",
87
+ "intermediate": "mlp.c_fc",
88
+ "output": "mlp.c_proj",
89
+ "query_layer": "attn.c_attn",
90
+ "key_layer": "attn.c_attn",
91
+ "value_layer": "attn.c_attn",
92
+ "context_projection_layer": "attn.c_proj",
93
+ "gamma": "weight",
94
+ "kernel": "weight",
95
+ "beta": "bias",
96
+ "bias": "bias",
97
+ }
98
+
99
+
100
+ def convert_gpt2_checkpoint_to_pytorch(
101
+ gpt2_checkpoint_path, gpt2_config_file, pytorch_dump_folder_path
102
+ ):
103
+ # Construct model
104
+ if gpt2_config_file == "":
105
+ config = GPT2Config()
106
+ else:
107
+ config = GPT2Config.from_json_file(gpt2_config_file)
108
+ model = GPT2Model(config)
109
+
110
+ # Load weights from numpy
111
+ load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path)
112
+
113
+ # Save pytorch-model
114
+ pytorch_weights_dump_path = pytorch_dump_folder_path + "/" + WEIGHTS_NAME
115
+ pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME
116
+ print("Save PyTorch model to {}".format(pytorch_weights_dump_path))
117
+ torch.save(model.state_dict(), pytorch_weights_dump_path)
118
+ print("Save configuration file to {}".format(pytorch_config_dump_path))
119
+ with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
120
+ f.write(config.to_json_string())
121
+
122
+
123
+ # XXX: MUST do like: convert_gpt2_checkpoint_to_pytorch('./model.ckpt-100000', './mega.json', './')
124
+ # https://github.com/tensorflow/models/issues/2675#issuecomment-516595597
125
+ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
126
+ """Load tf checkpoints in a pytorch model"""
127
+ try:
128
+ import re
129
+
130
+ import tensorflow as tf
131
+ except ImportError:
132
+ logger.error(
133
+ "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
134
+ "https://www.tensorflow.org/install/ for installation instructions."
135
+ )
136
+ raise
137
+ tf_path = os.path.abspath(gpt2_checkpoint_path)
138
+ logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
139
+ # Load weights from TF model
140
+ init_vars = tf.train.list_variables(tf_path)
141
+ names = []
142
+ arrays = []
143
+ for name, shape in init_vars:
144
+ logger.info("Loading TF weight {} with shape {}".format(name, shape))
145
+ array = tf.train.load_variable(tf_path, name)
146
+ names.append(name)
147
+ arrays.append(array.squeeze())
148
+
149
+ import copy
150
+
151
+ orig_model = copy.deepcopy(model)
152
+
153
+ for name, array in zip(names, arrays):
154
+ name = name[6:] # skip "model/"
155
+ name = name.split("/")
156
+ pointer = model
157
+
158
+ attn_layer = ""
159
+ for m_name in name:
160
+ if re.fullmatch(r"[A-Za-z]+\d+", m_name):
161
+ scope_names = re.split(r"(\d+)", m_name)
162
+ else:
163
+ scope_names = [m_name]
164
+ sname = scope_names[0]
165
+
166
+ if sname == "" or sname == "embeddings":
167
+ continue
168
+ elif sname not in _GPT2_ML_TF_TO_TORCH:
169
+ print("=========================================================")
170
+ logger.info("Skip var name {}".format(scope_names))
171
+ pointer = None
172
+ break
173
+ else:
174
+ tname = _GPT2_ML_TF_TO_TORCH[sname]
175
+ if "." in tname:
176
+ parent, child = tname.split(".")
177
+ pointer = getattr(pointer, parent)
178
+ pointer = getattr(pointer, child)
179
+ else:
180
+ pointer = getattr(pointer, tname)
181
+
182
+ if tname == "attn.c_attn":
183
+ attn_layer = sname
184
+
185
+ if len(scope_names) >= 2:
186
+ num = int(scope_names[1])
187
+ pointer = pointer[num]
188
+
189
+ if pointer is None:
190
+ continue
191
+ if attn_layer == "":
192
+ try:
193
+ assert pointer.shape == array.shape
194
+ except AssertionError as e:
195
+ e.args += (pointer.shape, array.shape)
196
+ raise
197
+ logger.info(
198
+ "Initialize PyTorch weight {}, {}, {}".format(
199
+ name, array.mean(), pointer.mean()
200
+ )
201
+ )
202
+ if attn_layer == "":
203
+ pointer.data = torch.from_numpy(array)
204
+ else:
205
+ shape = pointer.shape
206
+ d = torch.from_numpy(array)
207
+ is_bias = len(shape) == 1
208
+ end = int(shape[0 if is_bias else 1] / 3)
209
+ m = dict(
210
+ query_layer=0,
211
+ key_layer=end,
212
+ value_layer=end * 2,
213
+ )
214
+ start = m[attn_layer]
215
+ end = start + end
216
+ if is_bias:
217
+ pointer.data[start:end] = d
218
+ else:
219
+ pointer.data[:, start:end] = d
220
+ logger.info(
221
+ "Initialize PyTorch weight {}, {}, {}".format(
222
+ name, array.mean(), pointer.mean()
223
+ )
224
+ )
225
+
226
+ for name, params in orig_model.named_parameters():
227
+ for n, p in model.named_parameters():
228
+ if name == n:
229
+ if params.equal(p):
230
+ print("--------------------------")
231
+ print(" %s not changed!" % n)
232
+ return model
233
+
234
+
235
+ class Attention(nn.Module):
236
+ def __init__(self, nx, n_ctx, config, scale=False, is_cross_attention=False):
237
+ super().__init__()
238
+
239
+ n_state = nx # in Attention: n_state=768 (nx=n_embd)
240
+ # [switch nx => n_state from Block to Attention to keep identical to TF implem]
241
+ assert n_state % config.n_head == 0
242
+ self.register_buffer(
243
+ "bias",
244
+ torch.tril(torch.ones((n_ctx, n_ctx), dtype=torch.uint8)).view(
245
+ 1, 1, n_ctx, n_ctx
246
+ ),
247
+ )
248
+ self.register_buffer("masked_bias", torch.tensor(-1e4))
249
+ self.n_head = config.n_head
250
+ self.split_size = n_state
251
+ self.scale = scale
252
+ self.is_cross_attention = is_cross_attention
253
+ if self.is_cross_attention:
254
+ self.c_attn = Conv1D(2 * n_state, nx)
255
+ self.q_attn = Conv1D(n_state, nx)
256
+ else:
257
+ self.c_attn = Conv1D(3 * n_state, nx)
258
+ self.c_proj = Conv1D(n_state, nx)
259
+ self.attn_dropout = nn.Dropout(config.attn_pdrop)
260
+ self.resid_dropout = nn.Dropout(config.resid_pdrop)
261
+ self.pruned_heads = set()
262
+
263
+ def prune_heads(self, heads):
264
+ if len(heads) == 0:
265
+ return
266
+ heads, index = find_pruneable_heads_and_indices(
267
+ heads, self.n_head, self.split_size // self.n_head, self.pruned_heads
268
+ )
269
+ index_attn = torch.cat(
270
+ [index, index + self.split_size, index + (2 * self.split_size)]
271
+ )
272
+
273
+ # Prune conv1d layers
274
+ self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
275
+ self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
276
+
277
+ # Update hyper params
278
+ self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads))
279
+ self.n_head = self.n_head - len(heads)
280
+ self.pruned_heads = self.pruned_heads.union(heads)
281
+
282
+ def _attn(
283
+ self, q, k, v, attention_mask=None, head_mask=None, output_attentions=False
284
+ ):
285
+ w = torch.matmul(q, k)
286
+ if self.scale:
287
+ w = w / (float(v.size(-1)) ** 0.5)
288
+ nd, ns = w.size(-2), w.size(-1)
289
+
290
+ if not self.is_cross_attention:
291
+ # if only "normal" attention layer implements causal mask
292
+ mask = self.bias[:, :, ns - nd : ns, :ns]
293
+ w = torch.where(mask.bool(), w, self.masked_bias.to(w.dtype))
294
+
295
+ if attention_mask is not None:
296
+ # Apply the attention mask
297
+ w = w + attention_mask
298
+
299
+ w = nn.Softmax(dim=-1)(w)
300
+ w = self.attn_dropout(w)
301
+
302
+ # Mask heads if we want to
303
+ if head_mask is not None:
304
+ w = w * head_mask
305
+
306
+ outputs = [torch.matmul(w, v)]
307
+ if output_attentions:
308
+ outputs.append(w)
309
+ return outputs
310
+
311
+ def merge_heads(self, x):
312
+ x = x.permute(0, 2, 1, 3).contiguous()
313
+ new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
314
+ return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states
315
+
316
+ def split_heads(self, x, k=False):
317
+ new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
318
+ x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states
319
+ if k:
320
+ return x.permute(0, 2, 3, 1) # (batch, head, head_features, seq_length)
321
+ else:
322
+ return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
323
+
324
+ def forward(
325
+ self,
326
+ hidden_states,
327
+ layer_past=None,
328
+ attention_mask=None,
329
+ head_mask=None,
330
+ encoder_hidden_states=None,
331
+ encoder_attention_mask=None,
332
+ use_cache=False,
333
+ output_attentions=False,
334
+ ):
335
+ if encoder_hidden_states is not None:
336
+ assert hasattr(
337
+ self, "q_attn"
338
+ ), "If class is used as cross attention, the weights `q_attn` have to be defined. Please make sure to instantiate class with `Attention(..., is_cross_attention=True)`."
339
+ query = self.q_attn(hidden_states)
340
+ key, value = self.c_attn(encoder_hidden_states).split(
341
+ self.split_size, dim=2
342
+ )
343
+ attention_mask = encoder_attention_mask
344
+ else:
345
+ query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
346
+
347
+ query = self.split_heads(query)
348
+ key = self.split_heads(key, k=True)
349
+ value = self.split_heads(value)
350
+ if layer_past is not None:
351
+ past_key, past_value = (
352
+ layer_past[0].transpose(-2, -1),
353
+ layer_past[1],
354
+ ) # transpose back cf below
355
+ key = torch.cat((past_key, key), dim=-1)
356
+ value = torch.cat((past_value, value), dim=-2)
357
+
358
+ if use_cache is True:
359
+ present = torch.stack(
360
+ (key.transpose(-2, -1), value)
361
+ ) # transpose to have same shapes for stacking
362
+ else:
363
+ present = (None,)
364
+
365
+ attn_outputs = self._attn(
366
+ query, key, value, attention_mask, head_mask, output_attentions
367
+ )
368
+ a = attn_outputs[0]
369
+
370
+ a = self.merge_heads(a)
371
+ a = self.c_proj(a)
372
+ a = self.resid_dropout(a)
373
+
374
+ outputs = [a, present] + attn_outputs[1:]
375
+ return outputs # a, present, (attentions)
376
+
377
+
378
+ class MLP(nn.Module):
379
+ def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd)
380
+ super().__init__()
381
+ nx = config.n_embd
382
+ self.c_fc = Conv1D(n_state, nx)
383
+ self.c_proj = Conv1D(nx, n_state)
384
+ self.act = ACT2FN[config.activation_function]
385
+ self.dropout = nn.Dropout(config.resid_pdrop)
386
+
387
+ def forward(self, x):
388
+ h = self.act(self.c_fc(x))
389
+ h2 = self.c_proj(h)
390
+ return self.dropout(h2)
391
+
392
+
393
+ class Block(nn.Module):
394
+ def __init__(self, n_ctx, config, scale=False):
395
+ super().__init__()
396
+ hidden_size = config.n_embd
397
+ inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
398
+ self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
399
+ self.attn = Attention(hidden_size, n_ctx, config, scale)
400
+ self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
401
+ if config.add_cross_attention:
402
+ self.crossattention = Attention(
403
+ hidden_size, n_ctx, config, scale, is_cross_attention=True
404
+ )
405
+ self.ln_cross_attn = nn.LayerNorm(
406
+ hidden_size, eps=config.layer_norm_epsilon
407
+ )
408
+ self.mlp = MLP(inner_dim, config)
409
+
410
+ def forward(
411
+ self,
412
+ hidden_states,
413
+ layer_past=None,
414
+ attention_mask=None,
415
+ head_mask=None,
416
+ encoder_hidden_states=None,
417
+ encoder_attention_mask=None,
418
+ use_cache=False,
419
+ output_attentions=False,
420
+ ):
421
+ attn_outputs = self.attn(
422
+ hidden_states,
423
+ layer_past=layer_past,
424
+ attention_mask=attention_mask,
425
+ head_mask=head_mask,
426
+ use_cache=use_cache,
427
+ output_attentions=output_attentions,
428
+ )
429
+ attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
430
+ outputs = attn_outputs[1:]
431
+ # residual connection
432
+ hidden_states = attn_output + hidden_states
433
+
434
+ if encoder_hidden_states is not None:
435
+ # add one self-attention block for cross-attention
436
+ assert hasattr(
437
+ self, "crossattention"
438
+ ), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
439
+ cross_attn_outputs = self.crossattention(
440
+ self.ln_cross_attn(hidden_states),
441
+ attention_mask=attention_mask,
442
+ head_mask=head_mask,
443
+ encoder_hidden_states=encoder_hidden_states,
444
+ encoder_attention_mask=encoder_attention_mask,
445
+ output_attentions=output_attentions,
446
+ )
447
+ attn_output = cross_attn_outputs[0]
448
+ # residual connection
449
+ hidden_states = hidden_states + attn_output
450
+ outputs = (
451
+ outputs + cross_attn_outputs[2:]
452
+ ) # add cross attentions if we output attention weights
453
+
454
+ feed_forward_hidden_states = self.mlp(self.ln_1(hidden_states))
455
+ # residual connection
456
+ hidden_states = hidden_states + feed_forward_hidden_states
457
+
458
+ hidden_states = self.ln_2(hidden_states)
459
+
460
+ outputs = [hidden_states] + outputs
461
+ return outputs # hidden_states, present, (attentions, cross_attentions)
462
+
463
+
464
+ class GPT2PreTrainedModel(PreTrainedModel):
465
+ """
466
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
467
+ models.
468
+ """
469
+
470
+ config_class = GPT2Config
471
+ load_tf_weights = load_tf_weights_in_gpt2
472
+ base_model_prefix = "transformer"
473
+ is_parallelizable = True
474
+
475
+ def __init__(self, *inputs, **kwargs):
476
+ super().__init__(*inputs, **kwargs)
477
+
478
+ def _init_weights(self, module):
479
+ """Initialize the weights."""
480
+ if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)):
481
+ # Slightly different from the TF version which uses truncated_normal for initialization
482
+ # cf https://github.com/pytorch/pytorch/pull/5617
483
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
484
+ if isinstance(module, (nn.Linear, Conv1D)) and module.bias is not None:
485
+ module.bias.data.zero_()
486
+ elif isinstance(module, nn.LayerNorm):
487
+ module.bias.data.zero_()
488
+ module.weight.data.fill_(1.0)
489
+
490
+
491
+ @dataclass
492
+ class GPT2DoubleHeadsModelOutput(ModelOutput):
493
+ """
494
+ Base class for outputs of models predicting if two sentences are consecutive or not.
495
+
496
+ Args:
497
+ loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided):
498
+ Language modeling loss.
499
+ mc_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`mc_labels` is provided):
500
+ Multiple choice classification loss.
501
+ logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`):
502
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
503
+ mc_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
504
+ Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
505
+ past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
506
+ List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape :obj:`(2,
507
+ batch_size, num_heads, sequence_length, embed_size_per_head)`).
508
+
509
+ Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
510
+ :obj:`past_key_values` input) to speed up sequential decoding.
511
+ hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
512
+ Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
513
+ of shape :obj:`(batch_size, sequence_length, hidden_size)`.
514
+
515
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
516
+ attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
517
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
518
+ sequence_length, sequence_length)`.
519
+
520
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
521
+ heads.
522
+ """
523
+
524
+ loss: Optional[torch.FloatTensor] = None
525
+ mc_loss: Optional[torch.FloatTensor] = None
526
+ logits: torch.FloatTensor = None
527
+ mc_logits: torch.FloatTensor = None
528
+ past_key_values: Optional[List[torch.FloatTensor]] = None
529
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
530
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
531
+
532
+
533
+ GPT2_START_DOCSTRING = r"""
534
+
535
+ This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
536
+ methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
537
+ pruning heads etc.)
538
+
539
+ This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__
540
+ subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
541
+ general usage and behavior.
542
+
543
+ Parameters:
544
+ config (:class:`~transformers.GPT2Config`): Model configuration class with all the parameters of the model.
545
+ Initializing with a config file does not load the weights associated with the model, only the