Spaces:
Runtime error
Runtime error
import torch | |
from torch.utils.data import DataLoader | |
from torchvision.datasets import ImageFolder | |
from data import transforms | |
from model import HNet, ResNet18 | |
from tqdm import tqdm | |
import config as CFG | |
from argparse import ArgumentParser | |
def test(model_type): | |
model = None | |
if model_type == "digit": | |
test_ds = ImageFolder(root=CFG.TEST_DIGIT_PATH, transform=transforms["test"]) | |
model = HNet(num_classes=10) | |
model.load_state_dict(torch.load(CFG.BEST_MODEL_DIGIT)) | |
else: | |
test_ds = ImageFolder(root=CFG.TEST_VYANJAN_PATH, transform=transforms["test"]) | |
model = HNet(num_classes=36) | |
model.load_state_dict(torch.load(CFG.BEST_MODEL_VYANJAN)) | |
test_dl = DataLoader(test_ds, batch_size=CFG.BATCH_SIZE) | |
model.eval() | |
running_corrects = 0 | |
with torch.no_grad(): | |
for images, labels in tqdm(test_dl): | |
outputs = model(images) | |
_, preds = torch.max(outputs, 1) | |
running_corrects += torch.sum(preds == labels) | |
print( | |
f"Test Accuracy of [{model_type}] model: {round(running_corrects.item()/len(test_ds) * 100, 3)}%" | |
) | |
if __name__ == "__main__": | |
parser = ArgumentParser(description="Test model for Hindi Character Recognition") | |
parser.add_argument( | |
"--model_type", | |
type=str, | |
help="Type of model (vyanjan/digit)", | |
default="vyanjan", | |
) | |
args = parser.parse_args() | |
test(model_type=args.model_type) | |