|
import torch |
|
import os |
|
import numpy as np |
|
|
|
import lightning as pl |
|
from torch.utils.data import TensorDataset, DataLoader |
|
from adv import ValuationModel |
|
from opendataval.dataloader import DataFetcher |
|
|
|
Array = np.ndarray |
|
|
|
def load_model(dataset_name: str) -> ValuationModel: |
|
results = torch.load( |
|
os.path.join("adv", dataset_name, "best_checkpoint.pt") |
|
) |
|
return results |
|
|
|
def onehot_to_location(y: Array) -> Array: |
|
res = np.zeros((y.shape[0]), dtype=int) |
|
for i in range(y.shape[0]): |
|
for j in range(len(y[i])): |
|
if y[i, j] > 0: |
|
res[i] = j |
|
break |
|
return res |
|
|
|
def load_dataset(dataset_name: str, train_size: int, valid_size: int) -> tuple[Array, Array, Array, Array]: |
|
fetcher = DataFetcher(dataset_name=dataset_name).split_dataset_by_count( |
|
train_size, valid_size |
|
) |
|
x_train, y_train, x_val, y_val, _, _ = fetcher.datapoints |
|
x_train = np.array(x_train.tolist()) |
|
y_train = onehot_to_location(np.array(y_train.tolist())) |
|
x_val = np.array(x_val.tolist()) |
|
y_val = onehot_to_location(np.array(y_val.tolist())) |
|
return x_train, y_train, x_val, y_val |
|
|
|
checkpoint = load_model("electricity") |
|
x_train, y_train, *_ = load_dataset("electricity", 100, 0) |
|
tensor_x = torch.Tensor([x_train[24]]) |
|
tensor_y = torch.Tensor([y_train[24]]) |
|
tensor_y = tensor_y.long() |
|
|
|
dataset = TensorDataset(tensor_x, tensor_y) |
|
dataloader = DataLoader(dataset, shuffle=False, drop_last=False, num_workers=0) |
|
trainer = pl.Trainer( |
|
precision="bf16-mixed", accelerator="gpu", devices=[0], enable_progress_bar=False |
|
) |
|
estimates = ( |
|
torch.cat(trainer.predict(checkpoint, dataloader)) |
|
.squeeze() |
|
.cpu() |
|
.float() |
|
) |
|
|
|
print(estimates.numpy()) |