import cv2 import io import numpy as np from PIL import Image import pytesseract from fastapi import FastAPI, UploadFile, File from fastapi.middleware.cors import CORSMiddleware from mltu.inferenceModel import OnnxInferenceModel from mltu.utils.text_utils import ctc_decoder from mltu.transformers import ImageResizer from mltu.configs import BaseModelConfigs from textblob import TextBlob from happytransformer import HappyTextToText, TTSettings from transformers import AutoTokenizer, T5ForConditionalGeneration from pydantic import BaseModel tokenizer = AutoTokenizer.from_pretrained("grammarly/coedit-large") chatModel = T5ForConditionalGeneration.from_pretrained("grammarly/coedit-large") configs = BaseModelConfigs.load("./configs.yaml") #happy_tt = HappyTextToText("T5", "vennify/t5-base-grammar-correction") beam_settings = TTSettings(num_beams=5, min_length=1, max_length=100) app = FastAPI() origins = ["*"] app.add_middleware( CORSMiddleware, allow_origins=origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) class ImageToWordModel(OnnxInferenceModel): def __init__(self, char_list, *args, **kwargs): super().__init__(*args, **kwargs) self.char_list = char_list def predict(self, image: np.ndarray): image = ImageResizer.resize_maintaining_aspect_ratio( image, *self.input_shape[:2][::-1] ) image_pred = np.expand_dims(image, axis=0).astype(np.float32) preds = self.model.run(None, {self.input_name: image_pred})[0] text = ctc_decoder(preds, self.char_list)[0] return text model = ImageToWordModel(model_path=configs.model_path, char_list=configs.vocab) extracted_text = "" @app.post("/extract_handwritten_text/") async def predict_text(image: UploadFile): global extracted_text # Read the uploaded image img = await image.read() nparr = np.frombuffer(img, np.uint8) img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) # Make a prediction extracted_text = model.predict(img) #corrected_text = happy_tt.generate_text(extracted_text, beam_settings) return {"text": extracted_text} @app.post("/extract_text/") async def extract_text_from_image(image: UploadFile): global extracted_text # Check if the uploaded file is an image if image.content_type.startswith("image/"): # Read the image from the uploaded file image_bytes = await image.read() img = Image.open(io.BytesIO(image_bytes)) # Perform OCR on the image extracted_text = pytesseract.image_to_string(img) #corrected_text = happy_tt.generate_text(extracted_text, beam_settings) return {"text": extracted_text} else: return {"error": "Invalid file format. Please upload an image."} class ChatPrompt(BaseModel): prompt: str @app.post("/chat_prompt/") async def chat_prompt(request: ChatPrompt): global extracted_text input_text = request.prompt + ": " + extracted_text print(input_text) input_ids = tokenizer(input_text, return_tensors="pt").input_ids outputs = chatModel.generate(input_ids, max_length=256) edited_text = tokenizer.decode(outputs[0], skip_special_tokens=True) return {"edited_text": edited_text}