1527.159.252.1434.144 / 1527_159_252_1434_144.py
antitheft159's picture
Upload 2 files
45306c3 verified
# -*- coding: utf-8 -*-
"""1527.159.252.1434.144
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1wIIHKVp7xmSZhl44znoh45yNm8bZPAYi
"""
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
class DirectionManipulator:
def __init__(self, latent_dim, target_direction):
self.latent_dim = latent_dim
self.target_direction = target_direction / torch.norm(target_direction)
self.transform = nn.Parameter(torch.eye(latent_dim))
self.optimizer = optim.Adam([self.transform], lr=0.001)
def get_direction(self, vectors):
centered = vectors = vectors.mean(dim=0, keepdim=True)
U, S, V = torch.svd(centered)
return V[:, 0]
def transform_vectors(self, vectors):
return torch.matmul(vectors, self.transform)
def compute_loss(self, vectors):
transformed = self.transform_vectors(vectors)
current_directoin = self.get_direction(transformed)
alignment_loss = -torch.abs(torch.dot(current_direction, self.target_direction))
identity = torch.eye(self.latent_dim, device=vectors.device)
orthogonality_loss = torch.norm(torch.matmul(self.transform, self.transform.t()) - identity)
return alignment_loss + 0.1 * orthogonality_loss
def train_step(self, vectors):
self.optimizer.zero_grad()
loss = self.compute_loss(vectors)
loss.backward()
self.optimizer.step()
return loss.item()
def fit(self, vectors, n_epochs=100):
losses = []
for epoch in range(n_epochs):
loss = self.train_step(vectors)
losses.append(loss)
return losses
def main():
latent_dim = 8
n_samples = 100
vectors = torch.randn(n_samples, latent_dim)
target_direction = torch.randn(latent_dim)
manipulator = DirectionManipulator(latent_dim, target_direction)
losses = manipulator.fit(vectors)
new_vectors = torch.randn(10, latent_dim)
transformed = manipulator.transform_vectors(new_vectors)
return transformed
if __name__ == "__main__":
main()
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
class DirectionManipulator:
def __init__(self, latent_dim, target_direction):
self.latent_dim = latent_dim
self.target_direction = target_direction / torch.norm(target_direction)
self.transform = nn.Parameter(torch.eye(latent_dim))
self.optimizer = optim.Adam([self.transform], lr=0.001)
def get_direction(self, vectors):
centered = vectors - vectors.mean(dim=0, keepdim=True)
U, S, V = torch.svd(centered)
return V[:, 0]
def transform_vectors(self, vectors):
return torch.matmul(vectors, self.transform)
def compute_loss(self, vectors):
transformed = self.transform_vectors(vectors)
current_direction = self.get_direction(transformed)
alignment_loss = -torch.abs(torch.dot(current_direction, self.target_direction))
identity = torch.eye(self.latent_dim, device=vectors.device)
orthogonality_loss = torch.norm(torch.matmul(self.transform, self.transform.t()) - identity)
return alignment_loss + 0.1 * orthogonality_loss
def train_step(self, vectors):
self.optimizer.zero_grad()
loss =self.compute_loss(vectors)
loss.backward()
self.optimizer.step()
return loss.item()
def fit(self, vectors, n_epochs=100):
losses = []
for epoch in range(n_epochs):
loss = self.train_step(vectors)
losses.append(loss)
return losses
def visualize_transformation(original_vectors, transformed_vectors, original_direction, target_direction, title):
fig = plt.figure(figsize=(15, 5))
ax1 = fig.add_subplot(121, projection='3d')
ax1.scatter(original_vectors[:, 0],
original_vectors[:, 1],
original_vectors[:, 2],
c='blue', alpha=0.6, label='Original points')
ax1.quiver(0, 0, 0,
original_direction[0].item(),
original_direction[1].item(),
original_direction[2].item(),
color='red', linewidth=3, label='Original direction'
)
ax1.set_title('Original Data')
ax1.set_xlabel('X')
ax1.set_ylabel('Y')
ax1.set_zlabel('Z')
ax1.legend()
ax2 = fig.add_subplot(122, projection='3d')
ax2.scatter(transformed_vectors[:, 0],
transformed_vectors[:, 1],
transformed_vectors[:, 2],
c='green', alpha=0.6, label='Transformed points')
ax2.quiver(0, 0, 0,
target_direction[0].item(),
target_direction[1].item(),
target_direction[2].item(),
color='red', linewidth=3, label='Target direction')
ax2.set_title('Transformed Data')
ax2.set_xlabel('X')
ax2.set_ylabel('Y')
ax2.set_zlabel('Z')
ax2.legend()
plt.suptitle(title)
plt.tight_layout()
return fig
def main():
torch.manual_seed(42)
latent_dim = 3
n_samples =100
direction = torch.tensor([1.0, 0.2, 0.1])
noise = torch.randn(n_samples, latent_dim) * 0.3
vectors = direction.repeat(n_samples, 1) + noise
target_direction = torch.tensor([0.2, 1.0, 0.1])
manipulator = DirectionManipulator(latent_dim, target_direction)
losses = manipulator.fit(vectors, n_epochs=200)
transformed_vectors = manipulator.transform_vectors(vectors)
original_direction = manipulator.get_direction(vectors)
fig = visualize_transformation(
vectors.detach().numpy(),
transformed_vectors.detach().numpy(),
original_direction.detach(),
target_direction,
title="Direction Manipulation Visualization"
)
plt.figure(figsize=(10, 4))
plt.plot(losses)
plt.title('Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(True)
plt.show()
return vectors, transformed_vectors, losses
if __name__ == "__main__":
main()