Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	| import os | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| import logging | |
| from utils.model_downloader import download_model_if_needed | |
| # Configure Logger | |
| logger = logging.getLogger(__name__) | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") | |
| class DepthEstimator: | |
| """ | |
| Generalized Depth Estimation Model Wrapper for MiDaS and DPT models. | |
| Supports: MiDaS v2.1 Small, MiDaS v2.1 Large, DPT Hybrid, DPT Large. | |
| """ | |
| def __init__(self, model_key="midas_v21_small_256", weights_dir="models/depth/weights", device="cpu"): | |
| """ | |
| Initialize the Depth Estimation model. | |
| Args: | |
| model_key (str): Model identifier as defined in model_downloader.py. | |
| weights_dir (str): Directory to store/download model weights. | |
| device (str): Inference device ("cpu" or "cuda"). | |
| """ | |
| weights_path = os.path.join(weights_dir, f"{model_key}.pt") | |
| download_model_if_needed(model_key, weights_path) | |
| logger.info(f"Loading Depth model '{model_key}' from MiDaS hub") | |
| self.device = device | |
| self.model_type = self._resolve_model_type(model_key) | |
| self.midas = torch.hub.load("intel-isl/MiDaS", self.model_type).to(self.device).eval() | |
| self.transform = self._resolve_transform() | |
| def _resolve_model_type(self, model_key): | |
| """ | |
| Maps model_key to MiDaS hub model type. | |
| """ | |
| mapping = { | |
| "midas_v21_small_256": "MiDaS_small", | |
| "midas_v21_384": "MiDaS", | |
| "dpt_hybrid_384": "DPT_Hybrid", | |
| "dpt_large_384": "DPT_Large", | |
| "dpt_swin2_large_384": "DPT_Large", # fallback to DPT_Large if not explicitly supported | |
| "dpt_beit_large_512": "DPT_Large", # fallback to DPT_Large if not explicitly supported | |
| } | |
| return mapping.get(model_key, "MiDaS_small") | |
| def _resolve_transform(self): | |
| """ | |
| Returns the correct transformation pipeline based on model type. | |
| """ | |
| transforms = torch.hub.load("intel-isl/MiDaS", "transforms") | |
| if self.model_type == "MiDaS_small": | |
| return transforms.small_transform | |
| else: | |
| return transforms.default_transform | |
| def predict(self, image: Image.Image): | |
| """ | |
| Generates a depth map for the given image. | |
| Args: | |
| image (PIL.Image.Image): Input image. | |
| Returns: | |
| np.ndarray: Depth map as a 2D numpy array. | |
| """ | |
| logger.info("Running depth estimation") | |
| input_tensor = self.transform(image).to(self.device) | |
| with torch.no_grad(): | |
| prediction = self.midas(input_tensor) | |
| prediction = torch.nn.functional.interpolate( | |
| prediction.unsqueeze(1), | |
| size=image.size[::-1], | |
| mode="bicubic", | |
| align_corners=False, | |
| ).squeeze() | |
| depth_map = prediction.cpu().numpy() | |
| logger.info("Depth estimation completed successfully") | |
| return depth_map | |
