param-bharat commited on
Commit
5ba8934
·
verified ·
1 Parent(s): 1610d69

Upload model

Browse files
Files changed (4) hide show
  1. config.json +23 -0
  2. configuration.py +17 -0
  3. model.safetensors +3 -0
  4. modeling.py +128 -0
config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "MultiHeadModel"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration.MultiHeadConfig",
7
+ "AutoModel": "modeling.MultiHeadModel"
8
+ },
9
+ "classifier_dropout": 0.1,
10
+ "encoder_name": "tasksource/deberta-base-long-nli",
11
+ "id2label": {
12
+ "0": "irrelevant",
13
+ "1": "relevant"
14
+ },
15
+ "label2id": {
16
+ "irrelevant": 0,
17
+ "relevant": 1
18
+ },
19
+ "model_type": "multihead",
20
+ "tokenizer_class": "DebertaV2TokenizerFast",
21
+ "torch_dtype": "float32",
22
+ "transformers_version": "4.47.0"
23
+ }
configuration.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class MultiHeadConfig(PretrainedConfig):
4
+ model_type = "multihead"
5
+
6
+ def __init__(
7
+ self,
8
+ encoder_name="microsoft/deberta-v3-small",
9
+ **kwargs
10
+ ):
11
+ self.encoder_name = encoder_name
12
+ self.classifier_dropout = kwargs.get("classifier_dropout", 0.1)
13
+ self.num_labels = kwargs.get("num_labels", 2)
14
+ self.id2label = kwargs.get("id2label", {0: "irrelevant", 1: "relevant"})
15
+ self.label2id = kwargs.get("label2id", {"irrelevant": 0, "relevant": 1})
16
+ self.tokenizer_class = kwargs.get("tokenizer_class", "DebertaV2TokenizerFast")
17
+ super().__init__(**kwargs)
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a7feb761acc294d07e857f52162ab3b73a5f575e9f2f6ab73b1cef39deac6415
3
+ size 735369560
modeling.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import CrossEntropyLoss
4
+ from transformers import PreTrainedModel, AutoModel, AutoConfig
5
+ from transformers.modeling_outputs import ModelOutput
6
+ from dataclasses import dataclass
7
+ from typing import Optional
8
+ from .configuration import MultiHeadConfig
9
+
10
+ @dataclass
11
+ class MultiHeadOutput(ModelOutput):
12
+ loss: Optional[torch.FloatTensor] = None
13
+ doc_logits: torch.FloatTensor = None
14
+ sent_logits: torch.FloatTensor = None
15
+ hidden_states: Optional[torch.FloatTensor] = None
16
+ attentions: Optional[torch.FloatTensor] = None
17
+
18
+ class MultiHeadPreTrainedModel(PreTrainedModel):
19
+ """
20
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.
21
+ """
22
+ config_class = MultiHeadConfig
23
+ base_model_prefix = "multihead"
24
+ supports_gradient_checkpointing = True
25
+
26
+ class MultiHeadModel(MultiHeadPreTrainedModel):
27
+ def __init__(self, config: MultiHeadConfig):
28
+ super().__init__(config)
29
+
30
+ self.encoder = AutoModel.from_pretrained(config.encoder_name)
31
+
32
+ self.classifier_dropout = nn.Dropout(config.classifier_dropout)
33
+ self.doc_classifier = nn.Linear(self.encoder.config.hidden_size, config.num_labels)
34
+ self.sent_classifier = nn.Linear(self.encoder.config.hidden_size, config.num_labels)
35
+
36
+ self.doc_attention = nn.Linear(self.encoder.config.hidden_size, 1)
37
+ self.sent_attention = nn.Linear(self.encoder.config.hidden_size, 1)
38
+
39
+ self.post_init()
40
+
41
+ def attentive_pooling(self, hidden_states, mask, attention_layer, sentence_mode=False):
42
+ if not sentence_mode:
43
+ attention_scores = attention_layer(hidden_states).squeeze(-1)
44
+ attention_scores = attention_scores.masked_fill(~mask, float("-inf"))
45
+ attention_weights = torch.softmax(attention_scores, dim=1)
46
+ pooled_output = torch.bmm(attention_weights.unsqueeze(1), hidden_states)
47
+ return pooled_output.squeeze(1)
48
+ else:
49
+ batch_size, num_sentences, seq_len = mask.size()
50
+ attention_scores = attention_layer(hidden_states).squeeze(-1).unsqueeze(1)
51
+ attention_scores = attention_scores.expand(batch_size, num_sentences, seq_len)
52
+ attention_scores = attention_scores.masked_fill(~mask, float("-inf"))
53
+ attention_weights = torch.softmax(attention_scores, dim=2)
54
+
55
+ pooled_output = torch.bmm(attention_weights, hidden_states)
56
+ return pooled_output
57
+
58
+ def forward(
59
+ self,
60
+ input_ids=None,
61
+ attention_mask=None,
62
+ token_type_ids=None,
63
+ document_labels=None,
64
+ sentence_positions=None,
65
+ sentence_labels=None,
66
+ return_dict=True,
67
+ **kwargs
68
+ ):
69
+ outputs = self.encoder(
70
+ input_ids=input_ids,
71
+ attention_mask=attention_mask,
72
+ token_type_ids=token_type_ids,
73
+ return_dict=True,
74
+ )
75
+ last_hidden_state = outputs.last_hidden_state
76
+
77
+ doc_repr = self.attentive_pooling(
78
+ hidden_states=last_hidden_state,
79
+ mask=attention_mask.bool(),
80
+ attention_layer=self.doc_attention,
81
+ sentence_mode=False
82
+ )
83
+ doc_repr = self.classifier_dropout(doc_repr)
84
+ doc_logits = self.doc_classifier(doc_repr)
85
+
86
+ batch_size, max_sents = sentence_positions.size()
87
+ seq_len = attention_mask.size(1)
88
+
89
+ valid_mask = (sentence_positions != -1)
90
+ safe_positions = sentence_positions.masked_fill(~valid_mask, 0)
91
+
92
+ sentence_tokens_mask = torch.zeros(batch_size, max_sents, seq_len, dtype=torch.bool, device=attention_mask.device)
93
+ batch_idx = torch.arange(batch_size, device=input_ids.device).unsqueeze(1).unsqueeze(2)
94
+ sentence_tokens_mask[batch_idx, torch.arange(max_sents).unsqueeze(0), safe_positions] = valid_mask
95
+
96
+ sent_reprs = self.attentive_pooling(
97
+ hidden_states=last_hidden_state,
98
+ mask=sentence_tokens_mask,
99
+ attention_layer=self.sent_attention,
100
+ sentence_mode=True
101
+ )
102
+ sent_reprs = self.classifier_dropout(sent_reprs)
103
+ sent_logits = self.sent_classifier(sent_reprs)
104
+
105
+ loss = None
106
+ if document_labels is not None:
107
+ doc_loss_fct = CrossEntropyLoss()
108
+ doc_loss = doc_loss_fct(doc_logits, document_labels)
109
+
110
+ if sentence_labels is not None:
111
+ sent_loss_fct = CrossEntropyLoss(ignore_index=-100)
112
+ sent_logits_flat = sent_logits.view(-1, sent_logits.size(-1))
113
+ sentence_labels_flat = sentence_labels.view(-1)
114
+ sent_loss = sent_loss_fct(sent_logits_flat, sentence_labels_flat)
115
+ loss = doc_loss + (2 * sent_loss)
116
+ else:
117
+ loss = doc_loss
118
+
119
+ if not return_dict:
120
+ return (loss, doc_logits, sent_logits)
121
+
122
+ return MultiHeadOutput(
123
+ loss=loss,
124
+ doc_logits=doc_logits,
125
+ sent_logits=sent_logits,
126
+ hidden_states=outputs.hidden_states if hasattr(outputs, "hidden_states") else None,
127
+ attentions=outputs.attentions if hasattr(outputs, "attentions") else None,
128
+ )