musadac's picture
Create app.py
cbbb801
raw history blame
No virus
3.82 kB
import torch
import torchvision.transforms as T
from PIL import Image
from huggingface_hub import hf_hub_download
from transformers import VisionEncoderDecoderModel
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
import warnings
from contextlib import contextmanager
from transformers import MBartTokenizer, ViTImageProcessor, XLMRobertaTokenizer
from transformers import ProcessorMixin
class CustomOCRProcessor(ProcessorMixin):
attributes = ["image_processor", "tokenizer"]
image_processor_class = "AutoImageProcessor"
tokenizer_class = "AutoTokenizer"
def __init__(self, image_processor=None, tokenizer=None, **kwargs):
if "feature_extractor" in kwargs:
warnings.warn(
"The `feature_extractor` argument is deprecated and will be removed in v5, use `image_processor`"
" instead.",
FutureWarning,
)
feature_extractor = kwargs.pop("feature_extractor")
image_processor = image_processor if image_processor is not None else feature_extractor
if image_processor is None:
raise ValueError("You need to specify an `image_processor`.")
if tokenizer is None:
raise ValueError("You need to specify a `tokenizer`.")
super().__init__(image_processor, tokenizer)
self.current_processor = self.image_processor
self._in_target_context_manager = False
def __call__(self, *args, **kwargs):
# For backward compatibility
if self._in_target_context_manager:
return self.current_processor(*args, **kwargs)
images = kwargs.pop("images", None)
text = kwargs.pop("text", None)
if len(args) > 0:
images = args[0]
args = args[1:]
if images is None and text is None:
raise ValueError("You need to specify either an `images` or `text` input to process.")
if images is not None:
inputs = self.image_processor(images, *args, **kwargs)
if text is not None:
encodings = self.tokenizer(text, **kwargs)
if text is None:
return inputs
elif images is None:
return encodings
else:
inputs["labels"] = encodings["input_ids"]
return inputs
def batch_decode(self, *args, **kwargs):
return self.tokenizer.batch_decode(*args, **kwargs)
def decode(self, *args, **kwargs):
return self.tokenizer.decode(*args, **kwargs)
image_processor = ViTImageProcessor.from_pretrained(
'microsoft/swin-base-patch4-window12-384-in22k'
)
tokenizer = MBartTokenizer.from_pretrained(
'facebook/mbart-large-50'
)
processortext2 = CustomOCRProcessor(image_processor,tokenizer)
app = FastAPI()
app.mount("/static", StaticFiles(directory="static"), name="static")
templates = Jinja2Templates(directory="templates")
# Download and load the model
model2 = VisionEncoderDecoderModel.from_pretrained("musadac/vilanocr-single-urdu",use_auth_token=True).to(device)
@app.get("/", response_class=HTMLResponse)
async def root():
return templates.TemplateResponse("index.html", {"request": None})
@app.post("/upload/", response_class=HTMLResponse)
async def upload_image(image: UploadFile = File(...)):
# Preprocess image
img = Image.open(image.file).convert("RGB")
pixel_values = processortext(img.convert("RGB"), return_tensors="pt").pixel_values
# Run the model
with torch.no_grad():
generated_ids = model2.generate(img_tensor)
# Extract OCR result
result = processortext.batch_decode(generated_ids, skip_special_tokens=True)[0]
return {"result": result}