Prositron commited on
Commit
f6f9962
·
verified ·
1 Parent(s): 31275b5

Update tensor_network.py

Browse files
Files changed (1) hide show
  1. tensor_network.py +67 -81
tensor_network.py CHANGED
@@ -1,95 +1,81 @@
1
  import torch
2
  import torch.nn as nn
3
 
4
- # Define an enhanced neural network model with more layers and self-attention
5
- class ComplexModel(nn.Module):
6
- def __init__(self):
7
- super(ComplexModel, self).__init__()
8
- # First convolutional layer: input channels=3, output channels=16
9
- self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
10
- self.bn1 = nn.BatchNorm2d(16)
11
-
12
- # Second convolutional layer: input channels=16, output channels=32
13
- self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
14
- self.bn2 = nn.BatchNorm2d(32)
15
-
16
- # Max pooling to reduce spatial dimensions by a factor of 2
17
- self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
18
-
19
- # Third convolutional layer: input channels=32, output channels=64
20
- self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
21
- self.bn3 = nn.BatchNorm2d(64)
22
-
23
- # Self-attention layer:
24
- # After conv3, the feature map is expected to be of shape [batch, 64, 2, 2].
25
- # We treat the spatial dimensions (2x2=4 tokens) as the sequence length.
26
- # For nn.MultiheadAttention, the embed dimension is 64.
27
- self.attention = nn.MultiheadAttention(embed_dim=64, num_heads=4)
28
 
29
- # Fully connected layers:
30
- # After conv3 and attention, the tensor shape remains [batch, 64, 2, 2],
31
- # so the flattened feature size is 64 * 2 * 2 = 256.
32
- self.fc1 = nn.Linear(64 * 2 * 2, 128)
33
- self.fc2 = nn.Linear(128, 10) # For example, output layer with 10 classes
 
 
 
34
 
35
  def forward(self, x):
36
- # First conv layer with batch normalization and ReLU activation
37
- x = self.conv1(x)
38
- x = self.bn1(x)
39
- x = torch.relu(x)
40
 
41
- # Second conv layer with batch normalization and ReLU activation
42
- x = self.conv2(x)
43
- x = self.bn2(x)
44
- x = torch.relu(x)
45
 
46
- # Pooling to reduce spatial dimensions
47
- x = self.pool(x)
48
-
49
- # Third conv layer with batch normalization and ReLU activation
50
- x = self.conv3(x)
51
- x = self.bn3(x)
52
- x = torch.relu(x)
53
-
54
- # --------- Self-Attention Block ---------
55
- # x shape: [batch_size, channels=64, height=2, width=2]
56
- batch, channels, height, width = x.shape
57
- # Flatten spatial dimensions: create a sequence of tokens.
58
- # New shape: [batch_size, channels, sequence_length] where sequence_length = height * width (4 tokens)
59
- x_flat = x.view(batch, channels, height * width) # Shape: [B, 64, 4]
60
- # Permute to match nn.MultiheadAttention input: [sequence_length, batch_size, embed_dim]
61
- x_flat = x_flat.permute(2, 0, 1) # Shape: [4, B, 64]
62
-
63
- # Apply self-attention (keys, queries, and values are all x_flat)
64
- attn_output, _ = self.attention(x_flat, x_flat, x_flat)
65
- # attn_output shape remains: [4, B, 64]
66
-
67
- # Permute back to [batch_size, channels, sequence_length]
68
- x_flat = attn_output.permute(1, 2, 0) # Shape: [B, 64, 4]
69
- # Reshape back to spatial dimensions: [B, 64, 2, 2]
70
- x = x_flat.view(batch, channels, height, width)
71
- # --------- End Self-Attention Block ---------
72
-
73
- # Flatten the tensor for the fully connected layers
74
- x = x.view(x.size(0), -1) # Flatten to [batch, 256]
75
- x = self.fc1(x)
76
- x = torch.relu(x)
77
- x = self.fc2(x)
78
- return x
79
-
80
- # Example of creating input tensors (each with shape: batch_size=2, channels=3, height=4, width=4)
81
- tensor1 = torch.rand(2, 3, 4, 4)
82
- tensor2 = torch.rand(2, 3, 4, 4)
83
- tensor3 = torch.rand(2, 3, 4, 4)
84
 
85
- # Adding the tensors element-wise to form the input tensor
86
- input_tensor = tensor1 + tensor2 + tensor3
 
 
87
 
88
- # Initialize the enhanced model
89
- model = ComplexModel()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
- # Forward pass through the model
 
 
92
  output = model(input_tensor)
93
 
94
  print("Output shape:", output.shape)
95
- print("Output:", output)
 
1
  import torch
2
  import torch.nn as nn
3
 
4
+ class BrainInspiredTransformer(nn.Module):
5
+ def __init__(self, num_layers=16, embed_dim=7, num_heads=1, num_extra_tokens=16, num_classes=10):
6
+ super(BrainInspiredTransformer, self).__init__()
7
+ self.embed_dim = embed_dim
8
+ self.num_extra_tokens = num_extra_tokens
9
+
10
+ # Project the 3-channel input into a 7-dimensional embedding space.
11
+ self.embedding = nn.Conv2d(3, embed_dim, kernel_size=1)
12
+
13
+ # Learnable extra tokens (to augment the 4x4 grid tokens).
14
+ self.extra_tokens = nn.Parameter(torch.randn(num_extra_tokens, embed_dim))
15
+
16
+ # Build a stack of self-attention layers with layer normalization.
17
+ self.attention_layers = nn.ModuleList([
18
+ nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads)
19
+ for _ in range(num_layers)
20
+ ])
21
+ self.layer_norms = nn.ModuleList([
22
+ nn.LayerNorm(embed_dim)
23
+ for _ in range(num_layers)
24
+ ])
 
 
 
25
 
26
+ # GRU cell for recurrent updating—mimicking working memory or recurrent feedback.
27
+ # It processes each token (with dimension=embed_dim) in a brain-inspired manner.
28
+ self.gru = nn.GRUCell(embed_dim, embed_dim)
29
+
30
+ # Final classification head.
31
+ # We have 16 tokens from the 4x4 grid and num_extra_tokens extra tokens.
32
+ # Flattened feature dimension is (16 + num_extra_tokens) * embed_dim.
33
+ self.fc = nn.Linear((16 + num_extra_tokens) * embed_dim, num_classes)
34
 
35
  def forward(self, x):
36
+ # x: [batch, 3, 4, 4]
37
+ batch_size = x.size(0)
 
 
38
 
39
+ # Embed the input: [batch, 3, 4, 4] -> [batch, embed_dim, 4, 4]
40
+ x = self.embedding(x)
 
 
41
 
42
+ # Flatten spatial dimensions: [batch, embed_dim, 4, 4] -> [batch, embed_dim, 16]
43
+ # Then permute to [sequence_length, batch, embed_dim] for attention.
44
+ x = x.view(batch_size, self.embed_dim, -1).permute(2, 0, 1) # [16, batch, 7]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
+ # Expand and concatenate extra tokens: extra_tokens [num_extra_tokens, embed_dim]
47
+ # becomes [num_extra_tokens, batch, embed_dim] and concatenated along sequence dim.
48
+ extra_tokens = self.extra_tokens.unsqueeze(1).expand(-1, batch_size, -1)
49
+ x = torch.cat([x, extra_tokens], dim=0) # [16 + num_extra_tokens, batch, 7]
50
 
51
+ # Process through the transformer layers with recurrent GRU updates.
52
+ for attn, norm in zip(self.attention_layers, self.layer_norms):
53
+ residual = x
54
+ attn_out, _ = attn(x, x, x)
55
+ # Residual connection and layer normalization.
56
+ x = norm(residual + attn_out)
57
+
58
+ # --- Brain-inspired recurrent update ---
59
+ # Reshape tokens to apply GRUCell in parallel.
60
+ seq_len, batch, embed_dim = x.shape
61
+ x_flat = x.view(seq_len * batch, embed_dim)
62
+ # Use the same x_flat as both input and hidden state.
63
+ x_updated_flat = self.gru(x_flat, x_flat)
64
+ x = x_updated_flat.view(seq_len, batch, embed_dim)
65
+ # --- End recurrent update ---
66
+
67
+ # Rearrange back to [batch, sequence_length, embed_dim] and flatten.
68
+ x = x.permute(1, 0, 2).contiguous()
69
+ x = x.view(batch_size, -1)
70
+
71
+ # Classification head.
72
+ out = self.fc(x)
73
+ return out
74
 
75
+ # Example usage:
76
+ input_tensor = torch.rand(2, 3, 4, 4) # [batch=2, channels=3, height=4, width=4]
77
+ model = BrainInspiredTransformer(num_layers=16, embed_dim=7, num_heads=1, num_extra_tokens=16, num_classes=10)
78
  output = model(input_tensor)
79
 
80
  print("Output shape:", output.shape)
81
+ print("Output:", output)