|
|
|
import random |
|
from pathlib import Path |
|
|
|
import torch |
|
from torch.utils.data import DataLoader |
|
|
|
from realfake.data import DictDataset, get_augs |
|
from realfake.models import RealFakeClassifier, RealFakeParams |
|
from realfake.utils import Args, inject_args, read_jsonl |
|
|
|
|
|
class InferenceParams(Args): |
|
checkpoint_path: Path |
|
test_file: Path |
|
map_location: str = "cpu" |
|
num_workers: int = 16 |
|
|
|
|
|
@inject_args |
|
def main(params: InferenceParams) -> None: |
|
checkpoint = torch.load(params.checkpoint_path, map_location=params.map_location) |
|
|
|
|
|
model = RealFakeClassifier(RealFakeParams.parse_file(params.checkpoint_path.parent/"params.json")) |
|
model.load_state_dict(checkpoint["state_dict"]) |
|
model.eval() |
|
|
|
records = read_jsonl(params.test_file) |
|
|
|
for _ in range(10): |
|
selected = random.sample(records, k=1000) |
|
|
|
with torch.inference_mode(): |
|
ds = DictDataset(selected, get_augs(train=False)) |
|
dl = DataLoader(ds, batch_size=32, num_workers=params.num_workers, shuffle=False) |
|
matched, total = 0, len(ds) |
|
|
|
for batch in dl: |
|
_, logits, y_true_onehot = model(batch) |
|
y_true = y_true_onehot.argmax(dim=1) |
|
y_pred = logits.softmax(dim=1).argmax(dim=1) |
|
matched += (y_true == y_pred).sum().item() |
|
|
|
print(f"Accuracy: {matched/total:2.2%}") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|