Spaces:
Sleeping
Sleeping
import os | |
import json | |
from PIL import Image | |
import torch | |
import torch.nn as nn | |
from torchvision import models, transforms | |
from huggingface_hub import snapshot_download | |
import gradio as gr | |
# -------- Model Definition -------- | |
class ChineseClassifier(nn.Module): | |
def __init__(self, embed_dim, num_classes, pretrainedEncoder=True, unfreezeEncoder=True): | |
super().__init__() | |
resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT) if pretrainedEncoder else models.resnet50() | |
self.resnet = nn.Sequential(*list(resnet.children())[:-1]) | |
for param in self.resnet.parameters(): | |
param.requires_grad = unfreezeEncoder | |
self.fc = nn.Linear(resnet.fc.in_features, embed_dim) | |
self.batch_norm = nn.BatchNorm1d(embed_dim) | |
self.dropout = nn.Dropout(0.3) | |
self.classifier = nn.Linear(embed_dim, num_classes) | |
def forward(self, x, return_embedding=False): | |
x = self.resnet(x) | |
x = torch.flatten(x, 1) | |
x = self.fc(x) | |
x = self.batch_norm(x) | |
x = self.dropout(x) | |
if return_embedding: | |
return x | |
x = self.classifier(x) | |
return x | |
# -------- Utility Functions -------- | |
def get_sorted_classes(labels_dict): | |
"""Extract sorted unique classes from labels dictionary""" | |
return sorted(set(labels_dict.values())) | |
def load_labels_json(labels_json_path): | |
"""Load and normalize labels JSON""" | |
with open(labels_json_path, "r", encoding="utf-8") as f: | |
labels_dict = json.load(f) | |
# Normalize paths and remove directory prefixes | |
return {os.path.basename(k).replace("\\", "/"): v for k, v in labels_dict.items()} | |
def prepare_transforms(): | |
return transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]), | |
]) | |
def load_model(model_path, embed_dim, num_classes, device, pretrained=True, unfreeze=True): | |
model = ChineseClassifier(embed_dim, num_classes, pretrainedEncoder=pretrained, unfreezeEncoder=unfreeze).to(device) | |
checkpoint = torch.load(model_path, map_location=device) | |
if "model_state_dict" in checkpoint: | |
try: | |
model.load_state_dict(checkpoint["model_state_dict"]) | |
except RuntimeError as e: | |
print("Warning:", e) | |
print("Loading partial weights, skipping classifier layer...") | |
filtered_state_dict = {k: v for k, v in checkpoint["model_state_dict"].items() if not k.startswith("classifier.")} | |
model.load_state_dict(filtered_state_dict, strict=False) | |
else: | |
model.load_state_dict(checkpoint) | |
model.eval() | |
return model | |
# -------- Setup -------- | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
EMBED_DIM = 512 | |
LABELS_JSON_PATH = "labels.json" | |
# 1. Load labels and extract sorted classes | |
labels_dict = load_labels_json(LABELS_JSON_PATH) | |
classes = get_sorted_classes(labels_dict) | |
idx_to_class = {idx: c for idx, c in enumerate(classes)} | |
num_classes = len(classes) | |
# Verify class count matches training | |
print(f"Loaded {num_classes} classes") | |
print(f"First 5 classes: {classes[:5]}") | |
# 2. Download model | |
REPO_ID = "JJJHHHH/CCR_EthicalSplit_Finetune" | |
print("Downloading model from repo...") | |
repo_dir = snapshot_download(repo_id=REPO_ID) | |
model_path = os.path.join(repo_dir, "CCR_EthicalSplit_Finetune.pth") | |
print(f"Model path: {model_path}") | |
# 3. Load model | |
model = load_model(model_path, EMBED_DIM, num_classes, DEVICE) | |
transform = prepare_transforms() | |
# -------- Prediction Function -------- | |
def predict(pil_img): | |
"""Predict character from PIL image""" | |
img_t = transform(pil_img).unsqueeze(0).to(DEVICE) | |
with torch.no_grad(): | |
output = model(img_t) | |
pred_idx = output.argmax(dim=1).item() | |
pred_label = idx_to_class[pred_idx] | |
return pred_label | |
# -------- Gradio Interface -------- | |
gr.Interface( | |
fn=predict, | |
inputs=gr.Image(type="pil", label="Upload Handwritten Chinese Character"), | |
outputs=gr.Text(label="Predicted Character"), | |
title="Chinese Character Recognition", | |
description="Recognizes handwritten Chinese characters with 80% accuracy", | |
).launch() |