File size: 4,252 Bytes
b5c1972
05c366d
faac63e
b5c1972
 
 
0cb7f53
b5c1972
 
faac63e
b5c1972
 
 
faac63e
b5c1972
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
faac63e
 
0cb7f53
faac63e
3b4987c
 
 
0f763d7
 
3b4987c
0f763d7
 
3b4987c
 
b5c1972
 
 
 
 
 
 
 
 
faac63e
 
 
 
 
 
 
 
 
 
 
 
 
b5c1972
 
 
3b4987c
faac63e
3b4987c
 
faac63e
3b4987c
 
 
0f763d7
 
faac63e
3b4987c
 
 
0f763d7
3b4987c
0cb7f53
6012f41
585a1e2
 
3b4987c
b5c1972
3b4987c
faac63e
b5c1972
0cb7f53
faac63e
 
3b4987c
faac63e
b5c1972
faac63e
b5c1972
faac63e
 
0cb7f53
faac63e
585a1e2
 
6197db5
 
585a1e2
3b4987c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import os
import json
from PIL import Image
import torch
import torch.nn as nn
from torchvision import models, transforms
from huggingface_hub import snapshot_download
import gradio as gr

# -------- Model Definition --------
class ChineseClassifier(nn.Module):
    def __init__(self, embed_dim, num_classes, pretrainedEncoder=True, unfreezeEncoder=True):
        super().__init__()
        resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT) if pretrainedEncoder else models.resnet50()
        self.resnet = nn.Sequential(*list(resnet.children())[:-1])
        for param in self.resnet.parameters():
            param.requires_grad = unfreezeEncoder
        self.fc = nn.Linear(resnet.fc.in_features, embed_dim)
        self.batch_norm = nn.BatchNorm1d(embed_dim)
        self.dropout = nn.Dropout(0.3)
        self.classifier = nn.Linear(embed_dim, num_classes)

    def forward(self, x, return_embedding=False):
        x = self.resnet(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        x = self.batch_norm(x)
        x = self.dropout(x)
        if return_embedding:
            return x
        x = self.classifier(x)
        return x

# -------- Utility Functions --------
def get_sorted_classes(labels_dict):
    """Extract sorted unique classes from labels dictionary"""
    return sorted(set(labels_dict.values()))

def load_labels_json(labels_json_path):
    """Load and normalize labels JSON"""
    with open(labels_json_path, "r", encoding="utf-8") as f:
        labels_dict = json.load(f)
    # Normalize paths and remove directory prefixes
    return {os.path.basename(k).replace("\\", "/"): v for k, v in labels_dict.items()}

def prepare_transforms():
    return transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])

def load_model(model_path, embed_dim, num_classes, device, pretrained=True, unfreeze=True):
    model = ChineseClassifier(embed_dim, num_classes, pretrainedEncoder=pretrained, unfreezeEncoder=unfreeze).to(device)
    checkpoint = torch.load(model_path, map_location=device)
    if "model_state_dict" in checkpoint:
        try:
            model.load_state_dict(checkpoint["model_state_dict"])
        except RuntimeError as e:
            print("Warning:", e)
            print("Loading partial weights, skipping classifier layer...")
            filtered_state_dict = {k: v for k, v in checkpoint["model_state_dict"].items() if not k.startswith("classifier.")}
            model.load_state_dict(filtered_state_dict, strict=False)
    else:
        model.load_state_dict(checkpoint)
    model.eval()
    return model

# -------- Setup --------
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
EMBED_DIM = 512
LABELS_JSON_PATH = "labels.json"

# 1. Load labels and extract sorted classes
labels_dict = load_labels_json(LABELS_JSON_PATH)
classes = get_sorted_classes(labels_dict)
idx_to_class = {idx: c for idx, c in enumerate(classes)}
num_classes = len(classes)

# Verify class count matches training
print(f"Loaded {num_classes} classes")
print(f"First 5 classes: {classes[:5]}")

# 2. Download model
REPO_ID = "JJJHHHH/CCR_EthicalSplit_Finetune"
print("Downloading model from repo...")
repo_dir = snapshot_download(repo_id=REPO_ID)
model_path = os.path.join(repo_dir, "CCR_EthicalSplit_Finetune.pth")
print(f"Model path: {model_path}")

# 3. Load model
model = load_model(model_path, EMBED_DIM, num_classes, DEVICE)
transform = prepare_transforms()

# -------- Prediction Function --------
def predict(pil_img):
    """Predict character from PIL image"""
    img_t = transform(pil_img).unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        output = model(img_t)
        pred_idx = output.argmax(dim=1).item()
        pred_label = idx_to_class[pred_idx]
    return pred_label

# -------- Gradio Interface --------
gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil", label="Upload Handwritten Chinese Character"),
    outputs=gr.Text(label="Predicted Character"),
    title="Chinese Character Recognition",
    description="Recognizes handwritten Chinese characters with 80% accuracy",
).launch()