MedPMC
Collection
MedPMC resources, including the data curation pipeline, curated datasets, and trained vision-language models. • 17 items • Updated • 1
How to use Yale-BIDS-Chen/medpmc-multi-fig-detection-vit with timm:
import timm
model = timm.create_model("hf_hub:Yale-BIDS-Chen/medpmc-multi-fig-detection-vit", pretrained=True)This repository provides the vision transformer-based multi-figure detection model used in the MedPMC data curation pipeline.
The model is a binary image classifier trained to predict whether a biomedical figure is a multi-panel / compound figure or a single-panel figure. It is intended for processing figures from biomedical literature, especially figures from PubMed Central (PMC) articles.
The model performs binary image classification.
0: single-panel figure
1: multi-panel / compound figure
import torch
import timm
from PIL import Image
from torchvision import transforms
checkpoint_path = "model.pth.tar"
image_path = "example.jpg"
device = "cuda" if torch.cuda.is_available() else "cpu"
checkpoint = torch.load(checkpoint_path, map_location="cpu")
arch = checkpoint["arch"]
state_dict = checkpoint["state_dict"]
# Remove DataParallel/DDP prefix if present.
state_dict = {
k.replace("module.", "", 1) if k.startswith("module.") else k: v
for k, v in state_dict.items()
}
# Binary classifier.
model = timm.create_model(
arch,
pretrained=False,
num_classes=2,
)
model.load_state_dict(state_dict, strict=True)
model = model.to(device)
model.eval()
preprocess = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
),
])
image = Image.open(image_path).convert("RGB")
inputs = preprocess(image).unsqueeze(0).to(device)
with torch.no_grad():
logits = model(inputs)
probs = torch.softmax(logits, dim=-1)
pred = torch.argmax(probs, dim=-1).item()
print("Prediction:", pred)
print("Probabilities:", probs.cpu().tolist())
Example output:
Prediction: 1
Probabilities: [[0.08, 0.92]]
This means that the model predicts the input image as a multi-panel / compound figure.
import torch
import timm
from PIL import Image
from pathlib import Path
from torchvision import transforms
checkpoint_path = "model.pth.tar"
image_dir = "sample"
device = "cuda" if torch.cuda.is_available() else "cpu"
checkpoint = torch.load(checkpoint_path, map_location="cpu")
arch = checkpoint["arch"]
state_dict = checkpoint["state_dict"]
state_dict = {
k.replace("module.", "", 1) if k.startswith("module.") else k: v
for k, v in state_dict.items()
}
model = timm.create_model(arch, pretrained=False, num_classes=2)
model.load_state_dict(state_dict, strict=True)
model = model.to(device)
model.eval()
preprocess = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
),
])
image_paths = sorted(
list(Path(image_dir).glob("*.jpg")) +
list(Path(image_dir).glob("*.jpeg")) +
list(Path(image_dir).glob("*.png"))
)
for image_path in image_paths:
image = Image.open(image_path).convert("RGB")
inputs = preprocess(image).unsqueeze(0).to(device)
with torch.no_grad():
logits = model(inputs)
probs = torch.softmax(logits, dim=-1)
pred = torch.argmax(probs, dim=-1).item()
print("Image:", image_path)
print("Prediction:", pred)
print("Probabilities:", probs.cpu().tolist())
The model is released for non-commercial research use under CC BY-NC-SA 4.0.
Citation information will be updated soon.