import os import json from PIL import Image import torch import torch.nn as nn from torchvision import models, transforms from huggingface_hub import snapshot_download import gradio as gr # -------- Model Definition -------- class ChineseClassifier(nn.Module): def __init__(self, embed_dim, num_classes, pretrainedEncoder=True, unfreezeEncoder=True): super().__init__() resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT) if pretrainedEncoder else models.resnet50() self.resnet = nn.Sequential(*list(resnet.children())[:-1]) for param in self.resnet.parameters(): param.requires_grad = unfreezeEncoder self.fc = nn.Linear(resnet.fc.in_features, embed_dim) self.batch_norm = nn.BatchNorm1d(embed_dim) self.dropout = nn.Dropout(0.3) self.classifier = nn.Linear(embed_dim, num_classes) def forward(self, x, return_embedding=False): x = self.resnet(x) x = torch.flatten(x, 1) x = self.fc(x) x = self.batch_norm(x) x = self.dropout(x) if return_embedding: return x x = self.classifier(x) return x # -------- Utility Functions -------- def get_sorted_classes(labels_dict): """Extract sorted unique classes from labels dictionary""" return sorted(set(labels_dict.values())) def load_labels_json(labels_json_path): """Load and normalize labels JSON""" with open(labels_json_path, "r", encoding="utf-8") as f: labels_dict = json.load(f) # Normalize paths and remove directory prefixes return {os.path.basename(k).replace("\\", "/"): v for k, v in labels_dict.items()} def prepare_transforms(): return transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) def load_model(model_path, embed_dim, num_classes, device, pretrained=True, unfreeze=True): model = ChineseClassifier(embed_dim, num_classes, pretrainedEncoder=pretrained, unfreezeEncoder=unfreeze).to(device) checkpoint = torch.load(model_path, map_location=device) if "model_state_dict" in checkpoint: try: model.load_state_dict(checkpoint["model_state_dict"]) except RuntimeError as e: print("Warning:", e) print("Loading partial weights, skipping classifier layer...") filtered_state_dict = {k: v for k, v in checkpoint["model_state_dict"].items() if not k.startswith("classifier.")} model.load_state_dict(filtered_state_dict, strict=False) else: model.load_state_dict(checkpoint) model.eval() return model # -------- Setup -------- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") EMBED_DIM = 512 LABELS_JSON_PATH = "labels.json" # 1. Load labels and extract sorted classes labels_dict = load_labels_json(LABELS_JSON_PATH) classes = get_sorted_classes(labels_dict) idx_to_class = {idx: c for idx, c in enumerate(classes)} num_classes = len(classes) # Verify class count matches training print(f"Loaded {num_classes} classes") print(f"First 5 classes: {classes[:5]}") # 2. Download model REPO_ID = "JJJHHHH/CCR_EthicalSplit_Finetune" print("Downloading model from repo...") repo_dir = snapshot_download(repo_id=REPO_ID) model_path = os.path.join(repo_dir, "CCR_EthicalSplit_Finetune.pth") print(f"Model path: {model_path}") # 3. Load model model = load_model(model_path, EMBED_DIM, num_classes, DEVICE) transform = prepare_transforms() # -------- Prediction Function -------- def predict(pil_img): """Predict character from PIL image""" img_t = transform(pil_img).unsqueeze(0).to(DEVICE) with torch.no_grad(): output = model(img_t) pred_idx = output.argmax(dim=1).item() pred_label = idx_to_class[pred_idx] return pred_label # -------- Gradio Interface -------- gr.Interface( fn=predict, inputs=gr.Image(type="pil", label="Upload Handwritten Chinese Character"), outputs=gr.Text(label="Predicted Character"), title="Chinese Character Recognition", description="Recognizes handwritten Chinese characters with 80% accuracy", ).launch()