import os import torch import gradio as gr from PIL import Image from torchvision.transforms import transforms from modelscope import snapshot_download MODEL_DIR = snapshot_download("MuGemSt/HEp2", cache_dir="./__pycache__") TRANSLATE = { "Centromere": "着丝粒 Centromere", "Golgi": "高尔基体 Golgi", "Homogeneous": "同质 Homogeneous", "NuMem": "记忆体 NuMem", "Nucleolar": "核仁 Nucleolar", "Speckled": "斑核 Speckled", } CLASSES = list(TRANSLATE.keys()) def embeding(img_path: str): compose = transforms.Compose( [ transforms.Resize(224), transforms.CenterCrop(224), transforms.RandomAffine(5), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ] ) img = Image.open(img_path).convert("RGB") return compose(img) def infer(target: str): model = torch.load(f"{MODEL_DIR}/save.pt", map_location=torch.device("cpu")) if not target: return None, "请上传细胞图片 Please upload a cell picture!" torch.cuda.empty_cache() input: torch.Tensor = embeding(target) output: torch.Tensor = model(input.unsqueeze(0)) predict = torch.max(output.data, 1)[1] return os.path.basename(target), TRANSLATE[CLASSES[predict]] if __name__ == "__main__": example_imgs = [] for cls in CLASSES: example_imgs.append(f"{MODEL_DIR}/examples/{cls}.png") with gr.Blocks() as demo: gr.Interface( fn=infer, inputs=gr.Image( type="filepath", label="上传细胞图像 Upload a cell picture" ), outputs=[ gr.Textbox(label="图片名 Picture name", show_copy_button=True), gr.Textbox(label="识别结果 Recognition result", show_copy_button=True), ], title="请上传 PNG 格式的 HEp2 细胞图片
It is recommended to upload HEp2 cell images in PNG format.", examples=example_imgs, allow_flagging="never", cache_examples=False, ) demo.launch()