| | import yaml, os, time, wandb, random |
| | import numpy as np |
| | import torch |
| | import torch.optim as optim |
| | import torch.nn.functional as F |
| |
|
| | from dataclasses import dataclass |
| | from collections import defaultdict |
| | from typing import Optional |
| |
|
| | try: |
| | |
| | from mechanism_base import * |
| | from model_base import EmbedMLP |
| | except ImportError: |
| | |
| | from src.mechanism_base import * |
| | from src.model_base import EmbedMLP |
| |
|
| |
|
| | |
| | def set_all_seeds(seed): |
| | """Set all random seeds for reproducibility""" |
| | random.seed(seed) |
| | np.random.seed(seed) |
| | torch.manual_seed(seed) |
| | torch.cuda.manual_seed(seed) |
| | torch.cuda.manual_seed_all(seed) |
| | torch.backends.cudnn.deterministic = True |
| | torch.backends.cudnn.benchmark = False |
| | |
| | def read_config(): |
| | current_dir = os.path.dirname(__file__) |
| | config_path = os.path.join(current_dir, "configs.yaml") |
| | with open(config_path, 'r') as stream: |
| | try: |
| | config = yaml.safe_load(stream) |
| | return config |
| | except yaml.YAMLError as exc: |
| | print(exc) |
| |
|
| | @dataclass |
| | class Config: |
| | def __init__(self, config): |
| | |
| | if not config: |
| | raise ValueError("Configuration dictionary cannot be None or empty.") |
| | |
| | |
| | self._flatten_config(config) |
| | |
| | |
| | if hasattr(self, 'lr') and isinstance(self.lr, str): |
| | self.lr = float(self.lr) |
| | if hasattr(self, 'weight_decay') and isinstance(self.weight_decay, str): |
| | self.weight_decay = float(self.weight_decay) |
| | if hasattr(self, 'stopping_thresh') and isinstance(self.stopping_thresh, str): |
| | self.stopping_thresh = float(self.stopping_thresh) |
| |
|
| | |
| | if not hasattr(self, 'd_vocab') or self.d_vocab is None: |
| | self.d_vocab = self.p |
| | |
| | |
| | if not hasattr(self, 'd_model') or self.d_model is None: |
| | if hasattr(self, 'embed_type') and self.embed_type == 'one_hot': |
| | self.d_model = self.d_vocab |
| | else: |
| | |
| | self.d_model = 128 |
| | |
| | |
| | if hasattr(self, 'seed'): |
| | set_all_seeds(self.seed) |
| | print(f"All random seeds set to: {self.seed}") |
| | |
| | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | print(self.device) |
| |
|
| | def _flatten_config(self, config_dict, parent_key=''): |
| | """Flatten nested configuration dictionary into flat attributes""" |
| | for key, value in config_dict.items(): |
| | if isinstance(value, dict): |
| | |
| | self._flatten_config(value, parent_key) |
| | else: |
| | |
| | setattr(self, key, value) |
| |
|
| | |
| | @property |
| | def random_answers(self): |
| | return np.random.randint(low=0, high=self.p, size=(self.p, self.p)) |
| |
|
| | |
| | @property |
| | def fns_dict(self): |
| | return { |
| | 'add': lambda x, y: (x + y) % self.p, |
| | 'subtract': lambda x, y: (x - y) % self.p, |
| | 'x2xyy2': lambda x, y: (x**2 + x * y + y**2) % self.p, |
| | 'rand': lambda x, y: self.random_answers[x][y] |
| | } |
| |
|
| | |
| | @property |
| | def fn(self): |
| | return self.fns_dict[self.fn_name] |
| |
|
| | |
| | def is_train_is_test(self, train): |
| | '''Creates an array of Boolean indices according to whether each data point is in train or test. |
| | Used to index into the big batch of all possible data''' |
| | |
| | is_train = [] |
| | is_test = [] |
| | |
| | for x in range(self.p): |
| | for y in range(self.p): |
| | if (x, y, 113) in train: |
| | is_train.append(True) |
| | is_test.append(False) |
| | else: |
| | is_train.append(False) |
| | is_test.append(True) |
| | |
| | is_train = np.array(is_train) |
| | is_test = np.array(is_test) |
| | return (is_train, is_test) |
| |
|
| | |
| | def is_it_time_to_save(self, epoch): |
| | return (epoch % self.save_every == 0) |
| |
|
| | |
| | def is_it_time_to_take_metrics(self, epoch): |
| | return epoch % self.take_metrics_every_n_epochs == 0 |
| |
|
| | def update_param(self, param_name, value): |
| | setattr(self, param_name, value) |
| |
|
| | |
| | def gen_train_test(config: Config): |
| | '''Generate train and test split with precomputed labels as tensors''' |
| | num_to_generate = config.p |
| |
|
| | |
| | all_pairs = [] |
| | all_labels = [] |
| | for i in range(num_to_generate): |
| | for j in range(num_to_generate): |
| | all_pairs.append((i, j)) |
| | all_labels.append(config.fn(i, j)) |
| |
|
| | |
| | device = config.device if hasattr(config, 'device') else torch.device('cpu') |
| | data_tensor = torch.tensor(all_pairs, device=device, dtype=torch.long) |
| | labels_tensor = torch.tensor(all_labels, device=device, dtype=torch.long) |
| |
|
| | |
| | random.seed(config.seed) |
| | indices = torch.randperm(len(all_pairs), device=device) |
| |
|
| | data_tensor = data_tensor[indices] |
| | labels_tensor = labels_tensor[indices] |
| |
|
| | |
| | if config.frac_train == 1: |
| | return (data_tensor, labels_tensor), (data_tensor, labels_tensor) |
| |
|
| | div = int(config.frac_train * len(all_pairs)) |
| | train_data = (data_tensor[:div], labels_tensor[:div]) |
| | test_data = (data_tensor[div:], labels_tensor[div:]) |
| |
|
| | return train_data, test_data |
| |
|
| | |
| | |
| |
|
| | |
| | def cross_entropy_high_precision(logits, labels): |
| | |
| | |
| | |
| | |
| | |
| | logprobs = F.log_softmax(logits.to(torch.float32), dim=-1) |
| | prediction_logprobs = torch.gather(logprobs, index=labels[:, None], dim=-1) |
| | loss = -torch.mean(prediction_logprobs) |
| | return loss |
| |
|
| | def full_loss(config : Config, model: EmbedMLP, data): |
| | '''Takes the cross entropy loss of the model on the data''' |
| | |
| | if isinstance(data, tuple) and len(data) == 2: |
| | data_tensor, labels = data |
| | else: |
| | |
| | if not isinstance(data, torch.Tensor): |
| | data_tensor = torch.tensor(data, device=config.device) |
| | elif data.device != config.device: |
| | data_tensor = data.to(config.device) |
| | else: |
| | data_tensor = data |
| | |
| | labels = torch.tensor([config.fn(i, j) for i, j in data_tensor]).to(config.device) |
| |
|
| | |
| | logits = model(data_tensor) |
| | return cross_entropy_high_precision(logits, labels) |
| |
|
| | def acc_rate(logits, labels): |
| | predictions = torch.argmax(logits, dim=1) |
| | correct = (predictions == labels).sum().item() |
| | accuracy = correct / labels.size(0) |
| | return accuracy |
| |
|
| | def acc(config: Config, model: EmbedMLP, data): |
| | '''Compute accuracy of the model on the data''' |
| | |
| | if isinstance(data, tuple) and len(data) == 2: |
| | data_tensor, labels = data |
| | else: |
| | |
| | if not isinstance(data, torch.Tensor): |
| | data_tensor = torch.tensor(data, device=config.device) |
| | elif data.device != config.device: |
| | data_tensor = data.to(config.device) |
| | else: |
| | data_tensor = data |
| | |
| | labels = torch.tensor([config.fn(i, j) for i, j in data_tensor]).to(config.device) |
| |
|
| | logits = model(data_tensor) |
| | predictions = torch.argmax(logits, dim=1) |
| | correct = (predictions == labels).sum().item() |
| | accuracy = correct / labels.size(0) |
| | return accuracy |
| |
|
| |
|