Spaces:
Sleeping
Sleeping
| import os | |
| import argparse | |
| import logging | |
| import sys | |
| # Suppress TensorFlow and system warnings | |
| os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # FATAL | |
| os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0' | |
| import warnings | |
| warnings.filterwarnings('ignore') | |
| import numpy as np | |
| import pandas as pd | |
| # Configure logging first | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| # Suppress absl logging from TensorFlow | |
| try: | |
| import absl.logging | |
| absl.logging.set_verbosity(absl.logging.ERROR) | |
| except ImportError: | |
| pass | |
| # Suppress TensorFlow Python logging | |
| logging.getLogger('tensorflow').setLevel(logging.ERROR) | |
| from model import PrecomputedModel, RawImageModel | |
| from evaluate import evaluate_predictions | |
| DIAGNOSIS_PROMPTS = { | |
| 'AIRSPACE_OPACITY': ('Airspace Opacity', 'no evidence of airspace disease'), | |
| 'PNEUMOTHORAX': ('small pneumothorax', 'no pneumothorax'), | |
| 'EFFUSION': ('large pleural effusion', 'no pleural effusion'), | |
| 'PULMONARY_EDEMA': ('moderate pulmonary edema', 'no pulmonary edema'), | |
| } | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Zero-Shot Chest X-Ray Classification") | |
| parser.add_argument("--diagnosis", type=str, choices=DIAGNOSIS_PROMPTS.keys(), required=True, help="Diagnosis to evaluate") | |
| parser.add_argument("--data-dir", type=str, default="data", help="Path to data directory") | |
| parser.add_argument("--raw-image", type=str, help="Path to a raw image file for inference (optional)") | |
| args = parser.parse_args() | |
| # Get prompts | |
| pos_txt, neg_txt = DIAGNOSIS_PROMPTS[args.diagnosis] | |
| logger.info(f"Diagnosis: {args.diagnosis}") | |
| logger.info(f"Positive query: '{pos_txt}'") | |
| logger.info(f"Negative query: '{neg_txt}'") | |
| # Load precomputed model for text embeddings (and image embeddings if no raw image) | |
| precomputed_model = PrecomputedModel(data_dir=args.data_dir) | |
| pos_emb, neg_emb = precomputed_model.get_diagnosis_embeddings(pos_txt, neg_txt) | |
| if args.raw_image: | |
| # Raw Image Inference Mode | |
| logger.info(f"Running inference on raw image: {args.raw_image}") | |
| raw_model = RawImageModel() | |
| try: | |
| image_emb = raw_model.compute_embeddings(args.raw_image) | |
| # image_emb shape is likely (1, 32, 128) or (32, 128) | |
| # PrecomputedModel.zero_shot expects flattened or (32, 128) | |
| score = PrecomputedModel.zero_shot(image_emb, pos_emb, neg_emb) | |
| logger.info(f"Zero-shot score for {args.raw_image}: {score:.4f}") | |
| # Since we only have one image, we can't calculate AUC meaningfully | |
| # unless we run it against the full validation set which takes time. | |
| # For this demo, just output the score. | |
| print(f"Score for {args.diagnosis}: {score}") | |
| except Exception as e: | |
| logger.error(f"Failed to process raw image: {e}") | |
| sys.exit(1) | |
| else: | |
| # Precomputed Embeddings Evaluation Mode (Full Dataset) | |
| logger.info("Running evaluation on full precomputed dataset...") | |
| # Filter labels for the target diagnosis (0 or 1) | |
| labels_df = precomputed_model.labels | |
| target_df = labels_df[labels_df[args.diagnosis].isin([0, 1])][['image_id', args.diagnosis]].copy() | |
| image_ids = target_df['image_id'].tolist() | |
| true_labels = target_df[args.diagnosis].tolist() | |
| # Compute scores | |
| valid_ids, scores = precomputed_model.compute_scores(image_ids, pos_emb, neg_emb) | |
| # Filter labels to match valid_ids found in embeddings | |
| final_labels = [] | |
| for img_id, label in zip(image_ids, true_labels): | |
| if img_id in valid_ids: | |
| final_labels.append(label) | |
| if not scores: | |
| logger.error("No valid scores computed. Check embedding match.") | |
| sys.exit(1) | |
| # Evaluate | |
| evaluate_predictions(scores, final_labels, args.diagnosis) | |
| if __name__ == "__main__": | |
| main() | |