File size: 1,914 Bytes
941e4bd 4dcc6ef 405249a 4dcc6ef cb825f4 405249a cb825f4 405249a 4dcc6ef |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
import os
from flask import Flask, request, jsonify
from PIL import Image
import torch
from torchvision.transforms import functional as F
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from transformers import AutoModel
from transformers import AutoModel
model_name = "trocrnew.pth"
access_token = os.environ.get("HF_TOKEN")
model = AutoModel.from_pretrained(model_name, token=access_token)
# access_token = os.environ.get("HF_TOKEN")
# model = AutoModel.from_pretrained("trocrnew.pth", token=access_token)
app = Flask(__name__)
# Load the trained model and processor
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-small-printed")
model = VisionEncoderDecoderModel.from_pretrained("trocrnew.pth")
# Set the model in evaluation mode
model.eval()
def ocr(image):
# Preprocess the image
image = F.to_tensor(image).unsqueeze(0)
# Perform OCR
with torch.no_grad():
generated_ids = model.generate(image)
# Decode the generated text
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return generated_text
@app.route('/ocr', methods=['POST'])
def ocr_endpoint():
# Check if the POST request contains a file
if 'file' not in request.files:
return jsonify({'error': 'No file provided'}), 400
file = request.files['file']
# Check if the file has an allowed extension
allowed_extensions = {'png', 'jpg', 'jpeg', 'gif'}
if '.' not in file.filename or file.filename.split('.')[-1].lower() not in allowed_extensions:
return jsonify({'error': 'Invalid file type'}), 400
# Read the image and perform OCR
try:
image = Image.open(file).convert('RGB')
text = ocr(image)
return jsonify({'text': text}), 200
except Exception as e:
return jsonify({'error': str(e)}), 500
if __name__ == '__main__':
app.run(debug=True)
|