Spaces:
Sleeping
Sleeping
import gradio as gr | |
from PIL import Image | |
import os | |
from IndicPhotoOCR.ocr import OCR # Ensure OCR class is saved in a file named ocr.py | |
from IndicPhotoOCR.theme import Seafoam | |
from IndicPhotoOCR.utils.helper import detect_para | |
from transformers import ( | |
AutoModelForSeq2SeqLM, | |
AutoTokenizer, | |
) | |
import numpy as np | |
import torch | |
from collections import Counter | |
def Most_Common(lst): | |
data = Counter(lst) | |
return data.most_common(1)[0][0] | |
from IndicTransToolkit import IndicProcessor | |
# Initialize the OCR object for text detection and recognition | |
ocr = OCR(device='cpu',verbose=False) | |
def translate(given_str,lang='hindi'): | |
DEVICE = 'cpu' | |
model_name = "ai4bharat/indictrans2-en-indic-1B" if lang=="english" else "ai4bharat/indictrans2-indic-en-1B" | |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, trust_remote_code=True) | |
ip = IndicProcessor(inference=True) | |
model = model.to(DEVICE) | |
model.eval() | |
src_lang, tgt_lang = ("eng_Latn", "hin_Deva") if lang=="english" else ("hin_Deva", "eng_Latn" ) | |
batch = ip.preprocess_batch( | |
[given_str], | |
src_lang=src_lang, | |
tgt_lang=tgt_lang, | |
) | |
inputs = tokenizer( | |
batch, | |
truncation=True, | |
padding="longest", | |
return_tensors="pt", | |
return_attention_mask=True, | |
).to(DEVICE) | |
with torch.no_grad(): | |
generated_tokens = model.generate( | |
**inputs, | |
use_cache=True, | |
min_length=0, | |
max_length=256, | |
num_beams=5, | |
num_return_sequences=1, | |
) | |
# Decode the generated tokens into text | |
with tokenizer.as_target_tokenizer(): | |
generated_tokens = tokenizer.batch_decode( | |
generated_tokens.detach().cpu().tolist(), | |
skip_special_tokens=True, | |
clean_up_tokenization_spaces=True, | |
) | |
translation = ip.postprocess_batch(generated_tokens, lang=tgt_lang)[0] | |
return translation | |
def process_image(image): | |
""" | |
Processes the uploaded image for text detection and recognition. | |
- Detects bounding boxes in the image | |
- Draws bounding boxes on the image and identifies script in each detected area | |
- Recognizes text in each cropped region and returns the annotated image and recognized text | |
Parameters: | |
image (PIL.Image): The input image to be processed. | |
Returns: | |
tuple: A PIL.Image with bounding boxes and a string of recognized text. | |
""" | |
# Save the input image temporarily | |
image_path = "input_image.jpg" | |
image.save(image_path) | |
# Detect bounding boxes on the image using OCR | |
detections = ocr.detect(image_path) | |
# Draw bounding boxes on the image and save it as output | |
ocr.visualize_detection(image_path, detections, save_path="output_image.png") | |
# Load the annotated image with bounding boxes drawn | |
output_image = Image.open("output_image.png") | |
# Initialize list to hold recognized text from each detected area | |
recognized_texts = {} | |
pil_image = Image.open(image_path) | |
# # Process each detected bounding box for script identification and text recognition | |
# for bbox in detections: | |
# # Identify the script and crop the image to this region | |
# script_lang, cropped_path = ocr.crop_and_identify_script(pil_image, bbox) | |
# if script_lang: # Only proceed if a script language is identified | |
# # Recognize text in the cropped area | |
# recognized_text = ocr.recognise(cropped_path, script_lang) | |
# recognized_texts.append(recognized_text) | |
langs = [] | |
for id, bbox in enumerate(detections): | |
# Identify the script and crop the image to this region | |
script_lang, cropped_path = ocr.crop_and_identify_script(pil_image, bbox) | |
# Calculate bounding box coordinates | |
x1 = min([bbox[i][0] for i in range(len(bbox))]) | |
y1 = min([bbox[i][1] for i in range(len(bbox))]) | |
x2 = max([bbox[i][0] for i in range(len(bbox))]) | |
y2 = max([bbox[i][1] for i in range(len(bbox))]) | |
if script_lang: | |
recognized_text = ocr.recognise(cropped_path, script_lang) | |
recognized_texts[f"img_{id}"] = {"txt": recognized_text, "bbox": [x1, y1, x2, y2]} | |
langs.append(script_lang) | |
# Combine recognized texts into a single string for display | |
# recognized_texts_combined = " ".join(recognized_texts) | |
string = detect_para(recognized_texts) | |
recognized_texts_combined = '\n'.join([' '.join(line) for line in string]) | |
recognized_texts_combined = translate(recognized_texts_combined,Most_Common(langs)) | |
return output_image, recognized_texts_combined | |
# Custom HTML for interface header with logos and alignment | |
interface_html = """ | |
<div style="text-align: left; padding: 10px;"> | |
<div style="background-color: white; padding: 10px; display: inline-block;"> | |
<img src="https://iitj.ac.in/images/logo/Design-of-New-Logo-of-IITJ-2.png" alt="IITJ Logo" style="width: 100px; height: 100px;"> | |
</div> | |
<img src="https://play-lh.googleusercontent.com/_FXSr4xmhPfBykmNJvKvC0GIAVJmOLhFl6RA5fobCjV-8zVSypxX8yb8ka6zu6-4TEft=w240-h480-rw" alt="Bhashini Logo" style="width: 100px; height: 100px; float: right;"> | |
</div> | |
""" | |
# Links to GitHub and Dataset repositories with GitHub icon | |
links_html = """ | |
<div style="text-align: center; padding-top: 20px;"> | |
<a href="https://github.com/Bhashini-IITJ/IndicPhotoOCR" target="_blank" style="margin-right: 20px; font-size: 18px; text-decoration: none;"> | |
GitHub Repository | |
</a> | |
<a href="https://github.com/Bhashini-IITJ/BharatSceneTextDataset" target="_blank" style="font-size: 18px; text-decoration: none;"> | |
Dataset Repository | |
</a> | |
</div> | |
""" | |
# Custom CSS to style the text box font size | |
custom_css = """ | |
.custom-textbox textarea { | |
font-size: 20px !important; | |
} | |
""" | |
# Create an instance of the Seafoam theme for a consistent visual style | |
seafoam = Seafoam() | |
# Define examples for users to try out | |
examples = [ | |
["test_images/208.jpg"], | |
["test_images/1310.jpg"] | |
] | |
title = "<h1 style='text-align: center;'>Developed by IITJ</h1>" | |
# Set up the Gradio Interface with the defined function and customizations | |
demo = gr.Interface( | |
fn=process_image, | |
inputs=gr.Image(type="pil", image_mode="RGB"), | |
outputs=[ | |
gr.Image(type="pil", label="Detected Bounding Boxes"), | |
gr.Textbox(label="Translated Text", elem_classes="custom-textbox") | |
], | |
title="Scene Text Translator", | |
description=title+interface_html+links_html, | |
theme=seafoam, | |
css=custom_css, | |
examples=examples | |
) | |
# # Server setup and launch configuration | |
# if __name__ == "__main__": | |
# server = "0.0.0.0" # IP address for server | |
# port = 7867 # Port to run the server on | |
# demo.launch(server_name=server, server_port=port) | |
demo.launch() | |