DocQA / classification.py
chandan06's picture
Update classification.py
ff6fbb2 verified
raw
history blame
No virus
2.7 kB
import numpy as np
import time
from tensorflow.keras.preprocessing import image
# from tensorflow.keras.preprocessing.image import ImageDataGenerator
import tensorflow as tf
import streamlit as st
# with tf.device('/cpu:0'):
# Load the saved model
model = tf.keras.models.load_model('best_resnet152_model.h5')
class_names = {0: '1099_Div', 1: '1099_Int', 2: 'Non_Form', 3: 'w_2', 4: 'w_3'}
# print(class_names)
# Load and preprocess the image
# img_path = '/app/filled_form_1.jpg'
@st.cache_resource
def predict(pil_img):
# Convert the PIL image to a NumPy array
img_array = image.img_to_array(pil_img)
img_array = np.expand_dims(img_array, axis=0)
img_array /= 255.0 # Rescale pixel values
# Predict the class
start_time = time.time()
predictions = model.predict(img_array)
end_time = time.time()
predicted_class_index = np.argmax(predictions, axis=1)[0]
# Get the predicted class name
predicted_class_name = class_names[predicted_class_index]
print("Predicted class:", predicted_class_name)
# print("Execution time: ", end_time - start_time)
return predicted_class_name
# import numpy as np
# import time
# from PIL import Image # Import for PIL image handling
# from torchvision import transforms # Import for image preprocessing
# import torch
# import torch.nn as nn # Import for PyTorch neural networks
# import streamlit as st
# # Load the PyTorch model (assuming it's saved in PyTorch format)
# model = torch.load('./best_resnet152_model.pt') # Replace with your model filename
# # Define class names dictionary
# class_names = {0: '1099_Div', 1: '1099_Int', 2: 'Non_Form', 3: 'w_2', 4: 'w_3'}
# # Define a function for prediction using PyTorch
# @st.cache_resource
# def predict(pil_img):
# # Preprocess the image
# preprocess = transforms.Compose([
# transforms.ToTensor(), # Convert to PyTorch tensor
# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # Normalize based on ImageNet statistics
# ])
# img_tensor = preprocess(pil_img)
# img_tensor.unsqueeze_(0) # Add batch dimension
# # Predict with PyTorch
# start_time = time.time()
# with torch.no_grad(): # Disable gradient calculation for prediction
# predictions = model(img_tensor)
# end_time = time.time()
# # Get the predicted class
# predicted_class_index = torch.argmax(predictions, dim=1).item()
# predicted_class_name = class_names[predicted_class_index]
# # Print results (optional for debugging)
# print("Predicted class:", predicted_class_name)
# print("Execution time: ", end_time - start_time)
# return predicted_class_name