maxwoe's picture
Upload model_cgd.py with huggingface_hub
ed8d2c4 verified
"""
Circular Gaussian Distribution (CGD) for Image Orientation Estimation (Inference Only)
Represents angles as probability distributions over discretized angle bins.
Model output: Probability distribution over 360 angle bins (1 degree resolution)
"""
import math
from typing import Dict, Any
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import pytorch_lightning as pl
import timm
import timm.data
from PIL import Image
import numpy as np
from loguru import logger
class CircularGaussianDistribution(nn.Module):
"""Circular Gaussian Distribution module for 360 degree image orientation."""
def __init__(self, num_bins: int = 360, sigma: float = 6.0):
super().__init__()
self.num_bins = num_bins
self.sigma = sigma
self.bin_size = 360.0 / num_bins
bin_centers = torch.arange(0, 360, self.bin_size)
self.register_buffer('bin_centers', bin_centers)
logger.info(f"CGD: {num_bins} bins, range [0, 360), sigma={sigma}")
def distribution_to_angle(self, distributions: torch.Tensor, method: str = 'argmax') -> torch.Tensor:
"""Extract angles from probability distributions.
Args:
distributions: Probability distributions [B, num_bins]
method: 'argmax', 'weighted_average', or 'peak_fitting'
Returns:
angles: Extracted angles in degrees [B] in [0, 360)
"""
if method == 'argmax':
peak_indices = torch.argmax(distributions, dim=1)
angles = self.bin_centers[peak_indices]
elif method == 'weighted_average':
weights = distributions / (distributions.sum(dim=1, keepdim=True) + 1e-8)
bin_angles_rad = self.bin_centers * torch.pi / 180.0
cos_components = torch.cos(bin_angles_rad)
sin_components = torch.sin(bin_angles_rad)
avg_cos = torch.sum(weights * cos_components.unsqueeze(0), dim=1)
avg_sin = torch.sum(weights * sin_components.unsqueeze(0), dim=1)
angles = torch.atan2(avg_sin, avg_cos) * 180.0 / torch.pi
angles = angles % 360.0
elif method == 'peak_fitting':
peak_indices = torch.argmax(distributions, dim=1)
angles = torch.zeros_like(peak_indices, dtype=torch.float)
for i in range(distributions.shape[0]):
peak_idx = peak_indices[i].item()
if 0 < peak_idx < self.num_bins - 1:
y1 = distributions[i, peak_idx - 1]
y2 = distributions[i, peak_idx]
y3 = distributions[i, peak_idx + 1]
a = 0.5 * (y1 - 2*y2 + y3)
b = 0.5 * (y3 - y1)
if abs(a) > 1e-8:
offset = -b / (2 * a)
offset = torch.clamp(offset, -0.5, 0.5)
else:
offset = 0
angles[i] = self.bin_centers[peak_idx] + offset * self.bin_size
else:
angles[i] = self.bin_centers[peak_idx]
else:
raise ValueError(f"Unknown extraction method: {method}")
angles = angles % 360.0
return angles
def get_distribution_uncertainty(self, distributions: torch.Tensor) -> torch.Tensor:
"""Calculate entropy-based uncertainty from distribution."""
log_probs = torch.log(distributions + 1e-8)
entropy = -torch.sum(distributions * log_probs, dim=1)
max_entropy = math.log(self.num_bins)
return entropy / max_entropy
class CGDAngleEstimation(pl.LightningModule):
"""CGD model for 360 degree image orientation estimation (inference only)."""
def __init__(
self,
batch_size: int = 16,
train_dir: str = "",
model_name: str = "vit_tiny_patch16_224",
learning_rate: float = 0.001,
validation_split: float = 0.1,
random_seed: int = 42,
image_size: int = 224,
num_bins: int = 360,
sigma: float = 6.0,
inference_method: str = 'argmax',
loss_type: str = 'kl_divergence',
test_dir=None,
test_rotation_range=360.0,
test_random_seed=42,
) -> None:
super().__init__()
self.save_hyperparameters()
self.model_name = model_name
self.learning_rate = learning_rate
self.batch_size = batch_size
self.train_dir = train_dir
self.validation_split = validation_split
self.random_seed = random_seed
self.image_size = image_size
self.num_bins = num_bins
self.sigma = sigma
self.inference_method = inference_method
self.loss_type = loss_type
self.model = timm.create_model(model_name, pretrained=True, num_classes=num_bins)
self.cgd = CircularGaussianDistribution(num_bins=num_bins, sigma=sigma)
@classmethod
def try_load(cls, checkpoint_path=None, **kwargs):
"""Load model from checkpoint."""
if checkpoint_path:
logger.info(f"Loading model from checkpoint: {checkpoint_path}")
model = cls.load_from_checkpoint(checkpoint_path, **kwargs)
logger.info("Model loaded successfully from checkpoint")
return model
raise FileNotFoundError("Checkpoint file not found")
@classmethod
def from_pretrained(cls, repo_id, model_name=None):
"""Load a pretrained model from HuggingFace Hub.
Args:
repo_id: HuggingFace repo ID (e.g. "maxwoe/image-rotation-angle-estimation")
model_name: Display name or checkpoint filename from config.json.
Defaults to the default model.
"""
import json
from huggingface_hub import hf_hub_download
config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
with open(config_path) as f:
config = json.load(f)
if model_name is None:
model_name = config["default_model"]
# Look up by display name or by filename
if model_name in config["models"]:
model_info = config["models"][model_name]
else:
model_info = None
for info in config["models"].values():
if info["filename"] == model_name:
model_info = info
break
if model_info is None:
available = [i["filename"] for i in config["models"].values()]
raise ValueError(f"Unknown model: {model_name}. Available: {available}")
ckpt_path = hf_hub_download(repo_id=repo_id, filename=model_info["filename"])
model = cls.try_load(checkpoint_path=ckpt_path, image_size=model_info["input_size"])
model.eval()
return model
def forward(self, x: torch.Tensor, return_logits: bool = False) -> torch.Tensor:
"""Forward pass returning probability distribution over angles."""
logits = self.model(x)
if return_logits:
return logits
return F.softmax(logits, dim=1)
def predict_angle(self, image) -> float:
"""Detect the current orientation angle of an image.
Args:
image: PIL Image, numpy array, or file path string.
For best results, pass PIL Image or numpy array directly.
Returns:
Predicted rotation angle in degrees [0, 360).
"""
self.eval()
if isinstance(image, str):
image = Image.open(image).convert('RGB')
elif isinstance(image, np.ndarray):
image = Image.fromarray(image).convert('RGB')
elif not isinstance(image, Image.Image):
raise TypeError(f"Expected PIL Image, numpy array, or file path, got {type(image)}")
else:
image = image.convert('RGB')
try:
data_config = timm.data.resolve_model_data_config(self.hparams.model_name)
data_config['crop_pct'] = 1.0
data_config['input_size'] = (3, self.image_size, self.image_size)
transform = timm.data.create_transform(**data_config, is_training=False)
except Exception:
transform = transforms.Compose([
transforms.Resize((self.image_size, self.image_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
image_tensor = transform(image).unsqueeze(0)
with torch.no_grad():
pred_distributions = self(image_tensor)
angle = self.cgd.distribution_to_angle(pred_distributions, method=self.inference_method).item()
return angle