|
import gradio as gr |
|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
from torch.utils.data import Dataset, DataLoader |
|
from torchvision import transforms |
|
from PIL import Image, ImageFont, ImageDraw |
|
import numpy as np |
|
import os |
|
import string |
|
import cv2 |
|
from torchvision.transforms.functional import to_pil_image |
|
import matplotlib.pyplot as plt |
|
import math |
|
from datetime import datetime |
|
import re |
|
from termcolor import colored |
|
from pyctcdecode import BeamSearchDecoderCTC, Alphabet |
|
from difflib import SequenceMatcher |
|
|
|
|
|
|
|
CHARS = string.ascii_letters + string.digits + string.punctuation |
|
CHAR2IDX = {c: i + 1 for i, c in enumerate(CHARS)} |
|
CHAR2IDX["<BLANK>"] = 0 |
|
IDX2CHAR = {v: k for k, v in CHAR2IDX.items()} |
|
BLANK_IDX = 0 |
|
IMAGE_HEIGHT = 32 |
|
IMAGE_WIDTH = 128 |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
font_path = None |
|
ocr_model = None |
|
|
|
labels = [IDX2CHAR.get(i, "") for i in range(len(IDX2CHAR))] |
|
|
|
|
|
alphabet = Alphabet.build_alphabet(labels) |
|
|
|
|
|
decoder = BeamSearchDecoderCTC(alphabet) |
|
|
|
os.makedirs("./fonts", exist_ok=True) |
|
os.makedirs("./models", exist_ok=True) |
|
os.makedirs("./labels", exist_ok=True) |
|
|
|
|
|
class OCRDataset(Dataset): |
|
def __init__(self, font_path, size=1000, label_length_range=(4, 7)): |
|
self.font = ImageFont.truetype(font_path, 32) |
|
self.label_length_range = label_length_range |
|
self.samples = [ |
|
"".join(np.random.choice(list(CHARS), np.random.randint(*self.label_length_range))) |
|
for _ in range(size) |
|
] |
|
|
|
self.transform = transforms.Compose([ |
|
transforms.ToTensor(), |
|
transforms.Normalize((0.5,), (0.5,)), |
|
transforms.Resize((IMAGE_HEIGHT, IMAGE_WIDTH)), |
|
transforms.RandomApply([transforms.GaussianBlur(kernel_size=3)], p=0.3), |
|
transforms.RandomApply([transforms.RandomAffine(degrees=10, translate=(0.1, 0.1))], p=0.3), |
|
]) |
|
def __len__(self): |
|
return len(self.samples) |
|
|
|
def __getitem__(self, idx): |
|
label = self.samples[idx] |
|
|
|
|
|
pad = 8 |
|
w = self.font.getlength(label) |
|
h = self.font.size |
|
img_w, img_h = int(w + 2 * pad), int(h + 2 * pad) |
|
img = Image.new("L", (img_w, img_h), 255) |
|
draw = ImageDraw.Draw(img) |
|
draw.text((pad, pad), label, font=self.font, fill=0) |
|
|
|
img = self.transform(img) |
|
label_encoded = torch.tensor([CHAR2IDX[c] for c in label], dtype=torch.long) |
|
label_length = torch.tensor(len(label_encoded), dtype=torch.long) |
|
|
|
return img, label_encoded, label_length |
|
|
|
|
|
|
|
def render_text(self, text): |
|
img = Image.new("L", (IMAGE_WIDTH, IMAGE_HEIGHT), color=255) |
|
draw = ImageDraw.Draw(img) |
|
bbox = self.font.getbbox(text) |
|
w, h = bbox[2] - bbox[0], bbox[3] - bbox[1] |
|
draw.text(((IMAGE_WIDTH - w) // 2, (IMAGE_HEIGHT - h) // 2), text, font=self.font, fill=0) |
|
return img |
|
|
|
|
|
|
|
class OCRModel(nn.Module): |
|
def __init__(self, num_classes): |
|
super().__init__() |
|
self.conv = nn.Sequential( |
|
nn.Conv2d(1, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d((2, 2), (2, 1)), |
|
nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d((2, 2), (2, 1)) |
|
) |
|
|
|
|
|
self.rnn = nn.LSTM(64 * 8, 128, bidirectional=True, num_layers=2, batch_first=True) |
|
self.fc = nn.Linear(256, num_classes) |
|
with torch.no_grad(): |
|
self.fc.bias[0] = -5.0 |
|
|
|
|
|
def forward(self, x): |
|
b, c, h, w = x.size() |
|
x = self.conv(x) |
|
x = x.permute(0, 3, 1, 2) |
|
x = x.reshape(b, x.size(1), -1) |
|
x, _ = self.rnn(x) |
|
x = self.fc(x) |
|
return x |
|
def color_char(c, conf): |
|
color_levels = ['\033[31m', '\033[33m', '\033[32m', '\033[36m', '\033[34m', '\033[35m', '\033[0m'] |
|
idx = min(int(conf * (len(color_levels) - 1)), len(color_levels) - 1) |
|
return f"{color_levels[idx]}{c}\033[0m" |
|
|
|
def sanitize_filename(name): |
|
return re.sub(r'[^a-zA-Z0-9_-]', '_', name) |
|
|
|
def greedy_decode(log_probs): |
|
|
|
|
|
pred = log_probs.argmax(2).squeeze(1).tolist() |
|
print(f"Decoded indices: {pred}") |
|
|
|
decoded = [] |
|
prev = BLANK_IDX |
|
for p in pred: |
|
if p != prev and p != BLANK_IDX: |
|
decoded.append(IDX2CHAR.get(p, "")) |
|
prev = p |
|
return ''.join(decoded) |
|
|
|
|
|
|
|
|
|
|
|
def custom_collate_fn(batch): |
|
images, labels, _ = zip(*batch) |
|
images = torch.stack(images, 0) |
|
|
|
flat_labels = [] |
|
label_lengths = [] |
|
|
|
for label in labels: |
|
flat_labels.append(label) |
|
label_lengths.append(len(label)) |
|
|
|
targets = torch.cat(flat_labels) |
|
return images, targets, torch.tensor(label_lengths, dtype=torch.long) |
|
|
|
|
|
|
|
def list_saved_models(): |
|
model_dir = "./models" |
|
if not os.path.exists(model_dir): |
|
return [] |
|
return [f for f in os.listdir(model_dir) if f.endswith(".pth")] |
|
|
|
|
|
|
|
def save_model(model, path): |
|
torch.save(model.state_dict(), path) |
|
|
|
|
|
def load_model(filename): |
|
global ocr_model |
|
model_dir = "./models" |
|
path = os.path.join(model_dir, filename) |
|
|
|
if not os.path.exists(path): |
|
return f"Model file '{path}' does not exist." |
|
|
|
model = OCRModel(num_classes=len(CHAR2IDX)) |
|
model.load_state_dict(torch.load(path, map_location=device)) |
|
model.to(device) |
|
model.eval() |
|
ocr_model = model |
|
return f"Model '{path}' loaded." |
|
|
|
|
|
|
|
def train_model(font_file, epochs=100, learning_rate=0.001): |
|
import time |
|
global font_path, ocr_model |
|
|
|
|
|
os.makedirs("./fonts", exist_ok=True) |
|
os.makedirs("./models", exist_ok=True) |
|
|
|
|
|
font_name = os.path.splitext(os.path.basename(font_file.name))[0] |
|
font_path = f"./fonts/{font_name}.ttf" |
|
with open(font_file.name, "rb") as uploaded: |
|
with open(font_path, "wb") as f: |
|
f.write(uploaded.read()) |
|
|
|
|
|
def get_dataset_for_epoch(epoch): |
|
if epoch < epochs // 3: |
|
label_len = (3, 4) |
|
elif epoch < 2 * epochs // 3: |
|
label_len = (4, 6) |
|
else: |
|
label_len = (5, 7) |
|
return OCRDataset(font_path, label_length_range=label_len) |
|
|
|
|
|
dataset = get_dataset_for_epoch(0) |
|
img, label, _ = dataset[0] |
|
print("Label:", ''.join([IDX2CHAR[i.item()] for i in label])) |
|
plt.imshow(img.permute(1, 2, 0).squeeze(), cmap='gray') |
|
plt.show() |
|
|
|
|
|
model = OCRModel(num_classes=len(CHAR2IDX)).to(device) |
|
criterion = nn.CTCLoss(blank=BLANK_IDX) |
|
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) |
|
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5) |
|
|
|
for epoch in range(epochs): |
|
dataset = get_dataset_for_epoch(epoch) |
|
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, collate_fn=custom_collate_fn) |
|
|
|
model.train() |
|
running_loss = 0.0 |
|
|
|
|
|
if epoch < 5: |
|
warmup_lr = learning_rate * 0.2 |
|
for param_group in optimizer.param_groups: |
|
param_group['lr'] = warmup_lr |
|
else: |
|
for param_group in optimizer.param_groups: |
|
param_group['lr'] = learning_rate |
|
|
|
for img, targets, target_lengths in dataloader: |
|
img = img.to(device) |
|
targets = targets.to(device) |
|
target_lengths = target_lengths.to(device) |
|
|
|
output = model(img) |
|
seq_len = output.size(1) |
|
batch_size = img.size(0) |
|
input_lengths = torch.full((batch_size,), seq_len, dtype=torch.long).to(device) |
|
|
|
log_probs = output.log_softmax(2).transpose(0, 1) |
|
loss = criterion(log_probs, targets, input_lengths, target_lengths) |
|
|
|
optimizer.zero_grad() |
|
loss.backward() |
|
optimizer.step() |
|
|
|
running_loss += loss.item() |
|
|
|
avg_loss = running_loss / len(dataloader) |
|
scheduler.step(avg_loss) |
|
print(f"[{epoch + 1}/{epochs}] Loss: {avg_loss:.4f}") |
|
|
|
|
|
timestamp = time.strftime("%Y%m%d%H%M%S") |
|
model_name = f"{font_name}_{epochs}ep_lr{learning_rate:.0e}_{timestamp}.pth" |
|
model_path = os.path.join("./models", model_name) |
|
save_model(model, model_path) |
|
|
|
ocr_model = model |
|
return f"✅ Training complete! Model saved as '{model_path}'" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def preprocess_image(image: Image.Image): |
|
img_cv = np.array(image.convert("L")) |
|
|
|
img_bin = cv2.adaptiveThreshold(img_cv, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, |
|
cv2.THRESH_BINARY_INV, 25, 15) |
|
|
|
|
|
white_px = (img_bin == 255).sum() |
|
black_px = (img_bin == 0).sum() |
|
if black_px > white_px: |
|
img_bin = 255 - img_bin |
|
|
|
|
|
h, w = img_bin.shape |
|
scale = IMAGE_HEIGHT / h |
|
new_w = int(w * scale) |
|
resized = cv2.resize(img_bin, (new_w, IMAGE_HEIGHT), interpolation=cv2.INTER_AREA) |
|
|
|
if new_w < IMAGE_WIDTH: |
|
pad_width = IMAGE_WIDTH - new_w |
|
padded = np.pad(resized, ((0, 0), (0, pad_width)), constant_values=255) |
|
else: |
|
padded = resized[:, :IMAGE_WIDTH] |
|
|
|
return to_pil_image(padded) |
|
|
|
|
|
|
|
|
|
|
|
CONFIDENCE_COLORS = [ |
|
"#FF0000", |
|
"#FF7F00", |
|
"#FFFF00", |
|
"#00FF00", |
|
"#00BFFF", |
|
"#0000FF", |
|
"#8B00FF", |
|
] |
|
|
|
def confidence_to_color(conf): |
|
""" |
|
Map confidence (0.0–1.0) to a ROYGBIV-style hex color. |
|
""" |
|
index = min(int(conf * (len(CONFIDENCE_COLORS) - 1)), len(CONFIDENCE_COLORS) - 1) |
|
return CONFIDENCE_COLORS[index] |
|
|
|
def color_char(c, conf): |
|
""" |
|
Wrap character `c` in a span tag with color mapped from `conf`. |
|
""" |
|
color = confidence_to_color(conf) |
|
return f'<span style="color:{color}; font-size:12pt; font-weight:bold;">{c}</span>' |
|
|
|
|
|
|
|
|
|
|
|
def predict_text(image: Image.Image, ground_truth: str = None, debug: bool = False): |
|
if ocr_model is None: |
|
return "Please load or train a model first." |
|
|
|
processed = preprocess_image(image) |
|
transform = transforms.Compose([ |
|
transforms.ToTensor(), |
|
transforms.Normalize((0.5,), (0.5,)) |
|
]) |
|
img_tensor = transform(processed).unsqueeze(0).to(device) |
|
|
|
ocr_model.eval() |
|
with torch.no_grad(): |
|
output = ocr_model(img_tensor) |
|
log_probs = output.log_softmax(2)[0] |
|
|
|
|
|
pred_text_raw = decoder.decode(log_probs.cpu().numpy()) |
|
pred_chars = pred_text_raw.replace("<BLANK>", "") |
|
|
|
pred_text = ''.join([c for c in pred_chars if c != "<BLANK>"]) |
|
|
|
|
|
probs = log_probs.exp() |
|
max_probs = probs.max(dim=1)[0] |
|
avg_conf = max_probs.mean().item() |
|
|
|
|
|
colorized_chars = [color_char(c, avg_conf) for c in pred_text] |
|
pretty_output = ''.join(colorized_chars) |
|
|
|
sim_score = "" |
|
if ground_truth: |
|
similarity = SequenceMatcher(None, ground_truth, pred_text).ratio() |
|
sim_score = f"<br><strong>Levenshtein Similarity:</strong> {similarity:.2%}" |
|
|
|
if debug: |
|
print("Decoded Text:", pred_text) |
|
print("Average Confidence:", avg_conf) |
|
if ground_truth: |
|
print("Ground Truth:", ground_truth) |
|
|
|
return f"<strong>Prediction:</strong> <strong>{pretty_output}</strong><br><strong>Confidence:</strong> {avg_conf:.2%}{sim_score}" |
|
|
|
|
|
|
|
|
|
CHARS = string.ascii_letters + string.digits + string.punctuation |
|
|
|
FONT_SIZE = 32 |
|
PADDING = 8 |
|
LABEL_DIR = "./labels" |
|
|
|
def generate_labels(font_file=None, num_labels: int = 25): |
|
global font_path |
|
|
|
try: |
|
if font_file and font_file != "None": |
|
font_path = os.path.abspath(font_file) |
|
else: |
|
font_path = None |
|
|
|
if font_path is None or not os.path.exists(font_path): |
|
font = ImageFont.load_default() |
|
else: |
|
font = ImageFont.truetype(font_path, 32) |
|
|
|
os.makedirs("./labels", exist_ok=True) |
|
labels = ["".join(np.random.choice(list(CHARS), np.random.randint(4, 7))) for _ in range(num_labels)] |
|
images = [] |
|
|
|
for label in labels: |
|
bbox = font.getbbox(label) |
|
text_w = bbox[2] - bbox[0] |
|
text_h = bbox[3] - bbox[1] |
|
pad = 8 |
|
img_w = text_w + pad * 2 |
|
img_h = text_h + pad * 2 |
|
|
|
img = Image.new("L", (img_w, img_h), color=255) |
|
draw = ImageDraw.Draw(img) |
|
draw.text((pad, pad), label, font=font, fill=0) |
|
|
|
safe_label = sanitize_filename(label) |
|
timestamp = datetime.now().strftime("%Y%m%d%H%M%S%f") |
|
label_dir = os.path.join("./labels", safe_label) |
|
os.makedirs(label_dir, exist_ok=True) |
|
|
|
filepath = os.path.join(label_dir, f"{timestamp}.png") |
|
img.save(filepath) |
|
|
|
images.append(img) |
|
|
|
return images |
|
|
|
except Exception as e: |
|
print("Error in generate_labels:", e) |
|
error_img = Image.new("RGB", (512, 128), color=(255, 255, 255)) |
|
draw = ImageDraw.Draw(error_img) |
|
draw.text((10, 50), f"Error: {str(e)}", fill=(255, 0, 0)) |
|
return [error_img] |
|
|
|
def list_fonts(): |
|
font_dir = "./fonts" |
|
if not os.path.exists(font_dir): |
|
return ["None"] |
|
fonts = [ |
|
(f, os.path.join(font_dir, f)) for f in os.listdir(font_dir) |
|
if f.lower().endswith((".ttf", ".otf")) |
|
] |
|
return [("None", "None")] + fonts |
|
|
|
|
|
custom_css = """ |
|
#label-gallery .gallery-item img { |
|
height: 43px; /* 32pt ≈ 43px */ |
|
width: auto; |
|
object-fit: contain; |
|
padding: 4px; |
|
} |
|
|
|
#label-gallery { |
|
flex-grow: 1; |
|
overflow-y: auto; |
|
height: 100%; |
|
} |
|
#output-text { |
|
font-size: 12pt; |
|
} |
|
""" |
|
|
|
|
|
with gr.Blocks(css=custom_css) as demo: |
|
with gr.Tab("【Train OCR Model】"): |
|
font_file = gr.File(label="Upload .ttf or .otf font", file_types=[".ttf", ".otf"]) |
|
epochs_input = gr.Slider(minimum=1, maximum=4096, value=256, step=1, label="Epochs") |
|
lr_input = gr.Slider(minimum=0.001, maximum=0.1, value=0.05, step=0.001, label="Learning Rate") |
|
train_button = gr.Button("Train OCR Model") |
|
train_status = gr.Textbox(label="Status") |
|
|
|
train_button.click(fn=train_model, inputs=[font_file, epochs_input, lr_input], outputs=train_status) |
|
|
|
with gr.Tab("【Generate Labels】"): |
|
font_file_labels = gr.Dropdown( |
|
choices=list_fonts(), |
|
label="Optional font for label image", |
|
interactive=True, |
|
) |
|
num_labels = gr.Number(value=20, label="Number of labels to generate", precision=0, interactive=True) |
|
gen_button = gr.Button("Generate Label Grid") |
|
|
|
gen_button.click( |
|
fn=generate_labels, |
|
inputs=[font_file_labels, num_labels], |
|
outputs=gr.Gallery( |
|
label="Generated Labels", |
|
columns=16, |
|
object_fit="contain", |
|
height="100%", |
|
elem_id="label-gallery" |
|
) |
|
|
|
) |
|
with gr.Tab("【Recognize Text】"): |
|
model_list = gr.Dropdown(choices=list_saved_models(), label="Select OCR Model") |
|
refresh_btn = gr.Button("🔄 Refresh Models") |
|
load_model_btn = gr.Button("Load Model") |
|
|
|
image_input = gr.Image(type="pil", label="Upload word strip") |
|
predict_btn = gr.Button("Predict") |
|
output_text = gr.HTML(label="Recognized Text", elem_id="output-text") |
|
model_status = gr.Textbox(label="Model Load Status") |
|
|
|
|
|
refresh_btn.click(fn=lambda: gr.update(choices=list_saved_models()), outputs=model_list) |
|
|
|
|
|
load_model_btn.click(fn=load_model, inputs=model_list, outputs=model_status) |
|
|
|
predict_btn.click(fn=predict_text, inputs=image_input, outputs=output_text) |
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|