IT3100_Group5 / app.py
FIamenova's picture
Update app.py
d946db2 verified
import os
import torch
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import JSONResponse
from PIL import Image
import io
import torchvision.transforms as transforms
import torch.nn.functional as F
import string
import tensorflow as tf
import keras
# Initialize FastAPI app
app = FastAPI()
# Load your custom model (replace this with your actual model loading code)
# Example: Loading a PyTorch model
model = tensorflow.keras.load_model.load("path_to_your_custom_model.pth")
model.eval() # Set the model to evaluation mode
# Preprocessing function (adjust according to your model's input requirements)
transform = transforms.Compose([
transforms.Resize((32, 128)), # Resize image to match your model input size
transforms.Grayscale(num_output_channels=1), # If your model takes grayscale images
transforms.ToTensor(),
transforms.Normalize(mean=[0.5], std=[0.5]) # Example normalization, adjust accordingly
])
# Endpoint to upload an image for OCR processing
@app.post("/uploadfile/")
async def upload_file(file: UploadFile = File(...)):
try:
# Read the image from the uploaded file
image_data = await file.read()
image = Image.open(io.BytesIO(image_data))
# Preprocess the image for the model
image = transform(image).unsqueeze(0) # Add batch dimension
# Perform OCR using the custom model
with torch.no_grad(): # Turn off gradients during inference
output = model(image) # Get model predictions
# Assuming your model outputs the predicted text as a tensor (you may need to decode this output)
# Example: decoding the predicted tensor into text
predicted_text = decode_output(output) # Replace this with your actual decoding function
# Return the OCR result as a JSON response
return JSONResponse(content={"extracted_text": predicted_text})
except Exception as e:
return JSONResponse(status_code=400, content={"error": str(e)})
def decode_output(output):
# Assuming the model outputs logits (unnormalized probabilities)
output = F.log_softmax(output, dim=2)
output = output.squeeze(0) # Remove batch dimension
# Get the predicted characters
_, predicted_indices = torch.max(output, dim=1)
# Convert indices to characters (assuming your model outputs indices for characters)
alphabet = string.ascii_lowercase + string.digits + " " # Modify this as per your model's character set
predicted_text = ''.join([alphabet[i] for i in predicted_indices])
return predicted_text