File size: 5,055 Bytes
c5d3e8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import torch
from typing import List, Tuple


# def vec(K):
#     return K.T.flatten().reshape(-1, 1)

# def rebuild(K, r1, r2):
#     """
#     Implements the R(K) operation from the image.
#     K: input matrix (k x d)
#     r1: block height
#     r2: number of block columns
#     """
#     k, d = K.shape
#     num_block_rows = k // r1
#     num_block_cols = r2
#     bw = d // r2 # block width
    
#     blocks = []
#     # R(K) stacks vec(Ki,j) as columns. 
#     # The image shows column-major order through the blocks.
#     for j in range(num_block_cols):
#         for i in range(num_block_rows):
#             # Extract block Ki,j
#             Ki_j = K[i*r1:(i+1)*r1, j*bw:(j+1)*bw]
#             # Vectorize (column-major) and add to list
#             blocks.append(vec(Ki_j))
    
#     return torch.hstack(blocks)


def rebuild(K, r1, r2):
    k, d = K.shape
    Br = k // r1      # number of block rows
    bw = d // r2      # block width

    # Step 1: reshape to (Br, r1, r2, bw)
    K_view = K.view(Br, r1, r2, bw)

    # Step 2: we want to vectorize each (r1, bw) block in COLUMN-MAJOR order.
    # That is equivalent to transposing the block and flattening in row-major.
    # So we permute to (Br, r2, bw, r1) and then flatten last two dims.
    
    # But better: move r1 and bw to end, then transpose those two
    # Actually: to get column-major flatten of (r1, bw), we can do:
    #   block.transpose(-2, -1).contiguous().view(-1)
    # So let's transpose the last two dims of the block

    # Current: (Br, r1, r2, bw) → we want to treat (r1, bw) as block → transpose to (bw, r1)
    # So permute to (Br, r2, bw, r1)
    K_transposed_blocks = K_view.permute(0, 2, 3, 1)  # (Br, r2, bw, r1)

    # Now flatten the last two dims (bw, r1) → (bw * r1,) → this is column-major of original block
    vecs = K_transposed_blocks.reshape(Br, r2, bw * r1)  # (Br, r2, vec_len)

    # Now, we have vecs[i, j] = vectorized block (i,j)
    # But we want to output columns in order: j=0: i=0,1,...,Br-1; j=1: i=0,...
    # So we need to **transpose the first two dimensions** and then **flatten in row-major**
    
    # Transpose to (r2, Br, vec_len)
    vecs = vecs.permute(1, 0, 2)  # (r2, Br, vec_len)

    # Now flatten first two dims: (r2*Br, vec_len), then transpose to (vec_len, r2*Br)
    result = vecs.reshape(r2 * Br, -1).t()

    return result

# def rebuild(grad, block_size: [int, int]):
#     new_matrix_rows = []
#     if grad.dim() == 2:  # 只处理二维梯度(矩阵)
#         # 获取梯度矩阵的尺寸
#         rows, cols = grad.size()
#         # 遍历分块
#         for j in range(0, cols, block_size[1]):
#             for i in range(0, rows, block_size[0]):
#                 # 获取当前块
#                 block = grad[i:i + block_size[0], j:j + block_size[1]]
#                 # 如果块的大小不足,填充零
#                 if block.size(0) < block_size[0] or block.size(1) < block_size[1]:
#                     padding = (
#                         0, block_size[1] - block.size(1),  # 列填充
#                         0, block_size[0] - block.size(0)  # 行填充
#                     )
#                     block = torch.nn.functional.pad(block, padding, "constant", 0)
#                 # 向量化并添加到新矩阵的行中
#                 new_matrix_rows.append(block.T.flatten())

#         # 将所有行堆叠成一个新矩阵
#     if new_matrix_rows:  # 如果有数据
#         new_gad = torch.stack(new_matrix_rows)
#     else:
#         new_gad = torch.empty(0)  # 如果没有梯度数据,返回空矩阵

#     return new_gad

# if __name__ == "__main__":
#     # 定义一个简单的模型
#     class SimpleModel(torch.nn.Module):
#         def __init__(self):
#             super(SimpleModel, self).__init__()
#             self.fc1 = torch.nn.Linear(10, 20)
#             self.fc2 = torch.nn.Linear(20, 10)

#         def forward(self, x):
#             x = self.fc1(x)
#             x = self.fc2(x)
#             return x

#     # 初始化模型和损失函数
#     model = SimpleModel()
#     criterion = torch.nn.MSELoss()
#     optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

#     # 模拟输入数据
#     inputs = torch.randn(5, 10)  # batch_size=5, input_size=10
#     targets = torch.randn(5, 10)  # batch_size=5, output_size=10

#     # 前向传播
#     outputs = model(inputs)
#     loss = criterion(outputs, targets)

#     # 反向传播计算梯度
#     loss.backward()

#     # 调用函数将梯度分块并构造新矩阵
#     for param in model.parameters():
#         if param.grad is not None:  # 检查是否有梯度
#             grad = param.grad  # 获取梯度
#             print("旧矩阵的内容:\n", grad)
#             print("新矩阵的形状:", grad.shape)
#             block_size = (2, 2)  # 分块大小为 2x2
#             new_matrix = rebuild(grad, block_size)
#             print("新矩阵的形状:", new_matrix.shape)
#             print("新矩阵的内容:\n", new_matrix)