|
|
""" |
|
|
RadFig VQA Image Filtering Model - Inference Script |
|
|
Classifies medical images as suitable/unsuitable for VQA tasks. |
|
|
""" |
|
|
|
|
|
import os |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import timm |
|
|
import cv2 |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
from PIL import Image |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
from albumentations import Compose, Resize, Normalize |
|
|
from albumentations.pytorch import ToTensorV2 |
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
|
class Config: |
|
|
"""Configuration for inference""" |
|
|
model_name = "tf_efficientnetv2_s.in21k_ft_in1k" |
|
|
size = 512 |
|
|
batch_size = 32 |
|
|
num_workers = 4 |
|
|
target_size = 1 |
|
|
n_fold = 5 |
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
|
|
|
class TestDataset(Dataset): |
|
|
"""Dataset for inference""" |
|
|
|
|
|
def __init__(self, image_paths, transform=None): |
|
|
self.image_paths = image_paths |
|
|
self.transform = transform |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.image_paths) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
image_path = self.image_paths[idx] |
|
|
|
|
|
|
|
|
image = cv2.imread(image_path) |
|
|
if image is None: |
|
|
raise ValueError(f"Could not load image: {image_path}") |
|
|
|
|
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
if self.transform: |
|
|
augmented = self.transform(image=image) |
|
|
image = augmented['image'] |
|
|
|
|
|
return image |
|
|
|
|
|
|
|
|
def get_transforms(): |
|
|
"""Get inference transforms""" |
|
|
return Compose([ |
|
|
Resize(Config.size, Config.size), |
|
|
Normalize( |
|
|
mean=[0.485, 0.456, 0.406], |
|
|
std=[0.229, 0.224, 0.225], |
|
|
), |
|
|
ToTensorV2(), |
|
|
]) |
|
|
|
|
|
|
|
|
class RadFigClassifier: |
|
|
"""RadFig VQA Image Filtering Classifier""" |
|
|
|
|
|
def __init__(self, model_dir="models"): |
|
|
self.config = Config() |
|
|
self.model_dir = model_dir |
|
|
self.device = self.config.device |
|
|
self.model = None |
|
|
self.states = [] |
|
|
|
|
|
|
|
|
self._load_model_states() |
|
|
|
|
|
def _load_model_states(self): |
|
|
"""Load all fold model states""" |
|
|
self.states = [] |
|
|
for fold in range(self.config.n_fold): |
|
|
model_path = os.path.join( |
|
|
self.model_dir, |
|
|
f"{self.config.model_name}_fold{fold}_best_loss.pth" |
|
|
) |
|
|
|
|
|
if not os.path.exists(model_path): |
|
|
raise FileNotFoundError(f"Model file not found: {model_path}") |
|
|
|
|
|
state = torch.load(model_path, map_location=self.device) |
|
|
self.states.append(state) |
|
|
|
|
|
print(f"Loaded {len(self.states)} model states from {self.model_dir}") |
|
|
|
|
|
def _create_model(self): |
|
|
"""Create model architecture""" |
|
|
model = timm.create_model( |
|
|
model_name=self.config.model_name, |
|
|
num_classes=self.config.target_size, |
|
|
pretrained=False |
|
|
) |
|
|
return model.to(self.device) |
|
|
|
|
|
def predict_batch(self, image_paths, return_probabilities=True): |
|
|
""" |
|
|
Predict on a batch of images |
|
|
|
|
|
Args: |
|
|
image_paths (list): List of image file paths |
|
|
return_probabilities (bool): If True, return probabilities. If False, return binary predictions. |
|
|
|
|
|
Returns: |
|
|
numpy.ndarray: Predictions (probabilities or binary) |
|
|
""" |
|
|
|
|
|
dataset = TestDataset(image_paths, transform=get_transforms()) |
|
|
dataloader = DataLoader( |
|
|
dataset, |
|
|
batch_size=self.config.batch_size, |
|
|
shuffle=False, |
|
|
num_workers=self.config.num_workers, |
|
|
pin_memory=True |
|
|
) |
|
|
|
|
|
|
|
|
model = self._create_model() |
|
|
|
|
|
all_predictions = [] |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
for images in tqdm(dataloader, desc="Predicting"): |
|
|
images = images.to(self.device) |
|
|
|
|
|
|
|
|
fold_predictions = [] |
|
|
|
|
|
for state in self.states: |
|
|
model.load_state_dict(state['model']) |
|
|
model.eval() |
|
|
|
|
|
outputs = model(images) |
|
|
probabilities = torch.sigmoid(outputs).cpu().numpy() |
|
|
fold_predictions.append(probabilities) |
|
|
|
|
|
|
|
|
avg_predictions = np.mean(fold_predictions, axis=0) |
|
|
all_predictions.append(avg_predictions) |
|
|
|
|
|
|
|
|
predictions = np.concatenate(all_predictions, axis=0).flatten() |
|
|
|
|
|
if return_probabilities: |
|
|
return predictions |
|
|
else: |
|
|
return (predictions > 0.5).astype(int) |
|
|
|
|
|
def predict_single(self, image_path, return_probability=True): |
|
|
""" |
|
|
Predict on a single image |
|
|
|
|
|
Args: |
|
|
image_path (str): Path to image file |
|
|
return_probability (bool): If True, return probability. If False, return binary prediction. |
|
|
|
|
|
Returns: |
|
|
float or int: Prediction |
|
|
""" |
|
|
predictions = self.predict_batch([image_path], return_probabilities=return_probability) |
|
|
return predictions[0] |
|
|
|
|
|
def predict_directory(self, directory_path, output_csv=None, return_probabilities=True): |
|
|
""" |
|
|
Predict on all images in a directory |
|
|
|
|
|
Args: |
|
|
directory_path (str): Path to directory containing images |
|
|
output_csv (str, optional): Path to save results as CSV |
|
|
return_probabilities (bool): If True, return probabilities. If False, return binary predictions. |
|
|
|
|
|
Returns: |
|
|
pandas.DataFrame: Results with image paths and predictions |
|
|
""" |
|
|
|
|
|
image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif'} |
|
|
image_paths = [] |
|
|
|
|
|
for filename in os.listdir(directory_path): |
|
|
if any(filename.lower().endswith(ext) for ext in image_extensions): |
|
|
image_paths.append(os.path.join(directory_path, filename)) |
|
|
|
|
|
if not image_paths: |
|
|
raise ValueError(f"No image files found in {directory_path}") |
|
|
|
|
|
print(f"Found {len(image_paths)} images in {directory_path}") |
|
|
|
|
|
|
|
|
predictions = self.predict_batch(image_paths, return_probabilities=return_probabilities) |
|
|
|
|
|
|
|
|
results = pd.DataFrame({ |
|
|
'image_path': image_paths, |
|
|
'filename': [os.path.basename(path) for path in image_paths], |
|
|
'prediction': predictions, |
|
|
'suitable_for_vqa': predictions > 0.9 if return_probabilities else predictions.astype(bool) |
|
|
}) |
|
|
|
|
|
|
|
|
results = results.sort_values('filename').reset_index(drop=True) |
|
|
|
|
|
|
|
|
if output_csv: |
|
|
results.to_csv(output_csv, index=False) |
|
|
print(f"Results saved to {output_csv}") |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Example usage""" |
|
|
import argparse |
|
|
|
|
|
parser = argparse.ArgumentParser(description="RadFig VQA Image Filtering Inference") |
|
|
parser.add_argument("--input", required=True, help="Input image file or directory") |
|
|
parser.add_argument("--models", default="models", help="Directory containing model files") |
|
|
parser.add_argument("--output", help="Output CSV file (for directory input)") |
|
|
parser.add_argument("--binary", action="store_true", help="Return binary predictions instead of probabilities") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
classifier = RadFigClassifier(model_dir=args.models) |
|
|
|
|
|
if os.path.isfile(args.input): |
|
|
|
|
|
prediction = classifier.predict_single( |
|
|
args.input, |
|
|
return_probability=not args.binary |
|
|
) |
|
|
|
|
|
if args.binary: |
|
|
result = "suitable" if prediction else "not suitable" |
|
|
print(f"Image: {args.input}") |
|
|
print(f"Prediction: {result} for VQA") |
|
|
else: |
|
|
print(f"Image: {args.input}") |
|
|
print(f"Probability suitable for VQA: {prediction:.4f}") |
|
|
print(f"Classification: {'suitable' if prediction > 0.9 else 'not suitable'}") |
|
|
|
|
|
elif os.path.isdir(args.input): |
|
|
|
|
|
results = classifier.predict_directory( |
|
|
args.input, |
|
|
output_csv=args.output, |
|
|
return_probabilities=not args.binary |
|
|
) |
|
|
|
|
|
|
|
|
if args.binary: |
|
|
suitable_count = results['suitable_for_vqa'].sum() |
|
|
else: |
|
|
suitable_count = (results['prediction'] > 0.9).sum() |
|
|
|
|
|
total_count = len(results) |
|
|
|
|
|
print(f"\nSummary:") |
|
|
print(f"Total images: {total_count}") |
|
|
print(f"Suitable for VQA: {suitable_count}") |
|
|
print(f"Not suitable for VQA: {total_count - suitable_count}") |
|
|
print(f"Percentage suitable: {suitable_count/total_count*100:.1f}%") |
|
|
|
|
|
else: |
|
|
print(f"Error: {args.input} is not a valid file or directory") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |