import random from pathlib import Path import torch from torch.utils.data import DataLoader from realfake.data import DictDataset, get_augs from realfake.models import RealFakeClassifier, RealFakeParams from realfake.utils import Args, inject_args, read_jsonl class InferenceParams(Args): checkpoint_path: Path test_file: Path map_location: str = "cpu" num_workers: int = 16 @inject_args def main(params: InferenceParams) -> None: checkpoint = torch.load(params.checkpoint_path, map_location=params.map_location) # todo: use PL mechanism to store hparams model = RealFakeClassifier(RealFakeParams.parse_file(params.checkpoint_path.parent/"params.json")) model.load_state_dict(checkpoint["state_dict"]) model.eval() records = read_jsonl(params.test_file) for _ in range(10): selected = random.sample(records, k=1000) with torch.inference_mode(): ds = DictDataset(selected, get_augs(train=False)) dl = DataLoader(ds, batch_size=32, num_workers=params.num_workers, shuffle=False) matched, total = 0, len(ds) for batch in dl: _, logits, y_true_onehot = model(batch) y_true = y_true_onehot.argmax(dim=1) y_pred = logits.softmax(dim=1).argmax(dim=1) matched += (y_true == y_pred).sum().item() print(f"Accuracy: {matched/total:2.2%}") if __name__ == "__main__": main()