CCR_OCR / app.py
JJJHHHH's picture
Update app.py
3b4987c verified
import os
import json
from PIL import Image
import torch
import torch.nn as nn
from torchvision import models, transforms
from huggingface_hub import snapshot_download
import gradio as gr
# -------- Model Definition --------
class ChineseClassifier(nn.Module):
def __init__(self, embed_dim, num_classes, pretrainedEncoder=True, unfreezeEncoder=True):
super().__init__()
resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT) if pretrainedEncoder else models.resnet50()
self.resnet = nn.Sequential(*list(resnet.children())[:-1])
for param in self.resnet.parameters():
param.requires_grad = unfreezeEncoder
self.fc = nn.Linear(resnet.fc.in_features, embed_dim)
self.batch_norm = nn.BatchNorm1d(embed_dim)
self.dropout = nn.Dropout(0.3)
self.classifier = nn.Linear(embed_dim, num_classes)
def forward(self, x, return_embedding=False):
x = self.resnet(x)
x = torch.flatten(x, 1)
x = self.fc(x)
x = self.batch_norm(x)
x = self.dropout(x)
if return_embedding:
return x
x = self.classifier(x)
return x
# -------- Utility Functions --------
def get_sorted_classes(labels_dict):
"""Extract sorted unique classes from labels dictionary"""
return sorted(set(labels_dict.values()))
def load_labels_json(labels_json_path):
"""Load and normalize labels JSON"""
with open(labels_json_path, "r", encoding="utf-8") as f:
labels_dict = json.load(f)
# Normalize paths and remove directory prefixes
return {os.path.basename(k).replace("\\", "/"): v for k, v in labels_dict.items()}
def prepare_transforms():
return transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
def load_model(model_path, embed_dim, num_classes, device, pretrained=True, unfreeze=True):
model = ChineseClassifier(embed_dim, num_classes, pretrainedEncoder=pretrained, unfreezeEncoder=unfreeze).to(device)
checkpoint = torch.load(model_path, map_location=device)
if "model_state_dict" in checkpoint:
try:
model.load_state_dict(checkpoint["model_state_dict"])
except RuntimeError as e:
print("Warning:", e)
print("Loading partial weights, skipping classifier layer...")
filtered_state_dict = {k: v for k, v in checkpoint["model_state_dict"].items() if not k.startswith("classifier.")}
model.load_state_dict(filtered_state_dict, strict=False)
else:
model.load_state_dict(checkpoint)
model.eval()
return model
# -------- Setup --------
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
EMBED_DIM = 512
LABELS_JSON_PATH = "labels.json"
# 1. Load labels and extract sorted classes
labels_dict = load_labels_json(LABELS_JSON_PATH)
classes = get_sorted_classes(labels_dict)
idx_to_class = {idx: c for idx, c in enumerate(classes)}
num_classes = len(classes)
# Verify class count matches training
print(f"Loaded {num_classes} classes")
print(f"First 5 classes: {classes[:5]}")
# 2. Download model
REPO_ID = "JJJHHHH/CCR_EthicalSplit_Finetune"
print("Downloading model from repo...")
repo_dir = snapshot_download(repo_id=REPO_ID)
model_path = os.path.join(repo_dir, "CCR_EthicalSplit_Finetune.pth")
print(f"Model path: {model_path}")
# 3. Load model
model = load_model(model_path, EMBED_DIM, num_classes, DEVICE)
transform = prepare_transforms()
# -------- Prediction Function --------
def predict(pil_img):
"""Predict character from PIL image"""
img_t = transform(pil_img).unsqueeze(0).to(DEVICE)
with torch.no_grad():
output = model(img_t)
pred_idx = output.argmax(dim=1).item()
pred_label = idx_to_class[pred_idx]
return pred_label
# -------- Gradio Interface --------
gr.Interface(
fn=predict,
inputs=gr.Image(type="pil", label="Upload Handwritten Chinese Character"),
outputs=gr.Text(label="Predicted Character"),
title="Chinese Character Recognition",
description="Recognizes handwritten Chinese characters with 80% accuracy",
).launch()