tr-ocr / app.py
triopood's picture
Update app.py
cb825f4 verified
raw
history blame contribute delete
No virus
1.91 kB
import os
from flask import Flask, request, jsonify
from PIL import Image
import torch
from torchvision.transforms import functional as F
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from transformers import AutoModel
from transformers import AutoModel
model_name = "trocrnew.pth"
access_token = os.environ.get("HF_TOKEN")
model = AutoModel.from_pretrained(model_name, token=access_token)
# access_token = os.environ.get("HF_TOKEN")
# model = AutoModel.from_pretrained("trocrnew.pth", token=access_token)
app = Flask(__name__)
# Load the trained model and processor
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-small-printed")
model = VisionEncoderDecoderModel.from_pretrained("trocrnew.pth")
# Set the model in evaluation mode
model.eval()
def ocr(image):
# Preprocess the image
image = F.to_tensor(image).unsqueeze(0)
# Perform OCR
with torch.no_grad():
generated_ids = model.generate(image)
# Decode the generated text
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return generated_text
@app.route('/ocr', methods=['POST'])
def ocr_endpoint():
# Check if the POST request contains a file
if 'file' not in request.files:
return jsonify({'error': 'No file provided'}), 400
file = request.files['file']
# Check if the file has an allowed extension
allowed_extensions = {'png', 'jpg', 'jpeg', 'gif'}
if '.' not in file.filename or file.filename.split('.')[-1].lower() not in allowed_extensions:
return jsonify({'error': 'Invalid file type'}), 400
# Read the image and perform OCR
try:
image = Image.open(file).convert('RGB')
text = ocr(image)
return jsonify({'text': text}), 200
except Exception as e:
return jsonify({'error': str(e)}), 500
if __name__ == '__main__':
app.run(debug=True)