|
|
|
|
|
""" |
|
|
Zero-shot inference script for demo images using Hugging Face models. |
|
|
Runs inference on images in demo/iwildcam_demo_images using specified models |
|
|
and saves results to JSON files. |
|
|
""" |
|
|
|
|
|
import os |
|
|
import json |
|
|
import torch |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
from transformers import pipeline |
|
|
from collections import OrderedDict |
|
|
import warnings |
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
try: |
|
|
import open_clip |
|
|
OPEN_CLIP_AVAILABLE = True |
|
|
except ImportError: |
|
|
OPEN_CLIP_AVAILABLE = False |
|
|
|
|
|
|
|
|
SPECIES_MAP = OrderedDict([ |
|
|
(24, "Jaguar"), |
|
|
(10, "Ocelot"), |
|
|
(6, "Mountain Lion"), |
|
|
(101, "Common Eland"), |
|
|
(102, "Waterbuck"), |
|
|
]) |
|
|
|
|
|
|
|
|
CLASS_NAMES = list(SPECIES_MAP.values()) |
|
|
|
|
|
|
|
|
DESCRIPTIVE_CLASS_NAMES = [ |
|
|
"a jaguar cat", |
|
|
"an ocelot cat", |
|
|
"a mountain lion cougar", |
|
|
"a common eland antelope", |
|
|
"a waterbuck antelope" |
|
|
] |
|
|
|
|
|
|
|
|
MODELS = [ |
|
|
"openai/clip-vit-large-patch14", |
|
|
"google/siglip2-large-patch16-384", |
|
|
"google/siglip2-large-patch16-512", |
|
|
"google/siglip2-so400m-patch16-naflex", |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"facebook/PE-Core-L14-336", |
|
|
"laion/CLIP-ViT-L-14-laion2B-s32B-b82K" |
|
|
] |
|
|
|
|
|
def load_demo_annotations(): |
|
|
"""Load the demo annotations to get image metadata.""" |
|
|
with open('iwildcam_demo_annotations.json', 'r') as f: |
|
|
data = json.load(f) |
|
|
|
|
|
|
|
|
image_metadata = {} |
|
|
for annotation in data['annotations']: |
|
|
image_id = annotation['image_id'] |
|
|
category_id = annotation['category_id'] |
|
|
image_info = next((img for img in data['images'] if img['id'] == image_id), None) |
|
|
if image_info: |
|
|
image_metadata[image_info['file_name']] = { |
|
|
'species_id': category_id, |
|
|
'species_name': SPECIES_MAP.get(category_id, "Unknown") |
|
|
} |
|
|
|
|
|
return image_metadata |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run_openclip_inference(model_name, image_paths, class_names): |
|
|
"""Run zero-shot inference using OpenCLIP models.""" |
|
|
if not OPEN_CLIP_AVAILABLE: |
|
|
print("open_clip is not available. Please install it with: pip install open_clip_torch") |
|
|
return None |
|
|
|
|
|
print(f"Loading OpenCLIP model: {model_name}") |
|
|
try: |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
|
if model_name == "facebook/PE-Core-L14-336": |
|
|
model, _, preprocess = open_clip.create_model_and_transforms('ViT-L-14', pretrained='metaclip_fullcc') |
|
|
elif model_name == "laion/CLIP-ViT-L-14-laion2B-s32B-b82K": |
|
|
model, _, preprocess = open_clip.create_model_and_transforms('ViT-L-14', pretrained='laion2b_s32b_b82k') |
|
|
else: |
|
|
print(f"Unknown OpenCLIP model: {model_name}") |
|
|
return None |
|
|
|
|
|
model = model.to(device) |
|
|
model.eval() |
|
|
tokenizer = open_clip.get_tokenizer('ViT-L-14') |
|
|
|
|
|
|
|
|
prompts = [f"a photo of a {class_name.lower()}" for class_name in class_names] |
|
|
text_tokens = tokenizer(prompts).to(device) |
|
|
|
|
|
results = {} |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
text_features = model.encode_text(text_tokens) |
|
|
text_features /= text_features.norm(dim=-1, keepdim=True) |
|
|
|
|
|
for i, image_path in enumerate(image_paths): |
|
|
if i % 10 == 0: |
|
|
print(f"Processing image {i+1}/{len(image_paths)}: {os.path.basename(image_path)}") |
|
|
|
|
|
try: |
|
|
image = Image.open(image_path).convert("RGB") |
|
|
image_tensor = preprocess(image).unsqueeze(0).to(device) |
|
|
|
|
|
|
|
|
image_features = model.encode_image(image_tensor) |
|
|
image_features /= image_features.norm(dim=-1, keepdim=True) |
|
|
|
|
|
|
|
|
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1) |
|
|
probabilities = similarity.squeeze(0).cpu().numpy() |
|
|
|
|
|
scores = {} |
|
|
for j, class_name in enumerate(class_names): |
|
|
scores[class_name] = float(probabilities[j]) |
|
|
|
|
|
results[os.path.basename(image_path)] = scores |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error processing {image_path}: {e}") |
|
|
uniform_prob = 1.0 / len(class_names) |
|
|
results[os.path.basename(image_path)] = {class_name: uniform_prob for class_name in class_names} |
|
|
|
|
|
return results |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error loading OpenCLIP model {model_name}: {e}") |
|
|
return None |
|
|
|
|
|
def run_siglip_inference(model_name, image_paths, class_names): |
|
|
"""Run zero-shot inference using SigLIP with manual CLIP-style computation.""" |
|
|
print(f"Loading SigLIP model: {model_name}") |
|
|
try: |
|
|
from transformers import AutoProcessor, AutoModel |
|
|
processor = AutoProcessor.from_pretrained(model_name) |
|
|
model = AutoModel.from_pretrained(model_name) |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
model = model.to(device) |
|
|
model.eval() |
|
|
|
|
|
results = {} |
|
|
|
|
|
with torch.no_grad(): |
|
|
for i, image_path in enumerate(image_paths): |
|
|
if i % 10 == 0: |
|
|
print(f"Processing image {i+1}/{len(image_paths)}: {os.path.basename(image_path)}") |
|
|
|
|
|
try: |
|
|
image = Image.open(image_path).convert("RGB") |
|
|
prompts = [f"This is a photo of a {class_name.lower()}" for class_name in class_names] |
|
|
inputs = processor( |
|
|
text=prompts, |
|
|
images=image, |
|
|
return_tensors="pt", |
|
|
padding="max_length", |
|
|
truncation=True |
|
|
).to(device) |
|
|
|
|
|
outputs = model(**inputs) |
|
|
logits_per_image = outputs.logits_per_image |
|
|
sigmoid_probs = torch.sigmoid(logits_per_image).squeeze(0) |
|
|
probabilities = torch.softmax(logits_per_image, dim=-1).squeeze(0) |
|
|
|
|
|
scores = {} |
|
|
for j, class_name in enumerate(class_names): |
|
|
scores[class_name] = probabilities[j].item() |
|
|
|
|
|
results[os.path.basename(image_path)] = scores |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error processing {image_path}: {e}") |
|
|
results[os.path.basename(image_path)] = {class_name: 0.0 for class_name in class_names} |
|
|
|
|
|
return results |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error loading SigLIP: {e}") |
|
|
return None |
|
|
|
|
|
def run_zeroshot_inference(model_name, image_paths, class_names): |
|
|
"""Run zero-shot inference using specified model.""" |
|
|
print(f"Loading model: {model_name}") |
|
|
|
|
|
try: |
|
|
|
|
|
classifier = pipeline( |
|
|
"zero-shot-image-classification", |
|
|
model=model_name, |
|
|
device=0 if torch.cuda.is_available() else -1 |
|
|
) |
|
|
|
|
|
results = {} |
|
|
|
|
|
for i, image_path in enumerate(image_paths): |
|
|
if i % 10 == 0: |
|
|
print(f"Processing image {i+1}/{len(image_paths)}: {os.path.basename(image_path)}") |
|
|
|
|
|
try: |
|
|
image = Image.open(image_path).convert("RGB") |
|
|
prompts = [f"a photo of a {class_name.lower()}" for class_name in class_names] |
|
|
outputs = classifier(image, prompts) |
|
|
|
|
|
scores = {} |
|
|
for output in outputs: |
|
|
prompt = output['label'] |
|
|
|
|
|
for i, p in enumerate(prompts): |
|
|
if p == prompt: |
|
|
class_name = class_names[i] |
|
|
scores[class_name] = output['score'] |
|
|
break |
|
|
|
|
|
|
|
|
for class_name in class_names: |
|
|
if class_name not in scores: |
|
|
scores[class_name] = 0.0 |
|
|
|
|
|
results[os.path.basename(image_path)] = scores |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error processing {image_path}: {e}") |
|
|
|
|
|
results[os.path.basename(image_path)] = {class_name: 0.0 for class_name in class_names} |
|
|
|
|
|
return results |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error loading model {model_name}: {e}") |
|
|
return None |
|
|
|
|
|
def main(): |
|
|
"""Main function to run zero-shot inference on all models.""" |
|
|
|
|
|
image_dir = "iwildcam_demo_images" |
|
|
image_files = [f for f in os.listdir(image_dir) if f.endswith('.jpg')] |
|
|
image_paths = [os.path.join(image_dir, f) for f in image_files] |
|
|
|
|
|
print(f"Found {len(image_files)} demo images") |
|
|
|
|
|
|
|
|
image_metadata = load_demo_annotations() |
|
|
print(f"Loaded metadata for {len(image_metadata)} images") |
|
|
|
|
|
|
|
|
for model_name in MODELS: |
|
|
print(f"\n{'='*60}") |
|
|
print(f"Running inference with {model_name}") |
|
|
print(f"{'='*60}") |
|
|
|
|
|
|
|
|
model_safe_name = model_name.replace("/", "_").replace("-", "_") |
|
|
output_file = f"zeroshot_results_{model_safe_name}.json" |
|
|
|
|
|
if os.path.exists(output_file): |
|
|
print(f"Results file {output_file} already exists, skipping {model_name}") |
|
|
continue |
|
|
|
|
|
|
|
|
if model_name in ["imageomics/bioclip", "imageomics/bioclip-2"]: |
|
|
|
|
|
print("Use pybioclip!") |
|
|
return |
|
|
elif model_name.startswith("google/siglip"): |
|
|
results = run_siglip_inference(model_name, image_paths, CLASS_NAMES) |
|
|
elif model_name in ["facebook/PE-Core-L14-336", "laion/CLIP-ViT-L-14-laion2B-s32B-b82K"]: |
|
|
results = run_openclip_inference(model_name, image_paths, CLASS_NAMES) |
|
|
else: |
|
|
results = run_zeroshot_inference(model_name, image_paths, CLASS_NAMES) |
|
|
|
|
|
if results is not None: |
|
|
|
|
|
output_data = { |
|
|
"model": model_name, |
|
|
"class_names": CLASS_NAMES, |
|
|
"num_images": len(results), |
|
|
"results": results |
|
|
} |
|
|
|
|
|
with open(output_file, 'w') as f: |
|
|
json.dump(output_data, f, indent=2) |
|
|
|
|
|
print(f"Results saved to {output_file}") |
|
|
|
|
|
|
|
|
sample_images = list(results.keys())[:3] |
|
|
print(f"\nSample results from {model_name}:") |
|
|
for img in sample_images: |
|
|
print(f" {img}:") |
|
|
scores = results[img] |
|
|
|
|
|
sorted_scores = sorted(scores.items(), key=lambda x: x[1], reverse=True) |
|
|
for class_name, score in sorted_scores[:3]: |
|
|
print(f" {class_name}: {score:.4f}") |
|
|
else: |
|
|
print(f"Failed to run inference with {model_name}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
os.chdir(os.path.dirname(os.path.abspath(__file__))) |
|
|
main() |
|
|
|