realfake / realfake /bin /inference.py
devforfu
Init
ea847ad
import random
from pathlib import Path
import torch
from torch.utils.data import DataLoader
from realfake.data import DictDataset, get_augs
from realfake.models import RealFakeClassifier, RealFakeParams
from realfake.utils import Args, inject_args, read_jsonl
class InferenceParams(Args):
checkpoint_path: Path
test_file: Path
map_location: str = "cpu"
num_workers: int = 16
@inject_args
def main(params: InferenceParams) -> None:
checkpoint = torch.load(params.checkpoint_path, map_location=params.map_location)
# todo: use PL mechanism to store hparams
model = RealFakeClassifier(RealFakeParams.parse_file(params.checkpoint_path.parent/"params.json"))
model.load_state_dict(checkpoint["state_dict"])
model.eval()
records = read_jsonl(params.test_file)
for _ in range(10):
selected = random.sample(records, k=1000)
with torch.inference_mode():
ds = DictDataset(selected, get_augs(train=False))
dl = DataLoader(ds, batch_size=32, num_workers=params.num_workers, shuffle=False)
matched, total = 0, len(ds)
for batch in dl:
_, logits, y_true_onehot = model(batch)
y_true = y_true_onehot.argmax(dim=1)
y_pred = logits.softmax(dim=1).argmax(dim=1)
matched += (y_true == y_pred).sum().item()
print(f"Accuracy: {matched/total:2.2%}")
if __name__ == "__main__":
main()