import streamlit as st from streamlit_cropper import st_cropper from PIL import Image from transformers import TrOCRProcessor, VisionEncoderDecoderModel, DonutProcessor, NougatProcessor import torch import re import pytesseract from io import BytesIO import openai import requests from nougat.dataset.rasterize import rasterize_paper import uuid import os def get_pdf(pdf_link): unique_filename = f"{os.getcwd()}/downloaded_paper_{uuid.uuid4().hex}.pdf" response = requests.get(pdf_link) if response.status_code == 200: with open(unique_filename, 'wb') as pdf_file: pdf_file.write(response.content) print("PDF downloaded successfully.") else: print("Failed to download the PDF.") return unique_filename def predict_arabic(img, model_name="UBC-NLP/Qalam"): # if img is None: # _,generated_text=main(image) # return generated_text # else: # model_name = "UBC-NLP/Qalam" processor = TrOCRProcessor.from_pretrained(model_name) model = VisionEncoderDecoderModel.from_pretrained(model_name) images = img.convert("RGB") pixel_values = processor(images, return_tensors="pt").pixel_values generated_ids = model.generate(pixel_values, max_length=256) generated_text = processor.batch_decode( generated_ids, skip_special_tokens=True)[0] return generated_text def predict_english(img, model_name="naver-clova-ix/donut-base-finetuned-cord-v2"): processor = DonutProcessor.from_pretrained(model_name) model = VisionEncoderDecoderModel.from_pretrained(model_name) device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) task_prompt = "" decoder_input_ids = processor.tokenizer( task_prompt, add_special_tokens=False, return_tensors="pt").input_ids image = img.convert("RGB") pixel_values = processor(image, return_tensors="pt").pixel_values outputs = model.generate( pixel_values.to(device), decoder_input_ids=decoder_input_ids.to(device), max_length=model.decoder.config.max_position_embeddings, early_stopping=True, pad_token_id=processor.tokenizer.pad_token_id, eos_token_id=processor.tokenizer.eos_token_id, use_cache=True, num_beams=1, bad_words_ids=[[processor.tokenizer.unk_token_id]], return_dict_in_generate=True, ) sequence = processor.batch_decode(outputs.sequences)[0] sequence = sequence.replace(processor.tokenizer.eos_token, "").replace( processor.tokenizer.pad_token, "") sequence = re.sub(r"<.*?>", "", sequence).strip() return sequence def predict_nougat(img, model_name="facebook/nougat-small"): device = "cuda" if torch.cuda.is_available() else "cpu" processor = NougatProcessor.from_pretrained(model_name) model = VisionEncoderDecoderModel.from_pretrained(model_name) image = img.convert("RGB") pixel_values = processor(image, return_tensors="pt", data_format="channels_first").pixel_values # generate transcription (here we only generate 30 tokens) outputs = model.generate( pixel_values.to(device), min_length=1, max_new_tokens=1500, bad_words_ids=[[processor.tokenizer.unk_token_id]], ) page_sequence = processor.batch_decode( outputs, skip_special_tokens=True)[0] # page_sequence = processor.post_process_generation(page_sequence, fix_markdown=False) return page_sequence def inference_nougat(pdf_file, pdf_link): if pdf_file is None: if pdf_link == '': print("No file is uploaded and No link is provided") return "No data provided. Upload a pdf file or provide a pdf link and try again!" else: file_name = get_pdf(pdf_link) else: file_name = pdf_file.name pdf_name = pdf_file.name.split('/')[-1].split('.')[0] images = rasterize_paper(file_name, return_pil=True) sequence = "" # infer for every page and concat for image in images: sequence += predict_nougat(image) content = sequence.replace(r'\(', '$').replace( r'\)', '$').replace(r'\[', '$$').replace(r'\]', '$$') return content def predict_tesseract(img): text = pytesseract.image_to_string(Image.open(img)) return text st.set_option('deprecation.showfileUploaderEncoding', False) st.set_page_config( page_title="Ex-stream-ly Cool App", page_icon="🖊️", layout="wide", initial_sidebar_state="expanded", menu_items={ 'Get Help': 'https://www.extremelycoolapp.com/help', 'Report a bug': "https://www.extremelycoolapp.com/bug", 'About': "# This is a header. This is an *extremely* cool app!" } ) # Upload an image and set some options for demo purposes st.header("Qalam: A Multilingual OCR System") st.sidebar.header("Configuration and Image Upload") st.sidebar.subheader("Adjust Image Enhancement Options") img_file = st.sidebar.file_uploader( label='Upload a file', type=['png', 'jpg', "pdf"]) # input_file = st.sidebar.text_input("Enter the file URL") realtime_update = st.sidebar.checkbox(label="Update in Real Time", value=True) # box_color = st.sidebar.color_picker(label="Box Color", value='#0000FF') aspect_choice = st.sidebar.radio(label="Aspect Ratio", options=[ "Free"]) aspect_dict = { "Free": None } aspect_ratio = aspect_dict[aspect_choice] st.sidebar.subheader("Select OCR Language and Model") Lng = st.sidebar.selectbox(label="Language", options=[ "Arabic", "English", "French", "Korean", "Chinese"]) Models = { "Arabic": "Qalam", "English": "Nougat", "French": "Tesseract", "Korean": "Donut", "Chinese": "Donut" } st.sidebar.markdown(f"### Selected Model: {Models[Lng]}") if img_file: if not img_file.type == "application/pdf": img = Image.open(img_file) if not realtime_update: st.write("Double click to save crop") col1, col2 = st.columns(2) with col1: st.subheader("Input: Upload and Crop Your Image") # Get a cropped image from the frontend cropped_img = st_cropper( img, realtime_update=realtime_update, box_color="#FF0000", aspect_ratio=aspect_ratio, should_resize_image=True, ) with col2: # Manipulate cropped image at will st.subheader("Output: Preview and Analyze") # _ = cropped_img.thumbnail((150, 150)) st.image(cropped_img) button = st.button("Run OCR") if button: with st.spinner('Running OCR...'): if Lng == "Arabic": ocr_text = predict_arabic(cropped_img) elif Lng == "English": ocr_text = predict_nougat(cropped_img) elif Lng == "French": ocr_text = predict_tesseract(cropped_img) elif Lng == "Korean": ocr_text = predict_english(cropped_img) elif Lng == "Chinese": ocr_text = predict_english(cropped_img) st.subheader(f"OCR Results for {Lng}") st.write(ocr_text) text_file = BytesIO(ocr_text.encode()) st.download_button('Download Text', text_file, file_name='ocr_text.txt') elif img_file.type == "application/pdf": button = st.sidebar.button("Run OCR") if button: with st.spinner('Running OCR...'): ocr_text = inference_nougat(img_file, "") st.subheader(f"OCR Results for the PDF file") st.write(ocr_text) text_file = BytesIO(ocr_text.encode()) st.download_button('Download Text', text_file, file_name='ocr_text.txt') # openai.api_key = "" # if "openai_model" not in st.session_state: # st.session_state["openai_model"] = "gpt-3.5-turbo" # if "messages" not in st.session_state: # st.session_state.messages = [] # for message in st.session_state.messages: # with st.chat_message(message["role"]): # st.markdown(message["content"]) # if prompt := st.chat_input("How can I help?"): # st.session_state.messages.append({"role": "user", "content": ocr_text + prompt}) # with st.chat_message("user"): # st.markdown(prompt) # with st.chat_message("assistant"): # message_placeholder = st.empty() # full_response = "" # for response in openai.ChatCompletion.create( # model=st.session_state["openai_model"], # messages=[ # {"role": m["role"], "content": m["content"]} # for m in st.session_state.messages # ], # stream=True, # ): # full_response += response.choices[0].delta.get("content", "") # message_placeholder.markdown(full_response + "▌") # message_placeholder.markdown(full_response) # st.session_state.messages.append({"role": "assistant", "content": full_response})