tommymarto commited on
Commit
f411a29
1 Parent(s): 0620d56

new config and new MCQBert classes

Browse files
Files changed (3) hide show
  1. config.json +6 -1
  2. configuration_mcqbert.py +10 -0
  3. modeling_mcqbert.py +39 -0
config.json CHANGED
@@ -1,8 +1,12 @@
1
  {
2
  "_name_or_path": "epfl-ml4ed/MCQBert",
3
  "architectures": [
4
- "BertModel"
5
  ],
 
 
 
 
6
  "attention_probs_dropout_prob": 0.1,
7
  "classifier_dropout": null,
8
  "cls_hidden_size": 256,
@@ -10,6 +14,7 @@
10
  "hidden_dropout_prob": 0.1,
11
  "hidden_size": 768,
12
  "initializer_range": 0.02,
 
13
  "intermediate_size": 3072,
14
  "layer_norm_eps": 1e-12,
15
  "max_position_embeddings": 512,
 
1
  {
2
  "_name_or_path": "epfl-ml4ed/MCQBert",
3
  "architectures": [
4
+ "MCQBert"
5
  ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_mcqbert.MCQBertConfig",
8
+ "AutoModelForCausalLM": "modeling_mcqbert.MCQBert"
9
+ },
10
  "attention_probs_dropout_prob": 0.1,
11
  "classifier_dropout": null,
12
  "cls_hidden_size": 256,
 
14
  "hidden_dropout_prob": 0.1,
15
  "hidden_size": 768,
16
  "initializer_range": 0.02,
17
+ "integration_strategy": null,
18
  "intermediate_size": 3072,
19
  "layer_norm_eps": 1e-12,
20
  "max_position_embeddings": 512,
configuration_mcqbert.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BertConfig
2
+
3
+ class MCQBertConfig(BertConfig):
4
+ model_type = "mcqbert"
5
+
6
+ def __init__(self, integration_strategy=None, student_embedding_size=4096, cls_hidden_size=256, **kwargs):
7
+ super().__init__(**kwargs)
8
+ self.integration_strategy = integration_strategy
9
+ self.student_embedding_size = student_embedding_size
10
+ self.cls_hidden_size = cls_hidden_size
modeling_mcqbert.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class MCQBert(BertPreTrainedModel):
2
+ def __init__(self, config: MCQBertConfig):
3
+ super().__init__(config)
4
+ self.student_embedding_layer = torch.nn.Linear(config.student_embedding_size, config.hidden_size)
5
+
6
+ cls_input_dim_multiplier = 2 if config.integration_strategy == "cat" else 1
7
+ cls_input_dim = self.config.hidden_size * cls_input_dim_multiplier
8
+
9
+ self.classifier = torch.nn.Sequential(
10
+ torch.nn.Linear(cls_input_dim, config.cls_hidden_size),
11
+ torch.nn.ReLU(),
12
+ torch.nn.Linear(config.cls_hidden_size, 1)
13
+ )
14
+
15
+ def forward(self, input_ids, student_embeddings=None):
16
+ if self.config.integration_strategy is None:
17
+ # don't consider embeddings is no integration strategy (MCQBert)
18
+ student_embeddings = torch.zeros(self.config.student_embedding_layer)
19
+
20
+ input_embeddings = self.embeddings(input_ids)
21
+ combined_embeddings = input_embeddings + self.student_embedding_layer(student_embeddings).unsqueeze(1).repeat(1, input_embeddings.size(1), 1)
22
+ output = super().forward(inputs_embeds = combined_embeddings)
23
+ return self.classifier(output.last_hidden_state[:, 0, :])
24
+
25
+ elif self.config.integration_strategy == "cat":
26
+ # MCQStudentBertCat
27
+ output = super().forward(input_ids)
28
+ output_with_student_embedding = torch.cat((output.last_hidden_state[:, 0, :], self.student_embedding_layer(student_embeddings)), dim = 1)
29
+ return self.classifier(output_with_student_embedding)
30
+
31
+ elif self.config.integration_strategy == "sum":
32
+ # MCQStudentBertSum
33
+ input_embeddings = self.embeddings(input_ids)
34
+ combined_embeddings = input_embeddings + self.student_embedding_layer(student_embeddings).unsqueeze(1).repeat(1, input_embeddings.size(1), 1)
35
+ output = super().forward(inputs_embeds = combined_embeddings)
36
+ return self.classifier(output.last_hidden_state[:, 0, :])
37
+
38
+ else:
39
+ raise ValueError(f"{self.config.integration_strategy} is not a known integration_strategy")