import torch import os import numpy as np import lightning as pl from torch.utils.data import TensorDataset, DataLoader from adv import ValuationModel from opendataval.dataloader import DataFetcher Array = np.ndarray def load_model(dataset_name: str) -> ValuationModel: results = torch.load( os.path.join("adv", dataset_name, "best_checkpoint.pt") ) return results def onehot_to_location(y: Array) -> Array: res = np.zeros((y.shape[0]), dtype=int) for i in range(y.shape[0]): for j in range(len(y[i])): if y[i, j] > 0: res[i] = j break return res def load_dataset(dataset_name: str, train_size: int, valid_size: int) -> tuple[Array, Array, Array, Array]: fetcher = DataFetcher(dataset_name=dataset_name).split_dataset_by_count( train_size, valid_size ) x_train, y_train, x_val, y_val, _, _ = fetcher.datapoints x_train = np.array(x_train.tolist()) y_train = onehot_to_location(np.array(y_train.tolist())) x_val = np.array(x_val.tolist()) y_val = onehot_to_location(np.array(y_val.tolist())) return x_train, y_train, x_val, y_val checkpoint = load_model("electricity") x_train, y_train, *_ = load_dataset("electricity", 100, 0) tensor_x = torch.Tensor([x_train[24]]) tensor_y = torch.Tensor([y_train[24]]) tensor_y = tensor_y.long() dataset = TensorDataset(tensor_x, tensor_y) dataloader = DataLoader(dataset, shuffle=False, drop_last=False, num_workers=0) trainer = pl.Trainer( precision="bf16-mixed", accelerator="gpu", devices=[0], enable_progress_bar=False ) estimates = ( torch.cat(trainer.predict(checkpoint, dataloader)) .squeeze() .cpu() .float() ) print(estimates.numpy())