antitheft159 commited on
Commit
45306c3
1 Parent(s): 20b5ed7

Upload 2 files

Browse files
1527_159_252_1434_144.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
1527_159_252_1434_144.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """1527.159.252.1434.144
3
+
4
+ Automatically generated by Colab.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1wIIHKVp7xmSZhl44znoh45yNm8bZPAYi
8
+ """
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.optim as optim
13
+ import numpy as np
14
+
15
+ class DirectionManipulator:
16
+ def __init__(self, latent_dim, target_direction):
17
+ self.latent_dim = latent_dim
18
+ self.target_direction = target_direction / torch.norm(target_direction)
19
+
20
+ self.transform = nn.Parameter(torch.eye(latent_dim))
21
+ self.optimizer = optim.Adam([self.transform], lr=0.001)
22
+
23
+ def get_direction(self, vectors):
24
+ centered = vectors = vectors.mean(dim=0, keepdim=True)
25
+
26
+ U, S, V = torch.svd(centered)
27
+
28
+ return V[:, 0]
29
+
30
+ def transform_vectors(self, vectors):
31
+ return torch.matmul(vectors, self.transform)
32
+
33
+ def compute_loss(self, vectors):
34
+ transformed = self.transform_vectors(vectors)
35
+ current_directoin = self.get_direction(transformed)
36
+
37
+ alignment_loss = -torch.abs(torch.dot(current_direction, self.target_direction))
38
+
39
+ identity = torch.eye(self.latent_dim, device=vectors.device)
40
+ orthogonality_loss = torch.norm(torch.matmul(self.transform, self.transform.t()) - identity)
41
+ return alignment_loss + 0.1 * orthogonality_loss
42
+
43
+ def train_step(self, vectors):
44
+ self.optimizer.zero_grad()
45
+ loss = self.compute_loss(vectors)
46
+ loss.backward()
47
+ self.optimizer.step()
48
+ return loss.item()
49
+
50
+ def fit(self, vectors, n_epochs=100):
51
+ losses = []
52
+ for epoch in range(n_epochs):
53
+ loss = self.train_step(vectors)
54
+ losses.append(loss)
55
+ return losses
56
+
57
+ def main():
58
+ latent_dim = 8
59
+ n_samples = 100
60
+
61
+ vectors = torch.randn(n_samples, latent_dim)
62
+ target_direction = torch.randn(latent_dim)
63
+
64
+ manipulator = DirectionManipulator(latent_dim, target_direction)
65
+ losses = manipulator.fit(vectors)
66
+
67
+ new_vectors = torch.randn(10, latent_dim)
68
+ transformed = manipulator.transform_vectors(new_vectors)
69
+
70
+ return transformed
71
+
72
+ if __name__ == "__main__":
73
+ main()
74
+
75
+ import torch
76
+ import torch.nn as nn
77
+ import torch.optim as optim
78
+ import numpy as np
79
+ import matplotlib.pyplot as plt
80
+ from mpl_toolkits.mplot3d import Axes3D
81
+
82
+ class DirectionManipulator:
83
+ def __init__(self, latent_dim, target_direction):
84
+ self.latent_dim = latent_dim
85
+ self.target_direction = target_direction / torch.norm(target_direction)
86
+ self.transform = nn.Parameter(torch.eye(latent_dim))
87
+ self.optimizer = optim.Adam([self.transform], lr=0.001)
88
+
89
+ def get_direction(self, vectors):
90
+ centered = vectors - vectors.mean(dim=0, keepdim=True)
91
+ U, S, V = torch.svd(centered)
92
+ return V[:, 0]
93
+
94
+ def transform_vectors(self, vectors):
95
+ return torch.matmul(vectors, self.transform)
96
+
97
+ def compute_loss(self, vectors):
98
+ transformed = self.transform_vectors(vectors)
99
+ current_direction = self.get_direction(transformed)
100
+ alignment_loss = -torch.abs(torch.dot(current_direction, self.target_direction))
101
+
102
+ identity = torch.eye(self.latent_dim, device=vectors.device)
103
+ orthogonality_loss = torch.norm(torch.matmul(self.transform, self.transform.t()) - identity)
104
+ return alignment_loss + 0.1 * orthogonality_loss
105
+
106
+ def train_step(self, vectors):
107
+ self.optimizer.zero_grad()
108
+ loss =self.compute_loss(vectors)
109
+ loss.backward()
110
+ self.optimizer.step()
111
+ return loss.item()
112
+
113
+ def fit(self, vectors, n_epochs=100):
114
+ losses = []
115
+ for epoch in range(n_epochs):
116
+ loss = self.train_step(vectors)
117
+ losses.append(loss)
118
+ return losses
119
+
120
+ def visualize_transformation(original_vectors, transformed_vectors, original_direction, target_direction, title):
121
+ fig = plt.figure(figsize=(15, 5))
122
+
123
+ ax1 = fig.add_subplot(121, projection='3d')
124
+ ax1.scatter(original_vectors[:, 0],
125
+ original_vectors[:, 1],
126
+ original_vectors[:, 2],
127
+ c='blue', alpha=0.6, label='Original points')
128
+
129
+ ax1.quiver(0, 0, 0,
130
+ original_direction[0].item(),
131
+ original_direction[1].item(),
132
+ original_direction[2].item(),
133
+ color='red', linewidth=3, label='Original direction'
134
+ )
135
+
136
+ ax1.set_title('Original Data')
137
+ ax1.set_xlabel('X')
138
+ ax1.set_ylabel('Y')
139
+ ax1.set_zlabel('Z')
140
+ ax1.legend()
141
+
142
+ ax2 = fig.add_subplot(122, projection='3d')
143
+ ax2.scatter(transformed_vectors[:, 0],
144
+ transformed_vectors[:, 1],
145
+ transformed_vectors[:, 2],
146
+ c='green', alpha=0.6, label='Transformed points')
147
+
148
+ ax2.quiver(0, 0, 0,
149
+ target_direction[0].item(),
150
+ target_direction[1].item(),
151
+ target_direction[2].item(),
152
+ color='red', linewidth=3, label='Target direction')
153
+
154
+ ax2.set_title('Transformed Data')
155
+ ax2.set_xlabel('X')
156
+ ax2.set_ylabel('Y')
157
+ ax2.set_zlabel('Z')
158
+ ax2.legend()
159
+
160
+ plt.suptitle(title)
161
+ plt.tight_layout()
162
+ return fig
163
+
164
+ def main():
165
+ torch.manual_seed(42)
166
+ latent_dim = 3
167
+ n_samples =100
168
+
169
+ direction = torch.tensor([1.0, 0.2, 0.1])
170
+ noise = torch.randn(n_samples, latent_dim) * 0.3
171
+ vectors = direction.repeat(n_samples, 1) + noise
172
+
173
+ target_direction = torch.tensor([0.2, 1.0, 0.1])
174
+
175
+ manipulator = DirectionManipulator(latent_dim, target_direction)
176
+ losses = manipulator.fit(vectors, n_epochs=200)
177
+
178
+ transformed_vectors = manipulator.transform_vectors(vectors)
179
+
180
+ original_direction = manipulator.get_direction(vectors)
181
+
182
+ fig = visualize_transformation(
183
+ vectors.detach().numpy(),
184
+ transformed_vectors.detach().numpy(),
185
+ original_direction.detach(),
186
+ target_direction,
187
+ title="Direction Manipulation Visualization"
188
+ )
189
+
190
+ plt.figure(figsize=(10, 4))
191
+ plt.plot(losses)
192
+ plt.title('Training Loss')
193
+ plt.xlabel('Epoch')
194
+ plt.ylabel('Loss')
195
+ plt.grid(True)
196
+ plt.show()
197
+
198
+ return vectors, transformed_vectors, losses
199
+
200
+ if __name__ == "__main__":
201
+ main()
202
+