import streamlit as st import torch from PIL import Image from huggingface_hub import hf_hub_download from transformers import VisionEncoderDecoderModel import warnings from contextlib import contextmanager from transformers import MBartTokenizer, ViTImageProcessor, XLMRobertaTokenizer from transformers import ProcessorMixin class CustomOCRProcessor(ProcessorMixin): attributes = ["image_processor", "tokenizer"] image_processor_class = "AutoImageProcessor" tokenizer_class = "AutoTokenizer" def __init__(self, image_processor=None, tokenizer=None, **kwargs): if "feature_extractor" in kwargs: warnings.warn( "The `feature_extractor` argument is deprecated and will be removed in v5, use `image_processor`" " instead.", FutureWarning, ) feature_extractor = kwargs.pop("feature_extractor") image_processor = image_processor if image_processor is not None else feature_extractor if image_processor is None: raise ValueError("You need to specify an `image_processor`.") if tokenizer is None: raise ValueError("You need to specify a `tokenizer`.") super().__init__(image_processor, tokenizer) self.current_processor = self.image_processor self._in_target_context_manager = False def __call__(self, *args, **kwargs): # For backward compatibility if self._in_target_context_manager: return self.current_processor(*args, **kwargs) images = kwargs.pop("images", None) text = kwargs.pop("text", None) if len(args) > 0: images = args[0] args = args[1:] if images is None and text is None: raise ValueError("You need to specify either an `images` or `text` input to process.") if images is not None: inputs = self.image_processor(images, *args, **kwargs) if text is not None: encodings = self.tokenizer(text, **kwargs) if text is None: return inputs elif images is None: return encodings else: inputs["labels"] = encodings["input_ids"] return inputs def batch_decode(self, *args, **kwargs): return self.tokenizer.batch_decode(*args, **kwargs) def decode(self, *args, **kwargs): return self.tokenizer.decode(*args, **kwargs) image_processor = ViTImageProcessor.from_pretrained( 'microsoft/swin-base-patch4-window12-384-in22k' ) tokenizer = MBartTokenizer.from_pretrained( 'facebook/mbart-large-50' ) processortext2 = CustomOCRProcessor(image_processor,tokenizer) st.title("Image OCR with musadac/vilanocr") uploaded_file = st.file_uploader("Choose an image", type=["jpg", "jpeg", "png"]) if uploaded_file is not None: img = Image.open(uploaded_file).convert("RGB") pixel_values = processortext2(img.convert("RGB"), return_tensors="pt").pixel_values with torch.no_grad(): generated_ids = model2.generate(img_tensor) result = processortext2.batch_decode(generated_ids, skip_special_tokens=True)[0] st.write("OCR Result:") st.write(result)