|
import cv2 |
|
import io |
|
import numpy as np |
|
from PIL import Image |
|
|
|
import pytesseract |
|
|
|
from fastapi import FastAPI, UploadFile, File |
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
|
from mltu.inferenceModel import OnnxInferenceModel |
|
from mltu.utils.text_utils import ctc_decoder |
|
from mltu.transformers import ImageResizer |
|
from mltu.configs import BaseModelConfigs |
|
|
|
from textblob import TextBlob |
|
from happytransformer import HappyTextToText, TTSettings |
|
|
|
|
|
from transformers import AutoTokenizer, T5ForConditionalGeneration |
|
from pydantic import BaseModel |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("grammarly/coedit-large", cache_dir="./cache") |
|
chatModel = T5ForConditionalGeneration.from_pretrained("grammarly/coedit-large", cache_dir="./cache") |
|
|
|
configs = BaseModelConfigs.load("./configs.yaml") |
|
|
|
|
|
|
|
beam_settings = TTSettings(num_beams=5, min_length=1, max_length=100) |
|
|
|
app = FastAPI() |
|
|
|
origins = ["*"] |
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=origins, |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
|
|
class ImageToWordModel(OnnxInferenceModel): |
|
def __init__(self, char_list, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.char_list = char_list |
|
|
|
def predict(self, image: np.ndarray): |
|
image = ImageResizer.resize_maintaining_aspect_ratio( |
|
image, *self.input_shape[:2][::-1] |
|
) |
|
|
|
image_pred = np.expand_dims(image, axis=0).astype(np.float32) |
|
|
|
preds = self.model.run(None, {self.input_name: image_pred})[0] |
|
|
|
text = ctc_decoder(preds, self.char_list)[0] |
|
|
|
return text |
|
|
|
|
|
model = ImageToWordModel(model_path=configs.model_path, char_list=configs.vocab) |
|
extracted_text = "" |
|
|
|
@app.post("/extract_handwritten_text/") |
|
async def predict_text(image: UploadFile): |
|
global extracted_text |
|
|
|
img = await image.read() |
|
nparr = np.frombuffer(img, np.uint8) |
|
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) |
|
|
|
|
|
extracted_text = model.predict(img) |
|
|
|
|
|
return {"text": extracted_text} |
|
|
|
|
|
@app.post("/extract_text/") |
|
async def extract_text_from_image(image: UploadFile): |
|
global extracted_text |
|
|
|
if image.content_type.startswith("image/"): |
|
|
|
image_bytes = await image.read() |
|
img = Image.open(io.BytesIO(image_bytes)) |
|
|
|
|
|
extracted_text = pytesseract.image_to_string(img) |
|
|
|
|
|
return {"text": extracted_text} |
|
else: |
|
return {"error": "Invalid file format. Please upload an image."} |
|
|
|
class ChatPrompt(BaseModel): |
|
prompt: str |
|
|
|
@app.post("/chat_prompt/") |
|
async def chat_prompt(request: ChatPrompt): |
|
global extracted_text |
|
input_text = request.prompt + ": " + extracted_text |
|
print(input_text) |
|
input_ids = tokenizer(input_text, return_tensors="pt").input_ids |
|
outputs = chatModel.generate(input_ids, max_length=256) |
|
edited_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
return {"edited_text": edited_text} |
|
|