cognitivess commited on
Commit
5e68fa9
·
verified ·
1 Parent(s): 0a5a5fb

Update cognitivess_model/modeling_cognitivess.py

Browse files
cognitivess_model/modeling_cognitivess.py CHANGED
@@ -1,14 +1,116 @@
 
 
1
  import torch
2
- from torch import nn
3
  from transformers import PreTrainedModel
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- class CognitivessForCausalLM(PreTrainedModel):
6
- def __init__(self, config):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  super().__init__(config)
8
- self.transformer = nn.Transformer(d_model=config.hidden_size, num_encoder_layers=config.num_hidden_layers)
9
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
- def forward(self, input_ids, attention_mask=None, labels=None):
12
- outputs = self.transformer(input_ids)
13
- logits = self.lm_head(outputs)
14
- return logits
 
1
+ # cognitivess_model/modeling_cognitivess.py
2
+
3
  import torch
4
+ import torch.nn as nn
5
  from transformers import PreTrainedModel
6
+ from .configuration_cognitivess import CognitivessConfig
7
+
8
+ class MultiHeadAttention(nn.Module):
9
+ def __init__(self, hidden_size, num_attention_heads, dropout_prob=0.0):
10
+ super().__init__()
11
+ self.num_attention_heads = num_attention_heads
12
+ self.attention_head_size = hidden_size // num_attention_heads
13
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
14
+
15
+ self.query = nn.Linear(hidden_size, self.all_head_size)
16
+ self.key = nn.Linear(hidden_size, self.all_head_size)
17
+ self.value = nn.Linear(hidden_size, self.all_head_size)
18
+ self.dense = nn.Linear(hidden_size, hidden_size)
19
+
20
+ self.dropout = nn.Dropout(dropout_prob)
21
+
22
+ def forward(self, hidden_states, attention_mask=None):
23
+ batch_size, seq_length, hidden_size = hidden_states.size()
24
+
25
+ query_layer = self.query(hidden_states).view(batch_size, seq_length, self.num_attention_heads, self.attention_head_size).transpose(1, 2)
26
+ key_layer = self.key(hidden_states).view(batch_size, seq_length, self.num_attention_heads, self.attention_head_size).transpose(1, 2)
27
+ value_layer = self.value(hidden_states).view(batch_size, seq_length, self.num_attention_heads, self.attention_head_size).transpose(1, 2)
28
+
29
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
30
+ attention_scores = attention_scores / torch.sqrt(torch.tensor(self.attention_head_size, dtype=torch.float32))
31
+
32
+ if attention_mask is not None:
33
+ attention_scores = attention_scores + attention_mask
34
+
35
+ attention_probs = torch.nn.functional.softmax(attention_scores, dim=-1)
36
+ attention_probs = self.dropout(attention_probs)
37
+
38
+ context_layer = torch.matmul(attention_probs, value_layer).transpose(1, 2).contiguous().view(batch_size, seq_length, hidden_size)
39
+ output_layer = self.dense(context_layer)
40
+
41
+ return output_layer
42
+
43
+ class FeedForward(nn.Module):
44
+ def __init__(self, hidden_size, intermediate_size, hidden_act, mlp_bias):
45
+ super().__init__()
46
+ self.dense = nn.Linear(hidden_size, intermediate_size, bias=mlp_bias)
47
+ self.activation = nn.SiLU() if hidden_act == "silu" else nn.ReLU()
48
+ self.output = nn.Linear(intermediate_size, hidden_size, bias=mlp_bias)
49
 
50
+ def forward(self, hidden_states):
51
+ hidden_states = self.dense(hidden_states)
52
+ hidden_states = self.activation(hidden_states)
53
+ hidden_states = self.output(hidden_states)
54
+ return hidden_states
55
+
56
+ class TransformerBlock(nn.Module):
57
+ def __init__(self, hidden_size, num_attention_heads, intermediate_size, hidden_act, layer_norm_eps, mlp_bias, attention_dropout):
58
+ super().__init__()
59
+ self.attention = MultiHeadAttention(hidden_size, num_attention_heads, dropout_prob=attention_dropout)
60
+ self.feed_forward = FeedForward(hidden_size, intermediate_size, hidden_act, mlp_bias)
61
+ self.layer_norm1 = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
62
+ self.layer_norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
63
+ self.dropout = nn.Dropout(0.1)
64
+
65
+ def forward(self, hidden_states, attention_mask=None):
66
+ # Attention
67
+ attention_output = self.attention(hidden_states, attention_mask)
68
+ hidden_states = self.layer_norm1(hidden_states + attention_output)
69
+
70
+ # Feed Forward
71
+ feed_forward_output = self.feed_forward(hidden_states)
72
+ hidden_states = self.layer_norm2(hidden_states + feed_forward_output)
73
+
74
+ return hidden_states
75
+
76
+ class CognitivessModel(PreTrainedModel):
77
+ def __init__(self, config: CognitivessConfig):
78
  super().__init__(config)
79
+ self.config = config
80
+
81
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
82
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
83
+ self.layers = nn.ModuleList([
84
+ TransformerBlock(
85
+ config.hidden_size,
86
+ config.num_attention_heads,
87
+ config.intermediate_size,
88
+ config.hidden_act,
89
+ config.layer_norm_eps,
90
+ config.mlp_bias,
91
+ config.attention_dropout
92
+ )
93
+ for _ in range(config.num_hidden_layers)
94
+ ])
95
+ self.pooler = nn.Linear(config.hidden_size, config.hidden_size)
96
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
97
+ self.activation = nn.SiLU() if config.hidden_act == "silu" else nn.ReLU()
98
+
99
+ def forward(self, input_ids, attention_mask=None):
100
+ # Embeddings
101
+ embeddings = self.embeddings(input_ids)
102
+ position_ids = torch.arange(input_ids.size(1), dtype=torch.long, device=input_ids.device)
103
+ position_embeddings = self.position_embeddings(position_ids)
104
+ embeddings = embeddings + position_embeddings
105
+
106
+ # Transformer Layers
107
+ hidden_states = embeddings
108
+ for layer in self.layers:
109
+ hidden_states = layer(hidden_states, attention_mask)
110
+
111
+ # Pooler
112
+ pooled_output = self.pooler(hidden_states[:, 0])
113
+ return pooled_output
114
 
115
+ # Define CognitivessForCausalLM as an alias to CognitivessModel if needed
116
+ CognitivessForCausalLM = CognitivessModel