|
import csv |
|
|
|
import json |
|
import numpy as np |
|
import os |
|
|
|
import sys |
|
import torch |
|
import requests |
|
|
|
from dataclasses import dataclass |
|
from PIL import Image |
|
from nltk import edit_distance |
|
from torchvision import transforms as T |
|
from typing import Optional, Callable, Sequence, Tuple |
|
from tqdm import tqdm |
|
|
|
|
|
from IndicPhotoOCR.utils.strhub.data.module import SceneTextDataModule |
|
from IndicPhotoOCR.utils.strhub.models.utils import load_from_checkpoint |
|
|
|
|
|
model_info = { |
|
"assamese": { |
|
"path": "models/assamese.ckpt", |
|
"url" : "https://github.com/anikde/STocr/releases/download/V2.0.0/assamese.ckpt", |
|
}, |
|
"bengali": { |
|
"path": "models/bengali.ckpt", |
|
"url" : "https://github.com/anikde/STocr/releases/download/V2.0.0/bengali.ckpt", |
|
}, |
|
"hindi": { |
|
"path": "models/hindi.ckpt", |
|
"url" : "https://github.com/anikde/STocr/releases/download/V2.0.0/hindi.ckpt", |
|
}, |
|
"gujarati": { |
|
"path": "models/gujarati.ckpt", |
|
"url" : "https://github.com/anikde/STocr/releases/download/V2.0.0/gujarati.ckpt", |
|
}, |
|
"marathi": { |
|
"path": "models/marathi.ckpt", |
|
"url" : "https://github.com/anikde/STocr/releases/download/V2.0.0/marathi.ckpt", |
|
}, |
|
"odia": { |
|
"path": "models/odia.ckpt", |
|
"url" : "https://github.com/anikde/STocr/releases/download/V2.0.0/odia.ckpt", |
|
}, |
|
"punjabi": { |
|
"path": "models/punjabi.ckpt", |
|
"url" : "https://github.com/anikde/STocr/releases/download/V2.0.0/punjabi.ckpt", |
|
}, |
|
"tamil": { |
|
"path": "models/tamil.ckpt", |
|
"url" : "https://github.com/anikde/STocr/releases/download/V2.0.0/tamil.ckpt", |
|
}, |
|
"telugu": { |
|
"path": "models/telugu.ckpt", |
|
"url" : "https://github.com/anikde/STocr/releases/download/V2.0.0/telugu.ckpt", |
|
} |
|
} |
|
|
|
class PARseqrecogniser: |
|
def __init__(self): |
|
pass |
|
|
|
def get_transform(self, img_size: Tuple[int], augment: bool = False, rotation: int = 0): |
|
transforms = [] |
|
if augment: |
|
from .augment import rand_augment_transform |
|
transforms.append(rand_augment_transform()) |
|
if rotation: |
|
transforms.append(lambda img: img.rotate(rotation, expand=True)) |
|
transforms.extend([ |
|
T.Resize(img_size, T.InterpolationMode.BICUBIC), |
|
T.ToTensor(), |
|
T.Normalize(0.5, 0.5) |
|
]) |
|
return T.Compose(transforms) |
|
|
|
|
|
def load_model(self, device, checkpoint): |
|
model = load_from_checkpoint(checkpoint).eval().to(device) |
|
return model |
|
|
|
def get_model_output(self, device, model, image_path): |
|
hp = model.hparams |
|
transform = self.get_transform(hp.img_size, rotation=0) |
|
|
|
image_name = image_path.split("/")[-1] |
|
img = Image.open(image_path).convert('RGB') |
|
img = transform(img) |
|
logits = model(img.unsqueeze(0).to(device)) |
|
probs = logits.softmax(-1) |
|
preds, probs = model.tokenizer.decode(probs) |
|
text = model.charset_adapter(preds[0]) |
|
scores = probs[0].detach().cpu().numpy() |
|
|
|
return text |
|
|
|
|
|
def ensure_model(self, model_name): |
|
model_path = model_info[model_name]["path"] |
|
url = model_info[model_name]["url"] |
|
root_model_dir = "IndicPhotoOCR/recognition/" |
|
model_path = os.path.join(root_model_dir, model_path) |
|
|
|
if not os.path.exists(model_path): |
|
print(f"Model not found locally. Downloading {model_name} from {url}...") |
|
|
|
|
|
response = requests.get(url, stream=True) |
|
total_size = int(response.headers.get('content-length', 0)) |
|
os.makedirs(f"{root_model_dir}/models", exist_ok=True) |
|
|
|
with open(model_path, "wb") as f, tqdm( |
|
desc=model_name, |
|
total=total_size, |
|
unit='B', |
|
unit_scale=True, |
|
unit_divisor=1024, |
|
) as bar: |
|
for data in response.iter_content(chunk_size=1024): |
|
f.write(data) |
|
bar.update(len(data)) |
|
|
|
print(f"Downloaded model for {model_name}.") |
|
|
|
return model_path |
|
|
|
def bstr(checkpoint, language, image_dir, save_dir): |
|
""" |
|
Runs the OCR model to process images and save the output as a JSON file. |
|
|
|
Args: |
|
checkpoint (str): Path to the model checkpoint file. |
|
language (str): Language code (e.g., 'hindi', 'english'). |
|
image_dir (str): Directory containing the images to process. |
|
save_dir (str): Directory where the output JSON file will be saved. |
|
|
|
Example usage: |
|
python your_script.py --checkpoint /path/to/checkpoint.ckpt --language hindi --image_dir /path/to/images --save_dir /path/to/save |
|
""" |
|
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") |
|
|
|
if language != "english": |
|
model = load_model(device, checkpoint) |
|
else: |
|
model = torch.hub.load('baudm/parseq', 'parseq', pretrained=True).eval().to(device) |
|
|
|
parseq_dict = {} |
|
for image_path in tqdm(os.listdir(image_dir)): |
|
assert os.path.exists(os.path.join(image_dir, image_path)) == True, f"{image_path}" |
|
text = get_model_output(device, model, os.path.join(image_dir, image_path), language=f"{language}") |
|
|
|
filename = image_path.split('/')[-1] |
|
parseq_dict[filename] = text |
|
|
|
os.makedirs(save_dir, exist_ok=True) |
|
with open(f"{save_dir}/{language}_test.json", 'w') as json_file: |
|
json.dump(parseq_dict, json_file, indent=4, ensure_ascii=False) |
|
|
|
|
|
def bstr_onImage(checkpoint, language, image_path): |
|
""" |
|
Runs the OCR model to process images and save the output as a JSON file. |
|
|
|
Args: |
|
checkpoint (str): Path to the model checkpoint file. |
|
language (str): Language code (e.g., 'hindi', 'english'). |
|
image_dir (str): Directory containing the images to process. |
|
save_dir (str): Directory where the output JSON file will be saved. |
|
|
|
Example usage: |
|
python your_script.py --checkpoint /path/to/checkpoint.ckpt --language hindi --image_dir /path/to/images --save_dir /path/to/save |
|
""" |
|
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") |
|
|
|
if language != "english": |
|
model = load_model(device, checkpoint) |
|
else: |
|
model = torch.hub.load('baudm/parseq', 'parseq', pretrained=True).eval().to(device) |
|
|
|
|
|
|
|
|
|
text = get_model_output(device, model, image_path, language=f"{language}") |
|
|
|
return text |
|
|
|
|
|
def recognise(self, checkpoint: str, image_path: str, language: str, verbose: bool, device: str) -> str: |
|
""" |
|
Loads the desired model and returns the recognized word from the specified image. |
|
|
|
Args: |
|
checkpoint (str): Path to the model checkpoint file. |
|
language (str): Language code (e.g., 'hindi', 'english'). |
|
image_path (str): Path to the image for which text recognition is needed. |
|
|
|
Returns: |
|
str: The recognized text from the image. |
|
""" |
|
|
|
|
|
if language != "english": |
|
model_path = self.ensure_model(checkpoint) |
|
model = self.load_model(device, model_path) |
|
else: |
|
model = torch.hub.load('baudm/parseq', 'parseq', pretrained=True, verbose=verbose).eval().to(device) |
|
|
|
recognized_text = self.get_model_output(device, model, image_path) |
|
|
|
return recognized_text |
|
|
|
|