ChancesYuan commited on
Commit
06a8327
1 Parent(s): 3124969

Upload 12 files

Browse files
.gitattributes CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ dataset/fb15k237/entity2textlong.txt filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -1,17 +1,261 @@
1
  import gradio as gr
 
 
 
 
 
 
 
2
 
3
- def edit_process(title, context):
4
- return f"Title:{title}\nContext:{context}\n...", f"Title:{title}\nContext:{context}\n..."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- def add_process(title, context, img):
7
- return f"Title:{title}\nContext:{context}\n...{img}", f"Title:{title}\nContext:{context}\n...{img}"
8
 
9
  with gr.Blocks() as demo:
 
 
 
 
10
  gr.Markdown("# KGE Editing")
11
 
12
  # 多个tab
13
  with gr.Tabs():
14
-
15
  with gr.TabItem("E-FB15k237"):
16
  with gr.Row():
17
  with gr.Column():
@@ -25,7 +269,7 @@ with gr.Blocks() as demo:
25
  edit_output = gr.Textbox(label="After Edit", lines=3, placeholder="")
26
 
27
  gr.Examples(
28
- examples=[["[MASK] r1 t1", "h1"], ["[MASK] r2 t2", "h2"]],
29
  inputs=[edit_input, alter_label],
30
  outputs=[origin_output, edit_output],
31
  fn=edit_process,
@@ -37,7 +281,7 @@ with gr.Blocks() as demo:
37
  with gr.Column():
38
  add_input = gr.Textbox(label="Input", lines=1, placeholder="New triple input")
39
 
40
- mask_head = gr.Textbox(label="Head/Tail", lines=1, placeholder="1:head / 0:tail")
41
  add_button = gr.Button("Add")
42
 
43
  with gr.Column():
@@ -45,15 +289,14 @@ with gr.Blocks() as demo:
45
  add_output = gr.Textbox(label="Add Results", lines=3, placeholder="")
46
 
47
  gr.Examples(
48
- examples=[["h1 r1 t1", "1"], ["h2 r2 t2", "1"]],
49
- inputs=[add_input, mask_head],
50
  outputs=[add_origin_output, add_output],
51
  fn=add_process,
52
  cache_examples=True,
53
  )
54
 
55
- # origin_button.click(fn=origin_preditcion, inputs=[input, alter_label], outputs=origin_output)
56
  edit_button.click(fn=edit_process, inputs=[edit_input, alter_label], outputs=[origin_output, edit_output])
57
- add_button.click(fn=add_process, inputs=[add_input, mask_head], outputs=[add_origin_output, add_output])
58
 
59
  demo.launch()
 
1
  import gradio as gr
2
+ from collections import defaultdict
3
+ from transformers import BertTokenizer, BertForMaskedLM
4
+ import jsonlines
5
+ import torch
6
+ from src.modeling_bert import EXBertForMaskedLM
7
+ from higher.patch import monkeypatch as make_functional
8
+ # from src.models.one_shot_learner import OneShotLearner
9
 
10
+ ### load KGE model
11
+ edit_origin_model = BertForMaskedLM.from_pretrained(pretrained_model_name_or_path="zjunlp/KGEditor", subfolder="E-FB15k237")
12
+ edit_ex_model = EXBertForMaskedLM.from_pretrained(pretrained_model_name_or_path="zjunlp/KGEditor", subfolder="E-FB15k237")
13
+
14
+ edit_learner = torch.load("./learner_checkpoint/edit/learner_params.pt", map_location=torch.device('cpu'))
15
+ add_learner = torch.load("./learner_checkpoint/add/learner_params.pt", map_location=torch.device('cpu'))
16
+
17
+ add_origin_model = BertForMaskedLM.from_pretrained(pretrained_model_name_or_path="zjunlp/KGEditor", subfolder="A-FB15k237")
18
+ add_ex_model = EXBertForMaskedLM.from_pretrained(pretrained_model_name_or_path="zjunlp/KGEditor", subfolder="A-FB15k237")
19
+
20
+ ### init inputs
21
+ ent_name2id = defaultdict(str)
22
+ id2ent_name = defaultdict(str)
23
+ rel_name2id = defaultdict(str)
24
+ id2ent_text = defaultdict(str)
25
+ id2rel_text = defaultdict(str)
26
+ corrupt_triple = defaultdict(list)
27
+
28
+ ### init tokenizer
29
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
30
+ add_tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path='zjunlp/KGEditor', subfolder="E-FB15k237")
31
+
32
+ def init_triple_input():
33
+ global ent2token
34
+ global ent2id
35
+ global id2ent
36
+ global rel2token
37
+
38
+ with open("./dataset/fb15k237/relations.txt", "r") as f:
39
+ lines = f.readlines()
40
+ relations = []
41
+ for line in lines:
42
+ relations.append(line.strip().split('\t')[0])
43
+
44
+ rel2token = {ent: f"[RELATION_{i}]" for i, ent in enumerate(relations)}
45
+
46
+ with open("./dataset/fb15k237/entity2text.txt", "r") as f:
47
+ for line in f.readlines():
48
+ id, name = line.rstrip('\n').split('\t')
49
+ ent_name2id[name] = id
50
+ id2ent_name[id] = name
51
+
52
+ with open("./dataset/fb15k237/relation2text.txt", "r") as f:
53
+ for line in f.readlines():
54
+ id, name = line.rstrip('\n').split('\t')
55
+ rel_name2id[name] = id
56
+ id2rel_text[id] = name
57
+
58
+ with open("./dataset/fb15k237/entity2textlong.txt", "r") as f:
59
+ for line in f.readlines():
60
+ id, text = line.rstrip('\n').split('\t')
61
+ id2ent_text[id] = text.replace("\\n", " ").replace("\\", "")
62
+
63
+ entities = list(id2ent_text.keys())
64
+ ent2token = {ent: f"[ENTITY_{i}]" for i, ent in enumerate(entities)}
65
+ ent2id = {ent: i for i, ent in enumerate(entities)}
66
+ id2ent = {i: ent for i, ent in enumerate(entities)}
67
+
68
+ with jsonlines.open("./dataset/fb15k237/edit_test.jsonl") as f:
69
+ lines = []
70
+ for d in f:
71
+ corrupt_triple[" ".join(d["ori"])] = d["cor"]
72
+
73
+ def solve(triple, alter_label, edit_task):
74
+ h, r, t = triple.split("|")
75
+ if h == "[MASK]":
76
+ text_a = "[MASK]"
77
+ text_b = id2rel_text[r] + " " + rel2token[r]
78
+ text_c = ent2token[ent_name2id[t]] + " " + id2ent_text[ent_name2id[t]]
79
+ origin_label = corrupt_triple[" ".join([ent_name2id[alter_label], r, ent_name2id[t]])][0] if edit_task else ent_name2id[alter_label]
80
+ else:
81
+ text_a = ent2token[ent_name2id[h]]
82
+ # text_b = id2rel_text[r] + "[PAD]"
83
+ text_b = id2rel_text[r] + " " + rel2token[r]
84
+ text_c = "[MASK]" + " " + id2ent_text[ent_name2id[h]]
85
+ origin_label = corrupt_triple[" ".join([ent_name2id[h], r, ent_name2id[alter_label]])][2] if edit_task else ent_name2id[alter_label]
86
+
87
+ if text_a == "[MASK]":
88
+ input_text_a = tokenizer.sep_token.join(["[MASK]", id2rel_text[r] + "[PAD]"])
89
+ input_text_b = "[PAD]" + " " + id2ent_text[ent_name2id[t]]
90
+ else:
91
+ input_text_a = "[PAD] "
92
+ input_text_b = tokenizer.sep_token.join([id2rel_text[r] + "[PAD]", "[MASK]" + " " + id2ent_text[ent_name2id[h]]])
93
+
94
+ cond_inputs_text = "{} >> {} || {}".format(
95
+ add_tokenizer.added_tokens_decoder[ent2id[origin_label] + len(tokenizer)],
96
+ add_tokenizer.added_tokens_decoder[ent2id[ent_name2id[alter_label]] + len(tokenizer)],
97
+ input_text_a + input_text_b
98
+ )
99
+
100
+ inputs = tokenizer(
101
+ f"{text_a} [SEP] {text_b} [SEP] {text_c}",
102
+ truncation="longest_first",
103
+ max_length=64,
104
+ padding="longest",
105
+ add_special_tokens=True,
106
+ )
107
+
108
+ edit_inputs = tokenizer(
109
+ input_text_a,
110
+ input_text_b,
111
+ truncation="longest_first",
112
+ max_length=64,
113
+ padding="longest",
114
+ add_special_tokens=True,
115
+ )
116
+
117
+ cond_inputs = tokenizer(
118
+ cond_inputs_text,
119
+ truncation=True,
120
+ max_length=64,
121
+ padding="max_length",
122
+ add_special_tokens=True,
123
+ )
124
+
125
+ inputs = {
126
+ "input_ids": torch.tensor(inputs["input_ids"]).unsqueeze(dim=0),
127
+ "attention_mask": torch.tensor(inputs["attention_mask"]).unsqueeze(dim=0),
128
+ "token_type_ids": torch.tensor(inputs["token_type_ids"]).unsqueeze(dim=0)
129
+ }
130
+
131
+ edit_inputs = {
132
+ "input_ids": torch.tensor(edit_inputs["input_ids"]).unsqueeze(dim=0),
133
+ "attention_mask": torch.tensor(edit_inputs["attention_mask"]).unsqueeze(dim=0),
134
+ "token_type_ids": torch.tensor(edit_inputs["token_type_ids"]).unsqueeze(dim=0)
135
+ }
136
+
137
+ cond_inputs = {
138
+ "input_ids": torch.tensor(cond_inputs["input_ids"]).unsqueeze(dim=0),
139
+ "attention_mask": torch.tensor(cond_inputs["attention_mask"]).unsqueeze(dim=0),
140
+ "token_type_ids": torch.tensor(cond_inputs["token_type_ids"]).unsqueeze(dim=0)
141
+ }
142
+
143
+ return inputs, cond_inputs, edit_inputs
144
+
145
+ def get_logits_orig_params_dict(inputs, cond_inputs, alter_label, ex_model, learner):
146
+ with torch.enable_grad():
147
+ logits = ex_model.eval()(
148
+ input_ids=inputs["input_ids"],
149
+ attention_mask=inputs["attention_mask"],
150
+ ).logits
151
+ # print(logits.shape)
152
+ # logits_orig, logit_for_grad, _ = logits.split([
153
+ # len(inputs["input_ids"]) - 1,
154
+ # 1,
155
+ # 0,
156
+ # ])
157
+ input_ids = inputs['input_ids']
158
+ _, mask_idx = (input_ids == tokenizer.mask_token_id).nonzero(as_tuple=True)
159
+ mask_logits = logits[:, mask_idx, 30522:45473].squeeze(dim=0)
160
+
161
+ grads = torch.autograd.grad(
162
+ # cross_entropy
163
+ torch.nn.functional.cross_entropy(
164
+ mask_logits[-1:, :],
165
+ torch.tensor([alter_label]),
166
+ reduction="none",
167
+ ).mean(-1),
168
+ ex_model.parameters(),
169
+ )
170
+
171
+ grads = {
172
+ name: grad
173
+ for (name, _), grad in zip(ex_model.named_parameters(), grads)
174
+ }
175
+
176
+ # cond_inputs里面有pad
177
+ params_dict = learner(
178
+ cond_inputs["input_ids"][-1:],
179
+ cond_inputs["attention_mask"][-1:],
180
+ grads=grads,
181
+ )
182
+
183
+ return params_dict
184
+
185
+ def edit_process(edit_input, alter_label):
186
+ inputs, cond_inputs, edit_inputs = solve(edit_input, alter_label, edit_task=True)
187
+
188
+ _, mask_idx = (inputs["input_ids"] == tokenizer.mask_token_id).nonzero(as_tuple=True)
189
+ logits = edit_origin_model(**inputs).logits[:, :, 30522:45473].squeeze()
190
+ logits = logits[mask_idx, :]
191
+
192
+ ### origin output
193
+ _, origin_entity_order = torch.sort(logits, dim=1, descending=True)
194
+ origin_entity_order = origin_entity_order.squeeze(dim=0)
195
+ origin_top3 = [id2ent_name[id2ent[origin_entity_order[i].item()]] for i in range(3)]
196
+
197
+ ### edit output
198
+ fmodel = make_functional(edit_ex_model).eval()
199
+ params_dict = get_logits_orig_params_dict(inputs, cond_inputs, ent2id[ent_name2id[alter_label]], edit_ex_model, edit_learner)
200
+ edit_logits = fmodel(
201
+ input_ids=inputs["input_ids"],
202
+ attention_mask=inputs["attention_mask"],
203
+ # add delta theta
204
+ params=[
205
+ params_dict.get(n, 0) + p
206
+ for n, p in edit_ex_model.named_parameters()
207
+ ],
208
+ ).logits[:, :, 30522:45473].squeeze()
209
+
210
+ edit_logits = edit_logits[mask_idx, :]
211
+ _, edit_entity_order = torch.sort(edit_logits, dim=1, descending=True)
212
+ edit_entity_order = edit_entity_order.squeeze(dim=0)
213
+ edit_top3 = [id2ent_name[id2ent[edit_entity_order[i].item()]] for i in range(3)]
214
+
215
+ return "\n".join(origin_top3), "\n".join(edit_top3)
216
+
217
+ def add_process(edit_input, alter_label):
218
+ inputs, cond_inputs, add_inputs = solve(edit_input, alter_label, edit_task=False)
219
+
220
+ _, mask_idx = (inputs["input_ids"] == tokenizer.mask_token_id).nonzero(as_tuple=True)
221
+ logits = add_origin_model(**inputs).logits[:, :, 30522:45473].squeeze()
222
+ logits = logits[mask_idx, :]
223
+
224
+ ### origin output
225
+ _, origin_entity_order = torch.sort(logits, dim=1, descending=True)
226
+ origin_entity_order = origin_entity_order.squeeze(dim=0)
227
+ origin_top3 = [id2ent_name[id2ent[origin_entity_order[i].item()]] for i in range(3)]
228
+
229
+ ### add output
230
+ fmodel = make_functional(add_ex_model).eval()
231
+ params_dict = get_logits_orig_params_dict(inputs, cond_inputs, ent2id[ent_name2id[alter_label]], add_ex_model, add_learner)
232
+ add_logits = fmodel(
233
+ input_ids=inputs["input_ids"],
234
+ attention_mask=inputs["attention_mask"],
235
+ # add delta theta
236
+ params=[
237
+ params_dict.get(n, 0) + p
238
+ for n, p in add_ex_model.named_parameters()
239
+ ],
240
+ ).logits[:, :, 30522:45473].squeeze()
241
+
242
+ add_logits = add_logits[mask_idx, :]
243
+ _, add_entity_order = torch.sort(add_logits, dim=1, descending=True)
244
+ add_entity_order = add_entity_order.squeeze(dim=0)
245
+ add_top3 = [id2ent_name[id2ent[add_entity_order[i].item()]] for i in range(3)]
246
+
247
+ return "\n".join(origin_top3), "\n".join(add_top3)
248
 
 
 
249
 
250
  with gr.Blocks() as demo:
251
+ init_triple_input()
252
+ ### example
253
+ # edit_process("[MASK]|/people/person/profession|Jack Black", "Kellie Martin")
254
+ add_process("Red Skelton|/people/person/places_lived./people/place_lived/location|[MASK]", "Palm Springs")
255
  gr.Markdown("# KGE Editing")
256
 
257
  # 多个tab
258
  with gr.Tabs():
 
259
  with gr.TabItem("E-FB15k237"):
260
  with gr.Row():
261
  with gr.Column():
 
269
  edit_output = gr.Textbox(label="After Edit", lines=3, placeholder="")
270
 
271
  gr.Examples(
272
+ examples=[["[MASK]|/people/person/profession|Jack Black", "Kellie Martin"], ["Red Skelton|/people/person/places_lived./people/place_lived/location|[MASK]", "Palm Springs"]],
273
  inputs=[edit_input, alter_label],
274
  outputs=[origin_output, edit_output],
275
  fn=edit_process,
 
281
  with gr.Column():
282
  add_input = gr.Textbox(label="Input", lines=1, placeholder="New triple input")
283
 
284
+ inductive_entity = gr.Textbox(label="Inductive Entity", lines=1, placeholder="Entity Name")
285
  add_button = gr.Button("Add")
286
 
287
  with gr.Column():
 
289
  add_output = gr.Textbox(label="Add Results", lines=3, placeholder="")
290
 
291
  gr.Examples(
292
+ examples=[["Jane Wyman|/people/person/places_lived./people/place_lived/location|[MASK]", "Palm Springs"], ["Red Skelton|/people/person/places_lived./people/place_lived/location|[MASK]", "Palm Springs"]],
293
+ inputs=[add_input, inductive_entity],
294
  outputs=[add_origin_output, add_output],
295
  fn=add_process,
296
  cache_examples=True,
297
  )
298
 
 
299
  edit_button.click(fn=edit_process, inputs=[edit_input, alter_label], outputs=[origin_output, edit_output])
300
+ add_button.click(fn=add_process, inputs=[add_input, inductive_entity], outputs=[add_origin_output, add_output])
301
 
302
  demo.launch()
dataset/fb15k237/edit_test.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
dataset/fb15k237/entity2text.txt ADDED
The diff for this file is too large to render. See raw diff
 
dataset/fb15k237/entity2textlong.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1c6028d81296e311076eed7cfbf5dc3c4174b68394639148e32212ad49aa6c7f
3
+ size 13063994
dataset/fb15k237/relation2text.txt ADDED
The diff for this file is too large to render. See raw diff
 
dataset/fb15k237/relations.txt ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /soccer/football_team/current_roster./soccer/football_roster_position/position
2
+ /music/artist/origin
3
+ /ice_hockey/hockey_team/current_roster./sports/sports_team_roster/position
4
+ /food/food/nutrients./food/nutrition_fact/nutrient
5
+ /film/actor/film./film/performance/film
6
+ /award/award_nominee/award_nominations./award/award_nomination/nominated_for
7
+ /government/political_party/politicians_in_this_party./government/political_party_tenure/politician
8
+ /base/schemastaging/person_extra/net_worth./measurement_unit/dated_money_value/currency
9
+ /people/deceased_person/place_of_death
10
+ /people/person/profession
11
+ /location/administrative_division/first_level_division_of
12
+ /base/marchmadness/ncaa_basketball_tournament/seeds./base/marchmadness/ncaa_tournament_seed/team
13
+ /education/university/international_tuition./measurement_unit/dated_money_value/currency
14
+ /location/us_county/county_seat
15
+ /location/location/partially_contains
16
+ /tv/tv_program/program_creator
17
+ /film/film/music
18
+ /tv/tv_program/languages
19
+ /common/topic/webpage./common/webpage/category
20
+ /user/tsegaran/random/taxonomy_subject/entry./user/tsegaran/random/taxonomy_entry/taxonomy
21
+ /education/field_of_study/students_majoring./education/education/major_field_of_study
22
+ /business/business_operation/assets./measurement_unit/dated_money_value/currency
23
+ /film/film_set_designer/film_sets_designed
24
+ /dataworld/gardening_hint/split_to
25
+ /people/person/languages
26
+ /business/job_title/people_with_this_title./business/employment_tenure/company
27
+ /location/country/form_of_government
28
+ /base/schemastaging/organization_extra/phone_number./base/schemastaging/phone_sandbox/service_language
29
+ /people/person/place_of_birth
30
+ /sports/sports_team/colors
31
+ /education/educational_institution/school_type
32
+ /award/award_category/winners./award/award_honor/award_winner
33
+ /organization/organization/headquarters./location/mailing_address/citytown
34
+ /education/educational_degree/people_with_this_degree./education/education/student
35
+ /government/legislative_session/members./government/government_position_held/legislative_sessions
36
+ /film/film/distributors./film/film_film_distributor_relationship/film_distribution_medium
37
+ /education/educational_degree/people_with_this_degree./education/education/major_field_of_study
38
+ /location/hud_county_place/county
39
+ /location/administrative_division/country
40
+ /film/film/film_production_design_by
41
+ /award/award_winning_work/awards_won./award/award_honor/award
42
+ /organization/organization/headquarters./location/mailing_address/state_province_region
43
+ /base/schemastaging/organization_extra/phone_number./base/schemastaging/phone_sandbox/contact_category
44
+ /tv/tv_program/country_of_origin
45
+ /olympics/olympic_participating_country/medals_won./olympics/olympic_medal_honor/medal
46
+ /location/country/second_level_divisions
47
+ /award/award_ceremony/awards_presented./award/award_honor/honored_for
48
+ /organization/organization_member/member_of./organization/organization_membership/organization
49
+ /education/educational_institution/campuses
50
+ /music/artist/contribution./music/recording_contribution/performance_role
51
+ /award/ranked_item/appears_in_ranked_lists./award/ranking/list
52
+ /people/person/religion
53
+ /travel/travel_destination/climate./travel/travel_destination_monthly_climate/month
54
+ /film/special_film_performance_type/film_performance_type./film/performance/film
55
+ /award/award_nominee/award_nominations./award/award_nomination/award
56
+ /location/statistical_region/religions./location/religion_percentage/religion
57
+ /sports/sports_league_draft/picks./sports/sports_league_draft_pick/school
58
+ /film/film/distributors./film/film_film_distributor_relationship/region
59
+ /government/politician/government_positions_held./government/government_position_held/legislative_sessions
60
+ /organization/role/leaders./organization/leadership/organization
61
+ /tv/tv_network/programs./tv/tv_network_duration/program
62
+ /soccer/football_team/current_roster./sports/sports_team_roster/position
63
+ /music/instrument/instrumentalists
64
+ /business/business_operation/operating_income./measurement_unit/dated_money_value/currency
65
+ /people/cause_of_death/people
66
+ /film/film/film_art_direction_by
67
+ /people/person/sibling_s./people/sibling_relationship/sibling
68
+ /film/film/cinematography
69
+ /film/actor/dubbing_performances./film/dubbing_performance/language
70
+ /base/biblioness/bibs_location/state
71
+ /base/petbreeds/city_with_dogs/top_breeds./base/petbreeds/dog_city_relationship/dog_breed
72
+ /people/person/gender
73
+ /education/field_of_study/students_majoring./education/education/student
74
+ /base/popstra/celebrity/dated./base/popstra/dated/participant
75
+ /sports/sports_team/roster./american_football/football_roster_position/position
76
+ /award/award_winner/awards_won./award/award_honor/award_winner
77
+ /olympics/olympic_participating_country/medals_won./olympics/olympic_medal_honor/olympics
78
+ /film/director/film
79
+ /tv/tv_producer/programs_produced./tv/tv_producer_term/program
80
+ /film/film_distributor/films_distributed./film/film_film_distributor_relationship/film
81
+ /olympics/olympic_games/sports
82
+ /music/record_label/artist
83
+ /education/university/local_tuition./measurement_unit/dated_money_value/currency
84
+ /film/film/story_by
85
+ /people/person/spouse_s./people/marriage/spouse
86
+ /sports/sports_league/teams./sports/sports_league_participation/team
87
+ /people/profession/specialization_of
88
+ /base/americancomedy/celebrity_impressionist/celebrities_impersonated
89
+ /tv/tv_program/genre
90
+ /award/award_category/nominees./award/award_nomination/nominated_for
91
+ /language/human_language/countries_spoken_in
92
+ /organization/organization/headquarters./location/mailing_address/country
93
+ /location/statistical_region/gdp_real./measurement_unit/adjusted_money_value/adjustment_currency
94
+ /education/university/fraternities_and_sororities
95
+ /award/award_nominee/award_nominations./award/award_nomination/award_nominee
96
+ /military/military_combatant/military_conflicts./military/military_combatant_group/combatants
97
+ /award/award_nominated_work/award_nominations./award/award_nomination/nominated_for
98
+ /location/location/time_zones
99
+ /film/film/dubbing_performances./film/dubbing_performance/actor
100
+ /film/film_subject/films
101
+ /education/educational_degree/people_with_this_degree./education/education/institution
102
+ /education/educational_institution/colors
103
+ /award/award_category/category_of
104
+ /tv/tv_personality/tv_regular_appearances./tv/tv_regular_personal_appearance/program
105
+ /film/film/language
106
+ /music/group_member/membership./music/group_membership/group
107
+ /business/business_operation/revenue./measurement_unit/dated_money_value/currency
108
+ /film/film/film_festivals
109
+ /film/actor/film./film/performance/special_performance_type
110
+ /organization/non_profit_organization/registered_with./organization/non_profit_registration/registering_agency
111
+ /government/politician/government_positions_held./government/government_position_held/jurisdiction_of_office
112
+ /base/aareas/schema/administrative_area/administrative_parent
113
+ /award/award_winning_work/awards_won./award/award_honor/award_winner
114
+ /organization/organization/place_founded
115
+ /soccer/football_player/current_team./sports/sports_team_roster/team
116
+ /government/politician/government_positions_held./government/government_position_held/basic_title
117
+ /music/artist/track_contributions./music/track_contribution/role
118
+ /base/localfood/seasonal_month/produce_available./base/localfood/produce_availability/seasonal_months
119
+ /celebrities/celebrity/celebrity_friends./celebrities/friendship/friend
120
+ /sports/professional_sports_team/draft_picks./sports/sports_league_draft_pick/school
121
+ /award/hall_of_fame/inductees./award/hall_of_fame_induction/inductee
122
+ /influence/influence_node/peers./influence/peer_relationship/peers
123
+ /medicine/disease/risk_factors
124
+ /broadcast/content/artist
125
+ /film/film/estimated_budget./measurement_unit/dated_money_value/currency
126
+ /military/military_conflict/combatants./military/military_combatant_group/combatants
127
+ /location/capital_of_administrative_division/capital_of./location/administrative_division_capital_relationship/administrative_division
128
+ /tv/tv_program/regular_cast./tv/regular_tv_appearance/actor
129
+ /people/deceased_person/place_of_burial
130
+ /location/location/adjoin_s./location/adjoining_relationship/adjoins
131
+ /music/group_member/membership./music/group_membership/role
132
+ /award/award_ceremony/awards_presented./award/award_honor/award_winner
133
+ /film/film/prequel
134
+ /film/film/produced_by
135
+ /tv/tv_program/tv_producer./tv/tv_producer_term/producer_type
136
+ /sports/sports_position/players./sports/sports_team_roster/team
137
+ /olympics/olympic_games/participating_countries
138
+ /music/genre/parent_genre
139
+ /tv/tv_writer/tv_programs./tv/tv_program_writer_relationship/tv_program
140
+ /music/genre/artists
141
+ /film/film/genre
142
+ /people/person/employment_history./business/employment_tenure/company
143
+ /education/university/domestic_tuition./measurement_unit/dated_money_value/currency
144
+ /people/person/nationality
145
+ /location/country/capital
146
+ /location/statistical_region/gni_per_capita_in_ppp_dollars./measurement_unit/dated_money_value/currency
147
+ /base/aareas/schema/administrative_area/capital
148
+ /business/business_operation/industry
149
+ /location/hud_foreclosure_area/estimated_number_of_mortgages./measurement_unit/dated_integer/source
150
+ /film/film/other_crew./film/film_crew_gig/crewmember
151
+ /base/popstra/location/vacationers./base/popstra/vacation_choice/vacationer
152
+ /film/film/film_format
153
+ /medicine/disease/notable_people_with_this_condition
154
+ /film/film/costume_design_by
155
+ /government/government_office_category/officeholders./government/government_position_held/jurisdiction_of_office
156
+ /location/statistical_region/gdp_nominal./measurement_unit/dated_money_value/currency
157
+ /sports/sports_team/roster./baseball/baseball_roster_position/position
158
+ /award/award_winning_work/awards_won./award/award_honor/honored_for
159
+ /olympics/olympic_sport/athletes./olympics/olympic_athlete_affiliation/olympics
160
+ /celebrities/celebrity/sexual_relationships./celebrities/romantic_relationship/celebrity
161
+ /people/marriage_union_type/unions_of_this_type./people/marriage/location_of_ceremony
162
+ /organization/organization/child./organization/organization_relationship/child
163
+ /organization/organization_founder/organizations_founded
164
+ /sports/sports_team/sport
165
+ /people/ethnicity/geographic_distribution
166
+ /location/statistical_region/places_exported_to./location/imports_and_exports/exported_to
167
+ /location/country/official_language
168
+ /film/film/production_companies
169
+ /user/jg/default_domain/olympic_games/sports
170
+ /time/event/locations
171
+ /people/person/spouse_s./people/marriage/type_of_union
172
+ /government/governmental_body/members./government/government_position_held/legislative_sessions
173
+ /media_common/netflix_genre/titles
174
+ /user/alexander/philosophy/philosopher/interests
175
+ /film/film/runtime./film/film_cut/film_release_region
176
+ /education/educational_institution/students_graduates./education/education/student
177
+ /base/eating/practicer_of_diet/diet
178
+ /tv/non_character_role/tv_regular_personal_appearances./tv/tv_regular_personal_appearance/person
179
+ /sports/sports_position/players./sports/sports_team_roster/position
180
+ /sports/professional_sports_team/draft_picks./sports/sports_league_draft_pick/draft
181
+ /medicine/symptom/symptom_of
182
+ /film/person_or_entity_appearing_in_film/films./film/personal_film_appearance/type_of_appearance
183
+ /sports/sports_team_location/teams
184
+ /american_football/football_team/current_roster./sports/sports_team_roster/position
185
+ /people/person/places_lived./people/place_lived/location
186
+ /location/statistical_region/rent50_2./measurement_unit/dated_money_value/currency
187
+ /film/film/personal_appearances./film/personal_film_appearance/person
188
+ /music/instrument/family
189
+ /sports/sports_team/roster./basketball/basketball_roster_position/position
190
+ /base/schemastaging/organization_extra/phone_number./base/schemastaging/phone_sandbox/service_location
191
+ /film/film/release_date_s./film/film_regional_release_date/film_release_region
192
+ /award/award_category/disciplines_or_subjects
193
+ /base/popstra/celebrity/friendship./base/popstra/friendship/participant
194
+ /music/performance_role/regular_performances./music/group_membership/group
195
+ /film/film/edited_by
196
+ /base/x2010fifaworldcupsouthafrica/world_cup_squad/current_world_cup_squad./base/x2010fifaworldcupsouthafrica/current_world_cup_squad/current_club
197
+ /base/popstra/celebrity/canoodled./base/popstra/canoodled/participant
198
+ /film/film/release_date_s./film/film_regional_release_date/film_release_distribution_medium
199
+ /film/film/other_crew./film/film_crew_gig/film_crew_role
200
+ /base/popstra/celebrity/breakup./base/popstra/breakup/participant
201
+ /film/film/country
202
+ /music/performance_role/regular_performances./music/group_membership/role
203
+ /sports/sports_team/roster./american_football/football_historical_roster_position/position_s
204
+ /film/film/release_date_s./film/film_regional_release_date/film_regional_debut_venue
205
+ /time/event/instance_of_recurring_event
206
+ /olympics/olympic_participating_country/athletes./olympics/olympic_athlete_affiliation/olympics
207
+ /organization/endowed_organization/endowment./measurement_unit/dated_money_value/currency
208
+ /travel/travel_destination/how_to_get_here./travel/transportation/mode_of_transportation
209
+ /baseball/baseball_team/team_stats./baseball/baseball_team_stats/season
210
+ /award/award_category/winners./award/award_honor/ceremony
211
+ /government/legislative_session/members./government/government_position_held/district_represented
212
+ /influence/influence_node/influenced_by
213
+ /base/culturalevent/event/entity_involved
214
+ /people/ethnicity/people
215
+ /sports/sport/pro_athletes./sports/pro_sports_played/athlete
216
+ /location/statistical_region/gdp_nominal_per_capita./measurement_unit/dated_money_value/currency
217
+ /location/hud_county_place/place
218
+ /base/aareas/schema/administrative_area/administrative_area_type
219
+ /base/locations/continents/countries_within
220
+ /sports/sports_position/players./american_football/football_historical_roster_position/position_s
221
+ /people/person/spouse_s./people/marriage/location_of_ceremony
222
+ /education/educational_institution/students_graduates./education/education/major_field_of_study
223
+ /film/film/written_by
224
+ /olympics/olympic_sport/athletes./olympics/olympic_athlete_affiliation/country
225
+ /music/performance_role/guest_performances./music/recording_contribution/performance_role
226
+ /film/film/featured_film_locations
227
+ /education/educational_institution_campus/educational_institution
228
+ /sports/pro_athlete/teams./sports/sports_team_roster/team
229
+ /people/ethnicity/languages_spoken
230
+ /film/film/executive_produced_by
231
+ /tv/tv_producer/programs_produced./tv/tv_producer_term/producer_type
232
+ /location/location/contains
233
+ /base/biblioness/bibs_location/country
234
+ /user/ktrueman/default_domain/international_organization/member_states
235
+ /music/performance_role/track_performances./music/track_contribution/role
236
+ /olympics/olympic_games/medals_awarded./olympics/olympic_medal_honor/medal
237
+ /base/saturdaynightlive/snl_cast_member/seasons./base/saturdaynightlive/snl_season_tenure/cast_members
requirement.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ allennlp
2
+ kgeditor==1.0.0
3
+ transformers
4
+ jsonlines
5
+ higher
src/__pycache__/modeling_bert.cpython-38.pyc ADDED
Binary file (39.5 kB). View file
 
src/__pycache__/one_shot_learner.cpython-38.pyc ADDED
Binary file (4.19 kB). View file
 
src/modeling_bert.py ADDED
@@ -0,0 +1,1338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """PyTorch BERT model."""
17
+
18
+
19
+ import math
20
+ import os
21
+ from dataclasses import dataclass
22
+ from typing import List, Optional, Tuple, Union
23
+
24
+ import torch
25
+ import torch.utils.checkpoint
26
+ from packaging import version
27
+ from torch import nn
28
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
29
+
30
+ from transformers.activations import ACT2FN
31
+ from transformers.modeling_outputs import (
32
+ BaseModelOutputWithPastAndCrossAttentions,
33
+ BaseModelOutputWithPoolingAndCrossAttentions,
34
+ CausalLMOutputWithCrossAttentions,
35
+ MaskedLMOutput,
36
+ MultipleChoiceModelOutput,
37
+ NextSentencePredictorOutput,
38
+ QuestionAnsweringModelOutput,
39
+ SequenceClassifierOutput,
40
+ TokenClassifierOutput,
41
+ )
42
+ from transformers.modeling_utils import PreTrainedModel
43
+ from transformers.pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
44
+ from transformers.utils import (
45
+ ModelOutput,
46
+ add_code_sample_docstrings,
47
+ add_start_docstrings,
48
+ add_start_docstrings_to_model_forward,
49
+ logging,
50
+ )
51
+ from transformers.models.bert.configuration_bert import BertConfig
52
+ from transformers.activations import ACT2FN
53
+
54
+
55
+ logger = logging.get_logger(__name__)
56
+
57
+ _CHECKPOINT_FOR_DOC = "bert-base-uncased"
58
+ _CONFIG_FOR_DOC = "BertConfig"
59
+ _TOKENIZER_FOR_DOC = "BertTokenizer"
60
+
61
+ # TokenClassification docstring
62
+ _CHECKPOINT_FOR_TOKEN_CLASSIFICATION = "dbmdz/bert-large-cased-finetuned-conll03-english"
63
+ _TOKEN_CLASS_EXPECTED_OUTPUT = (
64
+ "['O', 'I-ORG', 'I-ORG', 'I-ORG', 'O', 'O', 'O', 'O', 'O', 'I-LOC', 'O', 'I-LOC', 'I-LOC'] "
65
+ )
66
+ _TOKEN_CLASS_EXPECTED_LOSS = 0.01
67
+
68
+ # QuestionAnswering docstring
69
+ _CHECKPOINT_FOR_QA = "deepset/bert-base-cased-squad2"
70
+ _QA_EXPECTED_OUTPUT = "'a nice puppet'"
71
+ _QA_EXPECTED_LOSS = 7.41
72
+ _QA_TARGET_START_INDEX = 14
73
+ _QA_TARGET_END_INDEX = 15
74
+
75
+ # SequenceClassification docstring
76
+ _CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "textattack/bert-base-uncased-yelp-polarity"
77
+ _SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_1'"
78
+ _SEQ_CLASS_EXPECTED_LOSS = 0.01
79
+
80
+
81
+ BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
82
+ "bert-base-uncased",
83
+ "bert-large-uncased",
84
+ "bert-base-cased",
85
+ "bert-large-cased",
86
+ "bert-base-multilingual-uncased",
87
+ "bert-base-multilingual-cased",
88
+ "bert-base-chinese",
89
+ "bert-base-german-cased",
90
+ "bert-large-uncased-whole-word-masking",
91
+ "bert-large-cased-whole-word-masking",
92
+ "bert-large-uncased-whole-word-masking-finetuned-squad",
93
+ "bert-large-cased-whole-word-masking-finetuned-squad",
94
+ "bert-base-cased-finetuned-mrpc",
95
+ "bert-base-german-dbmdz-cased",
96
+ "bert-base-german-dbmdz-uncased",
97
+ "cl-tohoku/bert-base-japanese",
98
+ "cl-tohoku/bert-base-japanese-whole-word-masking",
99
+ "cl-tohoku/bert-base-japanese-char",
100
+ "cl-tohoku/bert-base-japanese-char-whole-word-masking",
101
+ "TurkuNLP/bert-base-finnish-cased-v1",
102
+ "TurkuNLP/bert-base-finnish-uncased-v1",
103
+ "wietsedv/bert-base-dutch-cased",
104
+ # See all BERT models at https://huggingface.co/models?filter=bert
105
+ ]
106
+
107
+
108
+ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
109
+ """Load tf checkpoints in a pytorch model."""
110
+ try:
111
+ import re
112
+
113
+ import numpy as np
114
+ import tensorflow as tf
115
+ except ImportError:
116
+ logger.error(
117
+ "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
118
+ "https://www.tensorflow.org/install/ for installation instructions."
119
+ )
120
+ raise
121
+ tf_path = os.path.abspath(tf_checkpoint_path)
122
+ logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
123
+ # Load weights from TF model
124
+ init_vars = tf.train.list_variables(tf_path)
125
+ names = []
126
+ arrays = []
127
+ for name, shape in init_vars:
128
+ logger.info(f"Loading TF weight {name} with shape {shape}")
129
+ array = tf.train.load_variable(tf_path, name)
130
+ names.append(name)
131
+ arrays.append(array)
132
+
133
+ for name, array in zip(names, arrays):
134
+ name = name.split("/")
135
+ # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
136
+ # which are not required for using pretrained model
137
+ if any(
138
+ n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
139
+ for n in name
140
+ ):
141
+ logger.info(f"Skipping {'/'.join(name)}")
142
+ continue
143
+ pointer = model
144
+ for m_name in name:
145
+ if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
146
+ scope_names = re.split(r"_(\d+)", m_name)
147
+ else:
148
+ scope_names = [m_name]
149
+ if scope_names[0] == "kernel" or scope_names[0] == "gamma":
150
+ pointer = getattr(pointer, "weight")
151
+ elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
152
+ pointer = getattr(pointer, "bias")
153
+ elif scope_names[0] == "output_weights":
154
+ pointer = getattr(pointer, "weight")
155
+ elif scope_names[0] == "squad":
156
+ pointer = getattr(pointer, "classifier")
157
+ else:
158
+ try:
159
+ pointer = getattr(pointer, scope_names[0])
160
+ except AttributeError:
161
+ logger.info(f"Skipping {'/'.join(name)}")
162
+ continue
163
+ if len(scope_names) >= 2:
164
+ num = int(scope_names[1])
165
+ pointer = pointer[num]
166
+ if m_name[-11:] == "_embeddings":
167
+ pointer = getattr(pointer, "weight")
168
+ elif m_name == "kernel":
169
+ array = np.transpose(array)
170
+ try:
171
+ if pointer.shape != array.shape:
172
+ raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
173
+ except AssertionError as e:
174
+ e.args += (pointer.shape, array.shape)
175
+ raise
176
+ logger.info(f"Initialize PyTorch weight {name}")
177
+ pointer.data = torch.from_numpy(array)
178
+ return model
179
+
180
+
181
+ class BertEmbeddings(nn.Module):
182
+ """Construct the embeddings from word, position and token_type embeddings."""
183
+
184
+ def __init__(self, config):
185
+ super().__init__()
186
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
187
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
188
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
189
+
190
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
191
+ # any TensorFlow checkpoint file
192
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
193
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
194
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
195
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
196
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
197
+ if version.parse(torch.__version__) > version.parse("1.6.0"):
198
+ self.register_buffer(
199
+ "token_type_ids",
200
+ torch.zeros(self.position_ids.size(), dtype=torch.long),
201
+ persistent=False,
202
+ )
203
+
204
+ def forward(
205
+ self,
206
+ input_ids: Optional[torch.LongTensor] = None,
207
+ token_type_ids: Optional[torch.LongTensor] = None,
208
+ position_ids: Optional[torch.LongTensor] = None,
209
+ inputs_embeds: Optional[torch.FloatTensor] = None,
210
+ past_key_values_length: int = 0,
211
+ ) -> torch.Tensor:
212
+ if input_ids is not None:
213
+ input_shape = input_ids.size()
214
+ else:
215
+ input_shape = inputs_embeds.size()[:-1]
216
+
217
+ seq_length = input_shape[1]
218
+
219
+ if position_ids is None:
220
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
221
+
222
+ # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
223
+ # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
224
+ # issue #5664
225
+ if token_type_ids is None:
226
+ if hasattr(self, "token_type_ids"):
227
+ buffered_token_type_ids = self.token_type_ids[:, :seq_length]
228
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
229
+ token_type_ids = buffered_token_type_ids_expanded
230
+ else:
231
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
232
+
233
+ if inputs_embeds is None:
234
+ inputs_embeds = self.word_embeddings(input_ids)
235
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
236
+
237
+ embeddings = inputs_embeds + token_type_embeddings
238
+ if self.position_embedding_type == "absolute":
239
+ position_embeddings = self.position_embeddings(position_ids)
240
+ embeddings += position_embeddings
241
+ embeddings = self.LayerNorm(embeddings)
242
+ embeddings = self.dropout(embeddings)
243
+ return embeddings
244
+
245
+
246
+ class BertSelfAttention(nn.Module):
247
+ def __init__(self, config, position_embedding_type=None):
248
+ super().__init__()
249
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
250
+ raise ValueError(
251
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
252
+ f"heads ({config.num_attention_heads})"
253
+ )
254
+
255
+ self.num_attention_heads = config.num_attention_heads
256
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
257
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
258
+
259
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
260
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
261
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
262
+
263
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
264
+ self.position_embedding_type = position_embedding_type or getattr(
265
+ config, "position_embedding_type", "absolute"
266
+ )
267
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
268
+ self.max_position_embeddings = config.max_position_embeddings
269
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
270
+
271
+ self.is_decoder = config.is_decoder
272
+
273
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
274
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
275
+ x = x.view(new_x_shape)
276
+ return x.permute(0, 2, 1, 3)
277
+
278
+ def forward(
279
+ self,
280
+ hidden_states: torch.Tensor,
281
+ attention_mask: Optional[torch.FloatTensor] = None,
282
+ head_mask: Optional[torch.FloatTensor] = None,
283
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
284
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
285
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
286
+ output_attentions: Optional[bool] = False,
287
+ ) -> Tuple[torch.Tensor]:
288
+ mixed_query_layer = self.query(hidden_states)
289
+
290
+ # If this is instantiated as a cross-attention module, the keys
291
+ # and values come from an encoder; the attention mask needs to be
292
+ # such that the encoder's padding tokens are not attended to.
293
+ is_cross_attention = encoder_hidden_states is not None
294
+
295
+ if is_cross_attention and past_key_value is not None:
296
+ # reuse k,v, cross_attentions
297
+ key_layer = past_key_value[0]
298
+ value_layer = past_key_value[1]
299
+ attention_mask = encoder_attention_mask
300
+ elif is_cross_attention:
301
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
302
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
303
+ attention_mask = encoder_attention_mask
304
+ elif past_key_value is not None:
305
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
306
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
307
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
308
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
309
+ else:
310
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
311
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
312
+
313
+ query_layer = self.transpose_for_scores(mixed_query_layer)
314
+
315
+ if self.is_decoder:
316
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
317
+ # Further calls to cross_attention layer can then reuse all cross-attention
318
+ # key/value_states (first "if" case)
319
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
320
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
321
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
322
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
323
+ past_key_value = (key_layer, value_layer)
324
+
325
+ # Take the dot product between "query" and "key" to get the raw attention scores.
326
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
327
+
328
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
329
+ seq_length = hidden_states.size()[1]
330
+ position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
331
+ position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
332
+ distance = position_ids_l - position_ids_r
333
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
334
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
335
+
336
+ if self.position_embedding_type == "relative_key":
337
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
338
+ attention_scores = attention_scores + relative_position_scores
339
+ elif self.position_embedding_type == "relative_key_query":
340
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
341
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
342
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
343
+
344
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
345
+ if attention_mask is not None:
346
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
347
+ attention_scores = attention_scores + attention_mask
348
+
349
+ # Normalize the attention scores to probabilities.
350
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
351
+
352
+ # This is actually dropping out entire tokens to attend to, which might
353
+ # seem a bit unusual, but is taken from the original Transformer paper.
354
+ attention_probs = self.dropout(attention_probs)
355
+
356
+ # Mask heads if we want to
357
+ if head_mask is not None:
358
+ attention_probs = attention_probs * head_mask
359
+
360
+ context_layer = torch.matmul(attention_probs, value_layer)
361
+
362
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
363
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
364
+ context_layer = context_layer.view(new_context_layer_shape)
365
+
366
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
367
+
368
+ if self.is_decoder:
369
+ outputs = outputs + (past_key_value,)
370
+ return outputs
371
+
372
+ class BertSelfOutput(nn.Module):
373
+ def __init__(self, config):
374
+ super().__init__()
375
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
376
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
377
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
378
+
379
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
380
+ hidden_states = self.dense(hidden_states)
381
+ hidden_states = self.dropout(hidden_states)
382
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
383
+ return hidden_states
384
+
385
+ class BertAttention(nn.Module):
386
+ def __init__(self, config, position_embedding_type=None):
387
+ super().__init__()
388
+ self.self = BertSelfAttention(config, position_embedding_type=position_embedding_type)
389
+ self.output = BertSelfOutput(config)
390
+ self.pruned_heads = set()
391
+
392
+ def prune_heads(self, heads):
393
+ if len(heads) == 0:
394
+ return
395
+ heads, index = find_pruneable_heads_and_indices(
396
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
397
+ )
398
+
399
+ # Prune linear layers
400
+ self.self.query = prune_linear_layer(self.self.query, index)
401
+ self.self.key = prune_linear_layer(self.self.key, index)
402
+ self.self.value = prune_linear_layer(self.self.value, index)
403
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
404
+
405
+ # Update hyper params and store pruned heads
406
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
407
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
408
+ self.pruned_heads = self.pruned_heads.union(heads)
409
+
410
+ def forward(
411
+ self,
412
+ hidden_states: torch.Tensor,
413
+ attention_mask: Optional[torch.FloatTensor] = None,
414
+ head_mask: Optional[torch.FloatTensor] = None,
415
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
416
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
417
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
418
+ output_attentions: Optional[bool] = False,
419
+ ) -> Tuple[torch.Tensor]:
420
+ self_outputs = self.self(
421
+ hidden_states,
422
+ attention_mask,
423
+ head_mask,
424
+ encoder_hidden_states,
425
+ encoder_attention_mask,
426
+ past_key_value,
427
+ output_attentions,
428
+ )
429
+ attention_output = self.output(self_outputs[0], hidden_states)
430
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
431
+ return outputs
432
+
433
+
434
+ class BertIntermediate(nn.Module):
435
+ def __init__(self, config):
436
+ super().__init__()
437
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
438
+ if isinstance(config.hidden_act, str):
439
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
440
+ else:
441
+ self.intermediate_act_fn = config.hidden_act
442
+
443
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
444
+ hidden_states = self.dense(hidden_states)
445
+ hidden_states = self.intermediate_act_fn(hidden_states)
446
+ return hidden_states
447
+
448
+
449
+ class BertOutputEx(nn.Module):
450
+ def __init__(self, config):
451
+ super().__init__()
452
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
453
+
454
+ self.dense_in_ex = nn.Linear(config.intermediate_size, config.ex_size)
455
+ self.dense_out_ex = nn.Linear(config.ex_size, config.hidden_size)
456
+
457
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
458
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
459
+
460
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
461
+ hidden_states_ori = self.dense(hidden_states)
462
+
463
+ hidden_states_ex = self.dense_in_ex(hidden_states)
464
+ hidden_states_ex = self.dense_out_ex(hidden_states_ex)
465
+
466
+ hidden_states_ori = self.dropout(hidden_states_ori)
467
+
468
+ hidden_states = self.LayerNorm(hidden_states_ori + hidden_states_ex + input_tensor)
469
+ return hidden_states
470
+
471
+ class BertOutput(nn.Module):
472
+ def __init__(self, config):
473
+ super().__init__()
474
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
475
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
476
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
477
+
478
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
479
+ hidden_states = self.dense(hidden_states)
480
+ hidden_states = self.dropout(hidden_states)
481
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
482
+ return hidden_states
483
+
484
+ class BertLayerEx(nn.Module):
485
+ def __init__(self, config):
486
+ super().__init__()
487
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
488
+ self.seq_len_dim = 1
489
+ self.attention = BertAttention(config)
490
+ self.is_decoder = config.is_decoder
491
+ self.add_cross_attention = config.add_cross_attention
492
+ if self.add_cross_attention:
493
+ if not self.is_decoder:
494
+ raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
495
+ self.crossattention = BertAttention(config, position_embedding_type="absolute")
496
+ self.intermediate = BertIntermediate(config)
497
+ self.output = BertOutputEx(config)
498
+
499
+ def forward(
500
+ self,
501
+ hidden_states: torch.Tensor,
502
+ attention_mask: Optional[torch.FloatTensor] = None,
503
+ head_mask: Optional[torch.FloatTensor] = None,
504
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
505
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
506
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
507
+ output_attentions: Optional[bool] = False,
508
+ ) -> Tuple[torch.Tensor]:
509
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
510
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
511
+ self_attention_outputs = self.attention(
512
+ hidden_states,
513
+ attention_mask,
514
+ head_mask,
515
+ output_attentions=output_attentions,
516
+ past_key_value=self_attn_past_key_value,
517
+ )
518
+ attention_output = self_attention_outputs[0]
519
+
520
+ # if decoder, the last output is tuple of self-attn cache
521
+ if self.is_decoder:
522
+ outputs = self_attention_outputs[1:-1]
523
+ present_key_value = self_attention_outputs[-1]
524
+ else:
525
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
526
+
527
+ cross_attn_present_key_value = None
528
+ if self.is_decoder and encoder_hidden_states is not None:
529
+ if not hasattr(self, "crossattention"):
530
+ raise ValueError(
531
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
532
+ " by setting `config.add_cross_attention=True`"
533
+ )
534
+
535
+ # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
536
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
537
+ cross_attention_outputs = self.crossattention(
538
+ attention_output,
539
+ attention_mask,
540
+ head_mask,
541
+ encoder_hidden_states,
542
+ encoder_attention_mask,
543
+ cross_attn_past_key_value,
544
+ output_attentions,
545
+ )
546
+ attention_output = cross_attention_outputs[0]
547
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
548
+
549
+ # add cross-attn cache to positions 3,4 of present_key_value tuple
550
+ cross_attn_present_key_value = cross_attention_outputs[-1]
551
+ present_key_value = present_key_value + cross_attn_present_key_value
552
+
553
+ layer_output = apply_chunking_to_forward(
554
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
555
+ )
556
+ outputs = (layer_output,) + outputs
557
+
558
+ # if decoder, return the attn key/values as the last output
559
+ if self.is_decoder:
560
+ outputs = outputs + (present_key_value,)
561
+
562
+ return outputs
563
+
564
+ def feed_forward_chunk(self, attention_output):
565
+ intermediate_output = self.intermediate(attention_output)
566
+ layer_output = self.output(intermediate_output, attention_output)
567
+ return layer_output
568
+
569
+ class BertLayer(nn.Module):
570
+ def __init__(self, config):
571
+ super().__init__()
572
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
573
+ self.seq_len_dim = 1
574
+ self.attention = BertAttention(config)
575
+ self.is_decoder = config.is_decoder
576
+ self.add_cross_attention = config.add_cross_attention
577
+ if self.add_cross_attention:
578
+ if not self.is_decoder:
579
+ raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
580
+ self.crossattention = BertAttention(config, position_embedding_type="absolute")
581
+ self.intermediate = BertIntermediate(config)
582
+ self.output = BertOutput(config)
583
+
584
+ def forward(
585
+ self,
586
+ hidden_states: torch.Tensor,
587
+ attention_mask: Optional[torch.FloatTensor] = None,
588
+ head_mask: Optional[torch.FloatTensor] = None,
589
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
590
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
591
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
592
+ output_attentions: Optional[bool] = False,
593
+ ) -> Tuple[torch.Tensor]:
594
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
595
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
596
+ self_attention_outputs = self.attention(
597
+ hidden_states,
598
+ attention_mask,
599
+ head_mask,
600
+ output_attentions=output_attentions,
601
+ past_key_value=self_attn_past_key_value,
602
+ )
603
+ attention_output = self_attention_outputs[0]
604
+
605
+ # if decoder, the last output is tuple of self-attn cache
606
+ if self.is_decoder:
607
+ outputs = self_attention_outputs[1:-1]
608
+ present_key_value = self_attention_outputs[-1]
609
+ else:
610
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
611
+
612
+ cross_attn_present_key_value = None
613
+ if self.is_decoder and encoder_hidden_states is not None:
614
+ if not hasattr(self, "crossattention"):
615
+ raise ValueError(
616
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
617
+ " by setting `config.add_cross_attention=True`"
618
+ )
619
+
620
+ # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
621
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
622
+ cross_attention_outputs = self.crossattention(
623
+ attention_output,
624
+ attention_mask,
625
+ head_mask,
626
+ encoder_hidden_states,
627
+ encoder_attention_mask,
628
+ cross_attn_past_key_value,
629
+ output_attentions,
630
+ )
631
+ attention_output = cross_attention_outputs[0]
632
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
633
+
634
+ # add cross-attn cache to positions 3,4 of present_key_value tuple
635
+ cross_attn_present_key_value = cross_attention_outputs[-1]
636
+ present_key_value = present_key_value + cross_attn_present_key_value
637
+
638
+ layer_output = apply_chunking_to_forward(
639
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
640
+ )
641
+ outputs = (layer_output,) + outputs
642
+
643
+ # if decoder, return the attn key/values as the last output
644
+ if self.is_decoder:
645
+ outputs = outputs + (present_key_value,)
646
+
647
+ return outputs
648
+
649
+ def feed_forward_chunk(self, attention_output):
650
+ intermediate_output = self.intermediate(attention_output)
651
+ layer_output = self.output(intermediate_output, attention_output)
652
+ return layer_output
653
+
654
+ class BertEncoder(nn.Module):
655
+ def __init__(self, config):
656
+ super().__init__()
657
+ self.config = config
658
+ kb_layer = self.config.kb_layer
659
+ self.layer = nn.ModuleList([BertLayerEx(config) if i in kb_layer else BertLayer(config) for i in range(config.num_hidden_layers)])
660
+ self.gradient_checkpointing = False
661
+
662
+ def forward(
663
+ self,
664
+ hidden_states: torch.Tensor,
665
+ attention_mask: Optional[torch.FloatTensor] = None,
666
+ head_mask: Optional[torch.FloatTensor] = None,
667
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
668
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
669
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
670
+ use_cache: Optional[bool] = None,
671
+ output_attentions: Optional[bool] = False,
672
+ output_hidden_states: Optional[bool] = False,
673
+ return_dict: Optional[bool] = True,
674
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
675
+ all_hidden_states = () if output_hidden_states else None
676
+ all_self_attentions = () if output_attentions else None
677
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
678
+
679
+ next_decoder_cache = () if use_cache else None
680
+ for i, layer_module in enumerate(self.layer):
681
+ if output_hidden_states:
682
+ all_hidden_states = all_hidden_states + (hidden_states,)
683
+
684
+ layer_head_mask = head_mask[i] if head_mask is not None else None
685
+ past_key_value = past_key_values[i] if past_key_values is not None else None
686
+
687
+ if self.gradient_checkpointing and self.training:
688
+
689
+ if use_cache:
690
+ logger.warning(
691
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
692
+ )
693
+ use_cache = False
694
+
695
+ def create_custom_forward(module):
696
+ def custom_forward(*inputs):
697
+ return module(*inputs, past_key_value, output_attentions)
698
+
699
+ return custom_forward
700
+
701
+ layer_outputs = torch.utils.checkpoint.checkpoint(
702
+ create_custom_forward(layer_module),
703
+ hidden_states,
704
+ attention_mask,
705
+ layer_head_mask,
706
+ encoder_hidden_states,
707
+ encoder_attention_mask,
708
+ )
709
+ else:
710
+ layer_outputs = layer_module(
711
+ hidden_states,
712
+ attention_mask,
713
+ layer_head_mask,
714
+ encoder_hidden_states,
715
+ encoder_attention_mask,
716
+ past_key_value,
717
+ output_attentions,
718
+ )
719
+
720
+ hidden_states = layer_outputs[0]
721
+ if use_cache:
722
+ next_decoder_cache += (layer_outputs[-1],)
723
+ if output_attentions:
724
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
725
+ if self.config.add_cross_attention:
726
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
727
+
728
+ if output_hidden_states:
729
+ all_hidden_states = all_hidden_states + (hidden_states,)
730
+
731
+ if not return_dict:
732
+ return tuple(
733
+ v
734
+ for v in [
735
+ hidden_states,
736
+ next_decoder_cache,
737
+ all_hidden_states,
738
+ all_self_attentions,
739
+ all_cross_attentions,
740
+ ]
741
+ if v is not None
742
+ )
743
+ return BaseModelOutputWithPastAndCrossAttentions(
744
+ last_hidden_state=hidden_states,
745
+ past_key_values=next_decoder_cache,
746
+ hidden_states=all_hidden_states,
747
+ attentions=all_self_attentions,
748
+ cross_attentions=all_cross_attentions,
749
+ )
750
+
751
+
752
+ class BertPooler(nn.Module):
753
+ def __init__(self, config):
754
+ super().__init__()
755
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
756
+ self.activation = nn.Tanh()
757
+
758
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
759
+ # We "pool" the model by simply taking the hidden state corresponding
760
+ # to the first token.
761
+ first_token_tensor = hidden_states[:, 0]
762
+ pooled_output = self.dense(first_token_tensor)
763
+ pooled_output = self.activation(pooled_output)
764
+ return pooled_output
765
+
766
+ class BertPredictionHeadTransform(nn.Module):
767
+ def __init__(self, config):
768
+ super().__init__()
769
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
770
+ if isinstance(config.hidden_act, str):
771
+ self.transform_act_fn = ACT2FN[config.hidden_act]
772
+ else:
773
+ self.transform_act_fn = config.hidden_act
774
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
775
+
776
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
777
+ hidden_states = self.dense(hidden_states)
778
+ hidden_states = self.transform_act_fn(hidden_states)
779
+ hidden_states = self.LayerNorm(hidden_states)
780
+ return hidden_states
781
+
782
+ class BertLMPredictionHead(nn.Module):
783
+ def __init__(self, config):
784
+ super().__init__()
785
+ self.transform = BertPredictionHeadTransform(config)
786
+
787
+ # The output weights are the same as the input embeddings, but there is
788
+ # an output-only bias for each token.
789
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
790
+
791
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
792
+
793
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
794
+ self.decoder.bias = self.bias
795
+
796
+ def forward(self, hidden_states):
797
+ hidden_states = self.transform(hidden_states)
798
+ hidden_states = self.decoder(hidden_states)
799
+ return hidden_states
800
+
801
+ class BertOnlyMLMHead(nn.Module):
802
+ def __init__(self, config):
803
+ super().__init__()
804
+ self.predictions = BertLMPredictionHead(config)
805
+
806
+ def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
807
+ prediction_scores = self.predictions(sequence_output)
808
+ return prediction_scores
809
+
810
+ class BertPreTrainedModel(PreTrainedModel):
811
+ """
812
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
813
+ models.
814
+ """
815
+
816
+ config_class = BertConfig
817
+ load_tf_weights = load_tf_weights_in_bert
818
+ base_model_prefix = "bert"
819
+ supports_gradient_checkpointing = True
820
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
821
+
822
+ def _init_weights(self, module):
823
+ """Initialize the weights"""
824
+ if isinstance(module, nn.Linear):
825
+ # Slightly different from the TF version which uses truncated_normal for initialization
826
+ # cf https://github.com/pytorch/pytorch/pull/5617
827
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
828
+ if module.bias is not None:
829
+ module.bias.data.zero_()
830
+ elif isinstance(module, nn.Embedding):
831
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
832
+ if module.padding_idx is not None:
833
+ module.weight.data[module.padding_idx].zero_()
834
+ elif isinstance(module, nn.LayerNorm):
835
+ module.bias.data.zero_()
836
+ module.weight.data.fill_(1.0)
837
+
838
+ def _set_gradient_checkpointing(self, module, value=False):
839
+ if isinstance(module, BertEncoder):
840
+ module.gradient_checkpointing = value
841
+
842
+
843
+ @dataclass
844
+ class BertForPreTrainingOutput(ModelOutput):
845
+ """
846
+ Output type of [`BertForPreTraining`].
847
+
848
+ Args:
849
+ loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
850
+ Total loss as the sum of the masked language modeling loss and the next sequence prediction
851
+ (classification) loss.
852
+ prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
853
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
854
+ seq_relationship_logits (`torch.FloatTensor` of shape `(batch_size, 2)`):
855
+ Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
856
+ before SoftMax).
857
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
858
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
859
+ shape `(batch_size, sequence_length, hidden_size)`.
860
+
861
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
862
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
863
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
864
+ sequence_length)`.
865
+
866
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
867
+ heads.
868
+ """
869
+
870
+ loss: Optional[torch.FloatTensor] = None
871
+ prediction_logits: torch.FloatTensor = None
872
+ seq_relationship_logits: torch.FloatTensor = None
873
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
874
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
875
+
876
+
877
+ BERT_START_DOCSTRING = r"""
878
+
879
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
880
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
881
+ etc.)
882
+
883
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
884
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
885
+ and behavior.
886
+
887
+ Parameters:
888
+ config ([`BertConfig`]): Model configuration class with all the parameters of the model.
889
+ Initializing with a config file does not load the weights associated with the model, only the
890
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
891
+ """
892
+
893
+ BERT_INPUTS_DOCSTRING = r"""
894
+ Args:
895
+ input_ids (`torch.LongTensor` of shape `({0})`):
896
+ Indices of input sequence tokens in the vocabulary.
897
+
898
+ Indices can be obtained using [`BertTokenizer`]. See [`PreTrainedTokenizer.encode`] and
899
+ [`PreTrainedTokenizer.__call__`] for details.
900
+
901
+ [What are input IDs?](../glossary#input-ids)
902
+ attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
903
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
904
+
905
+ - 1 for tokens that are **not masked**,
906
+ - 0 for tokens that are **masked**.
907
+
908
+ [What are attention masks?](../glossary#attention-mask)
909
+ token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
910
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
911
+ 1]`:
912
+
913
+ - 0 corresponds to a *sentence A* token,
914
+ - 1 corresponds to a *sentence B* token.
915
+
916
+ [What are token type IDs?](../glossary#token-type-ids)
917
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
918
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
919
+ config.max_position_embeddings - 1]`.
920
+
921
+ [What are position IDs?](../glossary#position-ids)
922
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
923
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
924
+
925
+ - 1 indicates the head is **not masked**,
926
+ - 0 indicates the head is **masked**.
927
+
928
+ inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
929
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
930
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
931
+ model's internal embedding lookup matrix.
932
+ output_attentions (`bool`, *optional*):
933
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
934
+ tensors for more detail.
935
+ output_hidden_states (`bool`, *optional*):
936
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
937
+ more detail.
938
+ return_dict (`bool`, *optional*):
939
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
940
+ """
941
+
942
+ @add_start_docstrings(
943
+ "The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
944
+ BERT_START_DOCSTRING,
945
+ )
946
+ class BertModel(BertPreTrainedModel):
947
+ """
948
+
949
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
950
+ cross-attention is added between the self-attention layers, following the architecture described in [Attention is
951
+ all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
952
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
953
+
954
+ To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
955
+ to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
956
+ `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
957
+ """
958
+
959
+ def __init__(self, config, add_pooling_layer=True):
960
+ super().__init__(config)
961
+ self.config = config
962
+
963
+ self.embeddings = BertEmbeddings(config)
964
+ self.encoder = BertEncoder(config)
965
+
966
+ self.pooler = BertPooler(config) if add_pooling_layer else None
967
+
968
+ # Initialize weights and apply final processing
969
+ self.post_init()
970
+
971
+ def get_input_embeddings(self):
972
+ return self.embeddings.word_embeddings
973
+
974
+ def set_input_embeddings(self, value):
975
+ self.embeddings.word_embeddings = value
976
+
977
+ def _prune_heads(self, heads_to_prune):
978
+ """
979
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
980
+ class PreTrainedModel
981
+ """
982
+ for layer, heads in heads_to_prune.items():
983
+ self.encoder.layer[layer].attention.prune_heads(heads)
984
+
985
+ @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
986
+ @add_code_sample_docstrings(
987
+ processor_class=_TOKENIZER_FOR_DOC,
988
+ checkpoint=_CHECKPOINT_FOR_DOC,
989
+ output_type=BaseModelOutputWithPoolingAndCrossAttentions,
990
+ config_class=_CONFIG_FOR_DOC,
991
+ )
992
+ def forward(
993
+ self,
994
+ input_ids: Optional[torch.Tensor] = None,
995
+ attention_mask: Optional[torch.Tensor] = None,
996
+ token_type_ids: Optional[torch.Tensor] = None,
997
+ position_ids: Optional[torch.Tensor] = None,
998
+ head_mask: Optional[torch.Tensor] = None,
999
+ inputs_embeds: Optional[torch.Tensor] = None,
1000
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1001
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1002
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1003
+ use_cache: Optional[bool] = None,
1004
+ output_attentions: Optional[bool] = None,
1005
+ output_hidden_states: Optional[bool] = None,
1006
+ return_dict: Optional[bool] = None,
1007
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
1008
+ r"""
1009
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1010
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
1011
+ the model is configured as a decoder.
1012
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
1013
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
1014
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
1015
+
1016
+ - 1 for tokens that are **not masked**,
1017
+ - 0 for tokens that are **masked**.
1018
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
1019
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
1020
+
1021
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
1022
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
1023
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
1024
+ use_cache (`bool`, *optional*):
1025
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1026
+ `past_key_values`).
1027
+ """
1028
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1029
+ output_hidden_states = (
1030
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1031
+ )
1032
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1033
+
1034
+ if self.config.is_decoder:
1035
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1036
+ else:
1037
+ use_cache = False
1038
+
1039
+ if input_ids is not None and inputs_embeds is not None:
1040
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
1041
+ elif input_ids is not None:
1042
+ input_shape = input_ids.size()
1043
+ elif inputs_embeds is not None:
1044
+ input_shape = inputs_embeds.size()[:-1]
1045
+ else:
1046
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
1047
+
1048
+ batch_size, seq_length = input_shape
1049
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
1050
+
1051
+ # past_key_values_length
1052
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
1053
+
1054
+ if attention_mask is None:
1055
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
1056
+
1057
+ if token_type_ids is None:
1058
+ if hasattr(self.embeddings, "token_type_ids"):
1059
+ buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
1060
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
1061
+ token_type_ids = buffered_token_type_ids_expanded
1062
+ else:
1063
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
1064
+
1065
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
1066
+ # ourselves in which case we just need to make it broadcastable to all heads.
1067
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
1068
+
1069
+ # If a 2D or 3D attention mask is provided for the cross-attention
1070
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
1071
+ if self.config.is_decoder and encoder_hidden_states is not None:
1072
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
1073
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
1074
+ if encoder_attention_mask is None:
1075
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
1076
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
1077
+ else:
1078
+ encoder_extended_attention_mask = None
1079
+
1080
+ # Prepare head mask if needed
1081
+ # 1.0 in head_mask indicate we keep the head
1082
+ # attention_probs has shape bsz x n_heads x N x N
1083
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
1084
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
1085
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
1086
+
1087
+ embedding_output = self.embeddings(
1088
+ input_ids=input_ids,
1089
+ position_ids=position_ids,
1090
+ token_type_ids=token_type_ids,
1091
+ inputs_embeds=inputs_embeds,
1092
+ past_key_values_length=past_key_values_length,
1093
+ )
1094
+ encoder_outputs = self.encoder(
1095
+ embedding_output,
1096
+ attention_mask=extended_attention_mask,
1097
+ head_mask=head_mask,
1098
+ encoder_hidden_states=encoder_hidden_states,
1099
+ encoder_attention_mask=encoder_extended_attention_mask,
1100
+ past_key_values=past_key_values,
1101
+ use_cache=use_cache,
1102
+ output_attentions=output_attentions,
1103
+ output_hidden_states=output_hidden_states,
1104
+ return_dict=return_dict,
1105
+ )
1106
+ sequence_output = encoder_outputs[0]
1107
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
1108
+
1109
+ if not return_dict:
1110
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
1111
+
1112
+ return BaseModelOutputWithPoolingAndCrossAttentions(
1113
+ last_hidden_state=sequence_output,
1114
+ pooler_output=pooled_output,
1115
+ past_key_values=encoder_outputs.past_key_values,
1116
+ hidden_states=encoder_outputs.hidden_states,
1117
+ attentions=encoder_outputs.attentions,
1118
+ cross_attentions=encoder_outputs.cross_attentions,
1119
+ )
1120
+
1121
+ @add_start_docstrings("""Bert Model with a `language modeling` head on top.""", BERT_START_DOCSTRING)
1122
+ class EXBertForMaskedLM(BertPreTrainedModel):
1123
+
1124
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1125
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
1126
+
1127
+ def __init__(self, config):
1128
+ super().__init__(config)
1129
+
1130
+ if config.is_decoder:
1131
+ logger.warning(
1132
+ "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for "
1133
+ "bi-directional self-attention."
1134
+ )
1135
+
1136
+ self.bert = BertModel(config, add_pooling_layer=False)
1137
+ self.cls = BertOnlyMLMHead(config)
1138
+
1139
+ # Initialize weights and apply final processing
1140
+ self.post_init()
1141
+
1142
+ def get_output_embeddings(self):
1143
+ return self.cls.predictions.decoder
1144
+
1145
+ def set_output_embeddings(self, new_embeddings):
1146
+ self.cls.predictions.decoder = new_embeddings
1147
+
1148
+ @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1149
+ @add_code_sample_docstrings(
1150
+ processor_class=_TOKENIZER_FOR_DOC,
1151
+ checkpoint=_CHECKPOINT_FOR_DOC,
1152
+ output_type=MaskedLMOutput,
1153
+ config_class=_CONFIG_FOR_DOC,
1154
+ expected_output="'paris'",
1155
+ expected_loss=0.88,
1156
+ )
1157
+ def forward(
1158
+ self,
1159
+ input_ids: Optional[torch.Tensor] = None,
1160
+ attention_mask: Optional[torch.Tensor] = None,
1161
+ token_type_ids: Optional[torch.Tensor] = None,
1162
+ position_ids: Optional[torch.Tensor] = None,
1163
+ head_mask: Optional[torch.Tensor] = None,
1164
+ inputs_embeds: Optional[torch.Tensor] = None,
1165
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1166
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1167
+ labels: Optional[torch.Tensor] = None,
1168
+ output_attentions: Optional[bool] = None,
1169
+ output_hidden_states: Optional[bool] = None,
1170
+ return_dict: Optional[bool] = None,
1171
+ ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
1172
+ r"""
1173
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1174
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
1175
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
1176
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
1177
+ """
1178
+
1179
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1180
+
1181
+ outputs = self.bert(
1182
+ input_ids,
1183
+ attention_mask=attention_mask,
1184
+ token_type_ids=token_type_ids,
1185
+ position_ids=position_ids,
1186
+ head_mask=head_mask,
1187
+ inputs_embeds=inputs_embeds,
1188
+ encoder_hidden_states=encoder_hidden_states,
1189
+ encoder_attention_mask=encoder_attention_mask,
1190
+ output_attentions=output_attentions,
1191
+ output_hidden_states=output_hidden_states,
1192
+ return_dict=return_dict,
1193
+ )
1194
+
1195
+ sequence_output = outputs[0]
1196
+ prediction_scores = self.cls(sequence_output)
1197
+ # pos = (input_ids == self.config.mask_token_id).nonzero(as_tuple=True)
1198
+
1199
+ masked_lm_loss = None
1200
+ if labels is not None:
1201
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
1202
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1203
+ # masked_lm_loss = loss_fct(prediction_scores[pos[0], pos[1], :].view(-1, self.config.vocab_size), labels.view(-1))
1204
+
1205
+ if not return_dict:
1206
+ output = (prediction_scores,) + outputs[2:]
1207
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1208
+
1209
+ return MaskedLMOutput(
1210
+ loss=masked_lm_loss,
1211
+ logits=prediction_scores,
1212
+ hidden_states=outputs.hidden_states,
1213
+ attentions=outputs.attentions,
1214
+ )
1215
+
1216
+ def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
1217
+ input_shape = input_ids.shape
1218
+ effective_batch_size = input_shape[0]
1219
+
1220
+ # add a dummy token
1221
+ if self.config.pad_token_id is None:
1222
+ raise ValueError("The PAD token should be defined for generation")
1223
+
1224
+ attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)
1225
+ dummy_token = torch.full(
1226
+ (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
1227
+ )
1228
+ input_ids = torch.cat([input_ids, dummy_token], dim=1)
1229
+
1230
+ return {"input_ids": input_ids, "attention_mask": attention_mask}
1231
+
1232
+ @add_start_docstrings(
1233
+ """
1234
+ Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
1235
+ output) e.g. for GLUE tasks.
1236
+ """,
1237
+ BERT_START_DOCSTRING,
1238
+ )
1239
+ class EXBertForSequenceClassification(BertPreTrainedModel):
1240
+ def __init__(self, config):
1241
+ super().__init__(config)
1242
+ self.num_labels = config.num_labels
1243
+ self.config = config
1244
+
1245
+ self.bert = BertModel(config)
1246
+ classifier_dropout = (
1247
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
1248
+ )
1249
+ self.dropout = nn.Dropout(classifier_dropout)
1250
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1251
+
1252
+ # Initialize weights and apply final processing
1253
+ self.post_init()
1254
+
1255
+ @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1256
+ @add_code_sample_docstrings(
1257
+ processor_class=_TOKENIZER_FOR_DOC,
1258
+ checkpoint=_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION,
1259
+ output_type=SequenceClassifierOutput,
1260
+ config_class=_CONFIG_FOR_DOC,
1261
+ expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
1262
+ expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
1263
+ )
1264
+ def forward(
1265
+ self,
1266
+ input_ids: Optional[torch.Tensor] = None,
1267
+ attention_mask: Optional[torch.Tensor] = None,
1268
+ token_type_ids: Optional[torch.Tensor] = None,
1269
+ position_ids: Optional[torch.Tensor] = None,
1270
+ head_mask: Optional[torch.Tensor] = None,
1271
+ inputs_embeds: Optional[torch.Tensor] = None,
1272
+ labels: Optional[torch.Tensor] = None,
1273
+ output_attentions: Optional[bool] = None,
1274
+ output_hidden_states: Optional[bool] = None,
1275
+ return_dict: Optional[bool] = None,
1276
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
1277
+ r"""
1278
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1279
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1280
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1281
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1282
+ """
1283
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1284
+
1285
+ outputs = self.bert(
1286
+ input_ids,
1287
+ attention_mask=attention_mask,
1288
+ token_type_ids=token_type_ids,
1289
+ position_ids=position_ids,
1290
+ head_mask=head_mask,
1291
+ inputs_embeds=inputs_embeds,
1292
+ output_attentions=output_attentions,
1293
+ output_hidden_states=output_hidden_states,
1294
+ return_dict=return_dict,
1295
+ )
1296
+
1297
+ pooled_output = outputs[1]
1298
+
1299
+ pooled_output = self.dropout(pooled_output)
1300
+ logits = self.classifier(pooled_output)
1301
+
1302
+ loss = None
1303
+ if labels is not None:
1304
+ if self.config.problem_type is None:
1305
+ if self.num_labels == 1:
1306
+ self.config.problem_type = "regression"
1307
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1308
+ self.config.problem_type = "single_label_classification"
1309
+ else:
1310
+ self.config.problem_type = "multi_label_classification"
1311
+
1312
+ if self.config.problem_type == "regression":
1313
+ loss_fct = MSELoss()
1314
+ if self.num_labels == 1:
1315
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1316
+ else:
1317
+ loss = loss_fct(logits, labels)
1318
+ elif self.config.problem_type == "single_label_classification":
1319
+ loss_fct = CrossEntropyLoss()
1320
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1321
+ elif self.config.problem_type == "multi_label_classification":
1322
+ loss_fct = BCEWithLogitsLoss()
1323
+ loss = loss_fct(logits, labels)
1324
+ if not return_dict:
1325
+ output = (logits,) + outputs[2:]
1326
+ return ((loss,) + output) if loss is not None else output
1327
+
1328
+ return SequenceClassifierOutput(
1329
+ loss=loss,
1330
+ logits=logits,
1331
+ hidden_states=outputs.hidden_states,
1332
+ attentions=outputs.attentions,
1333
+ )
1334
+
1335
+ bert_mapping = {
1336
+ 'FT': EXBertForSequenceClassification,
1337
+ 'PT': EXBertForMaskedLM
1338
+ }
src/models/__pycache__/one_shot_learner.cpython-38.pyc ADDED
Binary file (4.19 kB). View file
 
src/models/one_shot_learner.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from allennlp.modules.feedforward import FeedForward
3
+ from allennlp.modules.seq2vec_encoders import PytorchSeq2VecWrapper
4
+ from higher.patch import monkeypatch as make_functional
5
+
6
+
7
+ class ConditionedParameter(torch.nn.Module):
8
+ def __init__(self, parameter, condition_dim=1024, hidden_dim=128, max_scale=1):
9
+ super().__init__()
10
+ self.parameter_shape = parameter.shape
11
+
12
+ if len(self.parameter_shape) == 2: # condition_dim是从lstm中得到的tensor,然后用linear学习返回到768作为更新的parm_dict
13
+ self.conditioners = torch.nn.Sequential(
14
+ torch.nn.utils.weight_norm(torch.nn.Linear(condition_dim, hidden_dim)),
15
+ torch.nn.Tanh(),
16
+ torch.nn.utils.weight_norm(
17
+ torch.nn.Linear(
18
+ hidden_dim, 2 * (parameter.shape[0] + parameter.shape[1]) + 1
19
+ )
20
+ ),
21
+ )
22
+ elif len(self.parameter_shape) == 1:
23
+ self.conditioners = torch.nn.Sequential(
24
+ torch.nn.utils.weight_norm(torch.nn.Linear(condition_dim, hidden_dim)),
25
+ torch.nn.Tanh(),
26
+ torch.nn.utils.weight_norm(
27
+ torch.nn.Linear(hidden_dim, 2 * parameter.shape[0] + 1)
28
+ ),
29
+ )
30
+ else:
31
+ raise RuntimeError()
32
+
33
+ self.max_scale = max_scale
34
+
35
+ def forward(self, inputs, grad):
36
+
37
+ if len(self.parameter_shape) == 2:
38
+ (
39
+ conditioner_cola,
40
+ conditioner_rowa,
41
+ conditioner_colb,
42
+ conditioner_rowb,
43
+ conditioner_norm,
44
+ ) = self.conditioners(inputs).split(
45
+ [
46
+ self.parameter_shape[1],
47
+ self.parameter_shape[0],
48
+ self.parameter_shape[1],
49
+ self.parameter_shape[0],
50
+ 1,
51
+ ],
52
+ dim=-1,
53
+ )
54
+
55
+ a = conditioner_rowa.softmax(-1).T @ conditioner_cola
56
+ b = conditioner_rowb.softmax(-1).T @ conditioner_colb
57
+
58
+ elif len(self.parameter_shape) == 1:
59
+ a, b, conditioner_norm = self.conditioners(inputs).split(
60
+ [self.parameter_shape[0], self.parameter_shape[0], 1], dim=-1
61
+ )
62
+ else:
63
+ raise RuntimeError()
64
+
65
+ return (
66
+ self.max_scale
67
+ * torch.mean(conditioner_norm.sigmoid(), dim=0).squeeze() # 多条我们直接取mean
68
+ * (grad * a.squeeze() + b.squeeze())
69
+ )
70
+
71
+
72
+ class LSTMConditioner(torch.nn.Module):
73
+ def __init__(
74
+ self,
75
+ vocab_dim=30522,
76
+ embedding_dim=768,
77
+ hidden_dim=256,
78
+ output_dim=1024,
79
+ embedding_init=None,
80
+ ):
81
+ super().__init__()
82
+ self.embedding = torch.nn.Embedding(
83
+ num_embeddings=vocab_dim,
84
+ embedding_dim=embedding_dim,
85
+ padding_idx=0,
86
+ _weight=embedding_init,
87
+ )
88
+ self.lstm = PytorchSeq2VecWrapper(
89
+ torch.nn.LSTM(
90
+ input_size=embedding_dim,
91
+ hidden_size=hidden_dim,
92
+ num_layers=1,
93
+ bidirectional=True,
94
+ batch_first=True,
95
+ )
96
+ )
97
+ self.linear = FeedForward(
98
+ input_dim=hidden_dim * 2,
99
+ num_layers=1,
100
+ hidden_dims=[output_dim],
101
+ activations=[torch.nn.Tanh()],
102
+ )
103
+
104
+ def forward(self, inputs, masks):
105
+ return self.linear(self.lstm(self.embedding(inputs), masks)) # 1, 64
106
+
107
+
108
+ class OneShotLearner(torch.nn.Module):
109
+ def __init__(
110
+ self,
111
+ model,
112
+ vocab_dim=30522,
113
+ embedding_dim=768,
114
+ hidden_dim=128,
115
+ condition_dim=1024,
116
+ include_set={},
117
+ max_scale=1e-3,
118
+ embedding_init=None,
119
+ ):
120
+ super().__init__()
121
+
122
+ self.param2conditioner_map = {
123
+ n: "{}_conditioner".format(n).replace(".", "_")
124
+ for n, p in model.named_parameters()
125
+ if n in include_set
126
+ }
127
+
128
+ self.conditioners = torch.nn.ModuleDict(
129
+ {
130
+ self.param2conditioner_map[n]: ConditionedParameter(
131
+ p,
132
+ condition_dim,
133
+ hidden_dim,
134
+ max_scale=max_scale,
135
+ )
136
+ for n, p in model.named_parameters()
137
+ if n in include_set
138
+ }
139
+ )
140
+
141
+ self.condition = LSTMConditioner(
142
+ vocab_dim,
143
+ embedding_dim,
144
+ hidden_dim,
145
+ condition_dim,
146
+ embedding_init=embedding_init,
147
+ )
148
+
149
+ def forward(self, inputs, masks, grads=None):
150
+ condition = self.condition(inputs, masks) # LSTM输出condition
151
+ return {
152
+ p: self.conditioners[self.param2conditioner_map[p]](
153
+ condition,
154
+ grad=grads[p] if grads else None,
155
+ )
156
+ for p, c in self.param2conditioner_map.items()
157
+ }