Spaces:
Runtime error
Runtime error
import os | |
import torch | |
from fastapi import FastAPI, File, UploadFile | |
from fastapi.responses import JSONResponse | |
from PIL import Image | |
import io | |
import torchvision.transforms as transforms | |
import torch.nn.functional as F | |
import string | |
import tensorflow as tf | |
import keras | |
# Initialize FastAPI app | |
app = FastAPI() | |
# Load your custom model (replace this with your actual model loading code) | |
# Example: Loading a PyTorch model | |
model = tensorflow.keras.load_model.load("path_to_your_custom_model.pth") | |
model.eval() # Set the model to evaluation mode | |
# Preprocessing function (adjust according to your model's input requirements) | |
transform = transforms.Compose([ | |
transforms.Resize((32, 128)), # Resize image to match your model input size | |
transforms.Grayscale(num_output_channels=1), # If your model takes grayscale images | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.5], std=[0.5]) # Example normalization, adjust accordingly | |
]) | |
# Endpoint to upload an image for OCR processing | |
async def upload_file(file: UploadFile = File(...)): | |
try: | |
# Read the image from the uploaded file | |
image_data = await file.read() | |
image = Image.open(io.BytesIO(image_data)) | |
# Preprocess the image for the model | |
image = transform(image).unsqueeze(0) # Add batch dimension | |
# Perform OCR using the custom model | |
with torch.no_grad(): # Turn off gradients during inference | |
output = model(image) # Get model predictions | |
# Assuming your model outputs the predicted text as a tensor (you may need to decode this output) | |
# Example: decoding the predicted tensor into text | |
predicted_text = decode_output(output) # Replace this with your actual decoding function | |
# Return the OCR result as a JSON response | |
return JSONResponse(content={"extracted_text": predicted_text}) | |
except Exception as e: | |
return JSONResponse(status_code=400, content={"error": str(e)}) | |
def decode_output(output): | |
# Assuming the model outputs logits (unnormalized probabilities) | |
output = F.log_softmax(output, dim=2) | |
output = output.squeeze(0) # Remove batch dimension | |
# Get the predicted characters | |
_, predicted_indices = torch.max(output, dim=1) | |
# Convert indices to characters (assuming your model outputs indices for characters) | |
alphabet = string.ascii_lowercase + string.digits + " " # Modify this as per your model's character set | |
predicted_text = ''.join([alphabet[i] for i in predicted_indices]) | |
return predicted_text | |