Percy98's picture
add models and scripts
94f77f5 verified
import torch
import os
import numpy as np
import lightning as pl
from torch.utils.data import TensorDataset, DataLoader
from adv import ValuationModel
from opendataval.dataloader import DataFetcher
Array = np.ndarray
def load_model(dataset_name: str) -> ValuationModel:
results = torch.load(
os.path.join("adv", dataset_name, "best_checkpoint.pt")
)
return results
def onehot_to_location(y: Array) -> Array:
res = np.zeros((y.shape[0]), dtype=int)
for i in range(y.shape[0]):
for j in range(len(y[i])):
if y[i, j] > 0:
res[i] = j
break
return res
def load_dataset(dataset_name: str, train_size: int, valid_size: int) -> tuple[Array, Array, Array, Array]:
fetcher = DataFetcher(dataset_name=dataset_name).split_dataset_by_count(
train_size, valid_size
)
x_train, y_train, x_val, y_val, _, _ = fetcher.datapoints
x_train = np.array(x_train.tolist())
y_train = onehot_to_location(np.array(y_train.tolist()))
x_val = np.array(x_val.tolist())
y_val = onehot_to_location(np.array(y_val.tolist()))
return x_train, y_train, x_val, y_val
checkpoint = load_model("electricity")
x_train, y_train, *_ = load_dataset("electricity", 100, 0)
tensor_x = torch.Tensor([x_train[24]])
tensor_y = torch.Tensor([y_train[24]])
tensor_y = tensor_y.long()
dataset = TensorDataset(tensor_x, tensor_y)
dataloader = DataLoader(dataset, shuffle=False, drop_last=False, num_workers=0)
trainer = pl.Trainer(
precision="bf16-mixed", accelerator="gpu", devices=[0], enable_progress_bar=False
)
estimates = (
torch.cat(trainer.predict(checkpoint, dataloader))
.squeeze()
.cpu()
.float()
)
print(estimates.numpy())