| |
|
|
| import pytorch_lightning as pl |
| import torch |
| from sklearn.cluster import KMeans |
| import numpy as np |
|
|
| class RBFNetwork(pl.LightningModule): |
| def __init__( |
| self, |
| current_timestep, |
| next_timestep, |
| n_centers: int = 100, |
| kappa: float = 1.0, |
| lr=1e-2, |
| datamodule=None, |
| image_data=False, |
| args=None |
| ): |
| super().__init__() |
| self.K = n_centers |
| self.current_timestep = current_timestep |
| self.next_timestep = next_timestep |
| self.clustering_model = KMeans(n_clusters=self.K) |
| self.kappa = kappa |
| self.last_val_loss = 1 |
| self.lr = lr |
| self.W = torch.nn.Parameter(torch.rand(self.K, 1)) |
| self.datamodule = datamodule |
| self.image_data = image_data |
| self.args = args |
|
|
| def on_before_zero_grad(self, *args, **kwargs): |
| self.W.data = torch.clamp(self.W.data, min=0.0001) |
|
|
| def on_train_start(self): |
| with torch.no_grad(): |
| |
| batch = next(iter(self.trainer.datamodule.train_dataloader())) |
| |
| metric_samples = batch[0]["metric_samples"][0] |
| all_data = torch.cat(metric_samples) |
| data_to_fit = all_data |
|
|
| print("Fitting Clustering model...") |
| self.clustering_model.fit(data_to_fit) |
|
|
| clusters = ( |
| self.calculate_centroids(all_data, self.clustering_model.labels_) |
| if self.image_data |
| else self.clustering_model.cluster_centers_ |
| ) |
|
|
| self.C = torch.tensor(clusters, dtype=torch.float32).to(self.device) |
| labels = self.clustering_model.labels_ |
| sigmas = np.zeros((self.K, 1)) |
|
|
| for k in range(self.K): |
| points = all_data[labels == k, :] |
| variance = ((points - clusters[k]) ** 2).mean(axis=0) |
| sigmas[k, :] = np.sqrt( |
| variance.sum() if self.image_data else variance.mean() |
| ) |
|
|
| |
| sigmas = np.maximum(sigmas, 1e-6) |
| |
| self.lamda = torch.tensor( |
| 0.5 / (self.kappa * sigmas) ** 2, dtype=torch.float32 |
| ).to(self.device) |
|
|
| def forward(self, x): |
| if len(x.shape) > 2: |
| x = x.reshape(x.shape[0], -1).to(self.C.device) |
| |
| x = x.to(self.C.device) |
| dist2 = torch.cdist(x, self.C) ** 2 |
| self.phi_x = torch.exp(-0.5 * self.lamda[None, :, :] * dist2[:, :, None]) |
| |
| h_x = (self.W.to(x.device) * self.phi_x).sum(dim=1) |
| |
| return h_x |
|
|
| def training_step(self, batch, batch_idx): |
| if self.args.data_type == "scrna" or self.args.data_type == "tahoe": |
| main_batch = batch[0]["train_samples"][0] |
| else: |
| main_batch = batch["train_samples"][0] |
|
|
| x0 = main_batch["x0"][0] |
| if self.args.branches == 1: |
| x1 = main_batch["x1"][0] |
| inputs = torch.cat([x0, x1], dim=0).to(self.device) |
| else: |
| x1_1 = main_batch["x1_1"][0] |
| x1_2 = main_batch["x1_2"][0] |
| |
| inputs = torch.cat([x0, x1_1, x1_2], dim=0).to(self.device) |
| print("inputs shape") |
| print(inputs.shape) |
| |
| loss = ((1 - self.forward(inputs)) ** 2).mean() |
| self.log( |
| "MetricModel/train_loss_learn_metric", |
| loss, |
| on_step=True, |
| on_epoch=True, |
| prog_bar=True, |
| ) |
| return loss |
|
|
| def validation_step(self, batch, batch_idx): |
| if self.args.data_type == "scrna" or self.args.data_type == "tahoe": |
| main_batch = batch[0]["val_samples"][0] |
| else: |
| main_batch = batch["val_samples"][0] |
|
|
| x0 = main_batch["x0"][0] |
| if self.args.branches == 1: |
| x1 = main_batch["x1"][0] |
| inputs = torch.cat([x0, x1], dim=0).to(self.device) |
| else: |
| x1_1 = main_batch["x1_1"][0] |
| x1_2 = main_batch["x1_2"][0] |
| |
| inputs = torch.cat([x0, x1_1, x1_2], dim=0).to(self.device) |
|
|
| h = self.forward(inputs) |
| |
| loss = ((1 - h) ** 2).mean() |
| self.log( |
| "MetricModel/val_loss_learn_metric", |
| loss, |
| on_step=True, |
| on_epoch=True, |
| prog_bar=True, |
| ) |
| self.last_val_loss = loss.detach() |
| return loss |
|
|
| def calculate_centroids(self, all_data, labels): |
| unique_labels = np.unique(labels) |
| centroids = np.zeros((len(unique_labels), all_data.shape[1])) |
| for i, label in enumerate(unique_labels): |
| centroids[i] = all_data[labels == label].mean(axis=0) |
| return centroids |
|
|
| def configure_optimizers(self): |
| optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) |
| return optimizer |
|
|
| def compute_metric(self, x, alpha=1, epsilon=1e-2, image_hx=False): |
| if epsilon < 0: |
| epsilon = (1 - float(self.last_val_loss)) / abs(epsilon) |
| h_x = self.forward(x) |
| if image_hx: |
| h_x = 1 - torch.abs(1 - h_x) |
| M_x = 1 / (h_x**alpha + epsilon) |
| else: |
| M_x = 1 / (h_x + epsilon) ** alpha |
| return M_x |