File size: 3,701 Bytes
4f2b696
cbbb801
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2c72c7b
36debf9
2c72c7b
a4172ab
 
 
 
2c72c7b
4f2b696
2c72c7b
4f2b696
 
a4172ab
4f2b696
5ead04d
cbbb801
 
a4172ab
cbbb801
5ead04d
4f2b696
 
cbbb801
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import streamlit as st
import torch
from PIL import Image
from huggingface_hub import hf_hub_download
from transformers import VisionEncoderDecoderModel


import warnings
from contextlib import contextmanager
from transformers import MBartTokenizer, ViTImageProcessor, XLMRobertaTokenizer
from transformers import ProcessorMixin


class CustomOCRProcessor(ProcessorMixin):
    attributes = ["image_processor", "tokenizer"]
    image_processor_class = "AutoImageProcessor"
    tokenizer_class = "AutoTokenizer"

    def __init__(self, image_processor=None, tokenizer=None, **kwargs):
        if "feature_extractor" in kwargs:
            warnings.warn(
                "The `feature_extractor` argument is deprecated and will be removed in v5, use `image_processor`"
                " instead.",
                FutureWarning,
            )
            feature_extractor = kwargs.pop("feature_extractor")

        image_processor = image_processor if image_processor is not None else feature_extractor
        if image_processor is None:
            raise ValueError("You need to specify an `image_processor`.")
        if tokenizer is None:
            raise ValueError("You need to specify a `tokenizer`.")

        super().__init__(image_processor, tokenizer)
        self.current_processor = self.image_processor
        self._in_target_context_manager = False

    def __call__(self, *args, **kwargs):
        # For backward compatibility
        if self._in_target_context_manager:
            return self.current_processor(*args, **kwargs)

        images = kwargs.pop("images", None)
        text = kwargs.pop("text", None)
        if len(args) > 0:
            images = args[0]
            args = args[1:]

        if images is None and text is None:
            raise ValueError("You need to specify either an `images` or `text` input to process.")

        if images is not None:
            inputs = self.image_processor(images, *args, **kwargs)
        if text is not None:
            encodings = self.tokenizer(text, **kwargs)

        if text is None:
            return inputs
        elif images is None:
            return encodings
        else:
            inputs["labels"] = encodings["input_ids"]
            return inputs

    def batch_decode(self, *args, **kwargs):
        return self.tokenizer.batch_decode(*args, **kwargs)

    def decode(self, *args, **kwargs):
        return self.tokenizer.decode(*args, **kwargs)


image_processor = ViTImageProcessor.from_pretrained(
    'microsoft/swin-base-patch4-window12-384-in22k'
)
tokenizer = MBartTokenizer.from_pretrained(
    'facebook/mbart-large-50'
)
processortext2 = CustomOCRProcessor(image_processor,tokenizer)

import os
huggingface_token = os.environ.get("HUGGINGFACE_TOKEN")
model = {}
model['single-urdu'] = "musadac/vilanocr-single-urdu"
model['multi-urdu'] = "musadac/ViLanOCR"
model['medical'] = "musadac/vilanocr-multi-medical"
model['chinese'] = "musadac/vilanocr-single-chinese"

st.title("Image OCR with musadac/vilanocr")
model_name = st.selectbox("Choose an OCR model", ["single-urdu", "multi-urdu", "medical","chinese" ])
uploaded_file = st.file_uploader("Choose an image", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
    model2 = VisionEncoderDecoderModel.from_pretrained(model[model_name], use_auth_token=huggingface_token)
    img = Image.open(uploaded_file).convert("RGB")
    pixel_values = processortext2(img.convert("RGB"), return_tensors="pt").pixel_values
    
    with torch.no_grad():
        generated_ids = model2.generate(pixel_values)
    
    result = processortext2.batch_decode(generated_ids, skip_special_tokens=True)[0]
    st.write("OCR Result:")
    st.write(result)