File size: 546 Bytes
ab8b628
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import unittest
import os
from image_classification_model.predict import predict
from image_classification_model.utils import ROOT_DIR

DATA_DIR = os.path.join(ROOT_DIR, "tests/data")


class TestPrediction(unittest.TestCase):
    def test_prediction_label_3(self):
        test_image_path = os.path.join(DATA_DIR, "number_3.jpg")
        predicted_label = predict(test_image_path)
        self.assertEqual(
            predicted_label, 3, f"Expected label 3, but got {predicted_label}"
        )


if __name__ == "__main__":
    unittest.main()