bhuvaneshprasad commited on
Commit
a2de17e
·
verified ·
1 Parent(s): 65e13be

Update prediction.py

Browse files
Files changed (1) hide show
  1. prediction.py +46 -46
prediction.py CHANGED
@@ -1,47 +1,47 @@
1
- import os
2
- from pathlib import Path
3
- from dotenv import load_dotenv
4
- import numpy as np
5
- import tensorflow as tf
6
-
7
- load_dotenv()
8
-
9
- class PredictionPipeline:
10
- """
11
- A class representing a pipeline for making predictions using a pre-trained model.
12
-
13
- Attributes:
14
- filename (str): The filename of the image to predict.
15
-
16
- Methods:
17
- predict() -> int:
18
- Loads a pre-trained model, processes an image, and predicts its class.
19
- """
20
- def __init__(self,filename):
21
- """
22
- Initialize the PredictionPipeline class.
23
-
24
- Args:
25
- filename (str): The filename of the image to predict.
26
- """
27
- self.filename =filename
28
-
29
- def predict(self) -> int:
30
- """
31
- Perform prediction on the image specified by the filename.
32
-
33
- Returns:
34
- int: The predicted class label.
35
- """
36
- model = tf.keras.models.load_model(os.path.join(os.getcwd(),Path(os.getenv('MODEL_URI'))))
37
- class_labels = ['brightpixel','narrowband',
38
- 'narrowbanddrd','noise',
39
- 'squarepulsednarrowband','squiggle',
40
- 'squigglesquarepulsednarrowband']
41
-
42
- imagename = self.filename
43
- test_image = tf.keras.preprocessing.image.load_img(imagename, target_size = (256,256))
44
- test_image = tf.keras.preprocessing.image.img_to_array(test_image)
45
- test_image = np.expand_dims(test_image, axis = 0)
46
- result = np.argmax(model.predict(test_image), axis=1)
47
  return class_labels[int(result)]
 
1
+ import os
2
+ from pathlib import Path
3
+ from dotenv import load_dotenv
4
+ import numpy as np
5
+ import tensorflow as tf
6
+
7
+ load_dotenv()
8
+
9
+ class PredictionPipeline:
10
+ """
11
+ A class representing a pipeline for making predictions using a pre-trained model.
12
+
13
+ Attributes:
14
+ filename (str): The filename of the image to predict.
15
+
16
+ Methods:
17
+ predict() -> int:
18
+ Loads a pre-trained model, processes an image, and predicts its class.
19
+ """
20
+ def __init__(self,filename):
21
+ """
22
+ Initialize the PredictionPipeline class.
23
+
24
+ Args:
25
+ filename (str): The filename of the image to predict.
26
+ """
27
+ self.filename =filename
28
+
29
+ def predict(self) -> int:
30
+ """
31
+ Perform prediction on the image specified by the filename.
32
+
33
+ Returns:
34
+ int: The predicted class label.
35
+ """
36
+ model = tf.keras.models.load_model("model.keras")
37
+ class_labels = ['brightpixel','narrowband',
38
+ 'narrowbanddrd','noise',
39
+ 'squarepulsednarrowband','squiggle',
40
+ 'squigglesquarepulsednarrowband']
41
+
42
+ imagename = self.filename
43
+ test_image = tf.keras.preprocessing.image.load_img(imagename, target_size = (256,256))
44
+ test_image = tf.keras.preprocessing.image.img_to_array(test_image)
45
+ test_image = np.expand_dims(test_image, axis = 0)
46
+ result = np.argmax(model.predict(test_image), axis=1)
47
  return class_labels[int(result)]