musadac commited on
Commit
cbbb801
1 Parent(s): e056e18

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -0
app.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision.transforms as T
3
+ from PIL import Image
4
+ from huggingface_hub import hf_hub_download
5
+ from transformers import VisionEncoderDecoderModel
6
+ from fastapi import FastAPI, File, UploadFile
7
+ from fastapi.responses import HTMLResponse
8
+ from fastapi.staticfiles import StaticFiles
9
+ from fastapi.templating import Jinja2Templates
10
+
11
+
12
+ import warnings
13
+ from contextlib import contextmanager
14
+ from transformers import MBartTokenizer, ViTImageProcessor, XLMRobertaTokenizer
15
+ from transformers import ProcessorMixin
16
+
17
+
18
+ class CustomOCRProcessor(ProcessorMixin):
19
+ attributes = ["image_processor", "tokenizer"]
20
+ image_processor_class = "AutoImageProcessor"
21
+ tokenizer_class = "AutoTokenizer"
22
+
23
+ def __init__(self, image_processor=None, tokenizer=None, **kwargs):
24
+ if "feature_extractor" in kwargs:
25
+ warnings.warn(
26
+ "The `feature_extractor` argument is deprecated and will be removed in v5, use `image_processor`"
27
+ " instead.",
28
+ FutureWarning,
29
+ )
30
+ feature_extractor = kwargs.pop("feature_extractor")
31
+
32
+ image_processor = image_processor if image_processor is not None else feature_extractor
33
+ if image_processor is None:
34
+ raise ValueError("You need to specify an `image_processor`.")
35
+ if tokenizer is None:
36
+ raise ValueError("You need to specify a `tokenizer`.")
37
+
38
+ super().__init__(image_processor, tokenizer)
39
+ self.current_processor = self.image_processor
40
+ self._in_target_context_manager = False
41
+
42
+ def __call__(self, *args, **kwargs):
43
+ # For backward compatibility
44
+ if self._in_target_context_manager:
45
+ return self.current_processor(*args, **kwargs)
46
+
47
+ images = kwargs.pop("images", None)
48
+ text = kwargs.pop("text", None)
49
+ if len(args) > 0:
50
+ images = args[0]
51
+ args = args[1:]
52
+
53
+ if images is None and text is None:
54
+ raise ValueError("You need to specify either an `images` or `text` input to process.")
55
+
56
+ if images is not None:
57
+ inputs = self.image_processor(images, *args, **kwargs)
58
+ if text is not None:
59
+ encodings = self.tokenizer(text, **kwargs)
60
+
61
+ if text is None:
62
+ return inputs
63
+ elif images is None:
64
+ return encodings
65
+ else:
66
+ inputs["labels"] = encodings["input_ids"]
67
+ return inputs
68
+
69
+ def batch_decode(self, *args, **kwargs):
70
+ return self.tokenizer.batch_decode(*args, **kwargs)
71
+
72
+ def decode(self, *args, **kwargs):
73
+ return self.tokenizer.decode(*args, **kwargs)
74
+
75
+
76
+ image_processor = ViTImageProcessor.from_pretrained(
77
+ 'microsoft/swin-base-patch4-window12-384-in22k'
78
+ )
79
+ tokenizer = MBartTokenizer.from_pretrained(
80
+ 'facebook/mbart-large-50'
81
+ )
82
+ processortext2 = CustomOCRProcessor(image_processor,tokenizer)
83
+
84
+
85
+ app = FastAPI()
86
+ app.mount("/static", StaticFiles(directory="static"), name="static")
87
+ templates = Jinja2Templates(directory="templates")
88
+
89
+ # Download and load the model
90
+ model2 = VisionEncoderDecoderModel.from_pretrained("musadac/vilanocr-single-urdu",use_auth_token=True).to(device)
91
+
92
+
93
+ @app.get("/", response_class=HTMLResponse)
94
+ async def root():
95
+ return templates.TemplateResponse("index.html", {"request": None})
96
+
97
+ @app.post("/upload/", response_class=HTMLResponse)
98
+ async def upload_image(image: UploadFile = File(...)):
99
+ # Preprocess image
100
+ img = Image.open(image.file).convert("RGB")
101
+ pixel_values = processortext(img.convert("RGB"), return_tensors="pt").pixel_values
102
+
103
+ # Run the model
104
+ with torch.no_grad():
105
+ generated_ids = model2.generate(img_tensor)
106
+
107
+ # Extract OCR result
108
+ result = processortext.batch_decode(generated_ids, skip_special_tokens=True)[0]
109
+
110
+ return {"result": result}