Spaces:
Sleeping
Sleeping
File size: 4,252 Bytes
b5c1972 05c366d faac63e b5c1972 0cb7f53 b5c1972 faac63e b5c1972 faac63e b5c1972 faac63e 0cb7f53 faac63e 3b4987c 0f763d7 3b4987c 0f763d7 3b4987c b5c1972 faac63e b5c1972 3b4987c faac63e 3b4987c faac63e 3b4987c 0f763d7 faac63e 3b4987c 0f763d7 3b4987c 0cb7f53 6012f41 585a1e2 3b4987c b5c1972 3b4987c faac63e b5c1972 0cb7f53 faac63e 3b4987c faac63e b5c1972 faac63e b5c1972 faac63e 0cb7f53 faac63e 585a1e2 6197db5 585a1e2 3b4987c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
import os
import json
from PIL import Image
import torch
import torch.nn as nn
from torchvision import models, transforms
from huggingface_hub import snapshot_download
import gradio as gr
# -------- Model Definition --------
class ChineseClassifier(nn.Module):
def __init__(self, embed_dim, num_classes, pretrainedEncoder=True, unfreezeEncoder=True):
super().__init__()
resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT) if pretrainedEncoder else models.resnet50()
self.resnet = nn.Sequential(*list(resnet.children())[:-1])
for param in self.resnet.parameters():
param.requires_grad = unfreezeEncoder
self.fc = nn.Linear(resnet.fc.in_features, embed_dim)
self.batch_norm = nn.BatchNorm1d(embed_dim)
self.dropout = nn.Dropout(0.3)
self.classifier = nn.Linear(embed_dim, num_classes)
def forward(self, x, return_embedding=False):
x = self.resnet(x)
x = torch.flatten(x, 1)
x = self.fc(x)
x = self.batch_norm(x)
x = self.dropout(x)
if return_embedding:
return x
x = self.classifier(x)
return x
# -------- Utility Functions --------
def get_sorted_classes(labels_dict):
"""Extract sorted unique classes from labels dictionary"""
return sorted(set(labels_dict.values()))
def load_labels_json(labels_json_path):
"""Load and normalize labels JSON"""
with open(labels_json_path, "r", encoding="utf-8") as f:
labels_dict = json.load(f)
# Normalize paths and remove directory prefixes
return {os.path.basename(k).replace("\\", "/"): v for k, v in labels_dict.items()}
def prepare_transforms():
return transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
def load_model(model_path, embed_dim, num_classes, device, pretrained=True, unfreeze=True):
model = ChineseClassifier(embed_dim, num_classes, pretrainedEncoder=pretrained, unfreezeEncoder=unfreeze).to(device)
checkpoint = torch.load(model_path, map_location=device)
if "model_state_dict" in checkpoint:
try:
model.load_state_dict(checkpoint["model_state_dict"])
except RuntimeError as e:
print("Warning:", e)
print("Loading partial weights, skipping classifier layer...")
filtered_state_dict = {k: v for k, v in checkpoint["model_state_dict"].items() if not k.startswith("classifier.")}
model.load_state_dict(filtered_state_dict, strict=False)
else:
model.load_state_dict(checkpoint)
model.eval()
return model
# -------- Setup --------
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
EMBED_DIM = 512
LABELS_JSON_PATH = "labels.json"
# 1. Load labels and extract sorted classes
labels_dict = load_labels_json(LABELS_JSON_PATH)
classes = get_sorted_classes(labels_dict)
idx_to_class = {idx: c for idx, c in enumerate(classes)}
num_classes = len(classes)
# Verify class count matches training
print(f"Loaded {num_classes} classes")
print(f"First 5 classes: {classes[:5]}")
# 2. Download model
REPO_ID = "JJJHHHH/CCR_EthicalSplit_Finetune"
print("Downloading model from repo...")
repo_dir = snapshot_download(repo_id=REPO_ID)
model_path = os.path.join(repo_dir, "CCR_EthicalSplit_Finetune.pth")
print(f"Model path: {model_path}")
# 3. Load model
model = load_model(model_path, EMBED_DIM, num_classes, DEVICE)
transform = prepare_transforms()
# -------- Prediction Function --------
def predict(pil_img):
"""Predict character from PIL image"""
img_t = transform(pil_img).unsqueeze(0).to(DEVICE)
with torch.no_grad():
output = model(img_t)
pred_idx = output.argmax(dim=1).item()
pred_label = idx_to_class[pred_idx]
return pred_label
# -------- Gradio Interface --------
gr.Interface(
fn=predict,
inputs=gr.Image(type="pil", label="Upload Handwritten Chinese Character"),
outputs=gr.Text(label="Predicted Character"),
title="Chinese Character Recognition",
description="Recognizes handwritten Chinese characters with 80% accuracy",
).launch() |