Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| import pandas as pd | |
| import logging | |
| import argparse | |
| import numpy as np | |
| from tqdm import tqdm | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| # Suppress TensorFlow logging | |
| os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' | |
| try: | |
| import absl.logging | |
| absl.logging.set_verbosity(absl.logging.ERROR) | |
| except ImportError: | |
| pass | |
| import logging | |
| logging.getLogger('tensorflow').setLevel(logging.ERROR) | |
| from model import RawImageModel, PrecomputedModel | |
| from dicom_utils import read_dicom_image | |
| from PIL import Image | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Evaluate on Kaggle DICOM Dataset") | |
| parser.add_argument("--csv", default="data/kaggle/labels.csv", help="Path to labels CSV") | |
| parser.add_argument("--data-dir", default="data/kaggle", help="Root directory for images if relative paths in CSV") | |
| parser.add_argument("--output", default="results/kaggle_predictions.csv", help="Output predictions file") | |
| args = parser.parse_args() | |
| # Create output directory | |
| os.makedirs(os.path.dirname(args.output), exist_ok=True) | |
| # Load dataset | |
| try: | |
| df = pd.read_csv(args.csv) | |
| logger.info(f"Loaded {len(df)} records from {args.csv}") | |
| except Exception as e: | |
| logger.error(f"Failed to load CSV: {e}") | |
| return | |
| # Check for file column | |
| file_col = 'file' if 'file' in df.columns else 'dicom_file' # Adapt to potential column names | |
| if file_col not in df.columns and 'file' not in df.columns: | |
| # Fallback inspection or error | |
| logger.error(f"Missing file column in CSV. Found: {df.columns}") | |
| return | |
| # Initialize Models | |
| try: | |
| # We need PrecomputedModel for text embeddings (labels) | |
| precomputed_model = PrecomputedModel() | |
| # We need RawImageModel for the images | |
| raw_model = RawImageModel() | |
| logger.info("Models loaded successfully.") | |
| except Exception as e: | |
| logger.fatal(f"Failed to initialize models: {e}") | |
| return | |
| # Get text embeddings for diagnosis | |
| diagnosis = 'PNEUMOTHORAX' | |
| try: | |
| # Hardcoded prompts matching main.py | |
| pos_txt = 'small pneumothorax' | |
| neg_txt = 'no pneumothorax' | |
| pos_emb, neg_emb = precomputed_model.get_diagnosis_embeddings(pos_txt, neg_txt) | |
| except Exception as e: | |
| logger.fatal(f"Failed to get text embeddings: {e}") | |
| return | |
| predictions = [] | |
| # Iterate and predict | |
| print(f"Running inference for {diagnosis} on {len(df)} images...") | |
| temp_path = "temp_inference.png" | |
| for _, row in tqdm(df.iterrows(), total=len(df)): | |
| file_path = row[file_col] | |
| # Construct full path | |
| full_path = os.path.join(args.data_dir, file_path) if not os.path.isabs(file_path) else file_path | |
| # Check if file exists | |
| if not os.path.exists(full_path): | |
| logger.warning(f"File not found: {full_path}") | |
| predictions.append({ | |
| 'file': file_path, | |
| 'true_label': None, | |
| 'pneumothorax_score': None, | |
| 'error': 'File not found' | |
| }) | |
| continue | |
| true_label = row.get('label', row.get('PNEUMOTHORAX', 'Unknown')) | |
| try: | |
| # 1. Read DICOM | |
| image_array = read_dicom_image(full_path) | |
| # 2. Save as temp PNG (Required by RawImageModel/TF pipeline currently) | |
| Image.fromarray(image_array).save(temp_path) | |
| # 3. Compute Image Embedding | |
| img_emb = raw_model.compute_embeddings(temp_path) | |
| # 4. Compute Zero-Shot Score | |
| score = PrecomputedModel.zero_shot(img_emb, pos_emb, neg_emb) | |
| predictions.append({ | |
| 'file': file_path, | |
| 'true_label': true_label, | |
| 'pneumothorax_score': float(score) | |
| }) | |
| except Exception as e: | |
| # logger.warning(f"Failed to process {file_path}: {e}") | |
| predictions.append({ | |
| 'file': file_path, | |
| 'true_label': true_label, | |
| 'pneumothorax_score': None, | |
| 'error': str(e) | |
| }) | |
| # Incremental Save every 10 items | |
| if len(predictions) % 10 == 0: | |
| pd.DataFrame(predictions).to_csv(args.output, index=False) | |
| # Final Save | |
| results_df = pd.DataFrame(predictions) | |
| results_df.to_csv(args.output, index=False) | |
| logger.info(f"Predictions saved to {args.output}") | |
| # Cleanup | |
| if os.path.exists("temp_inference.png"): | |
| os.remove("temp_inference.png") | |
| if __name__ == "__main__": | |
| main() | |