import cv2 import numpy as np from PIL import Image import torch from transformers import DPTForDepthEstimation, DPTImageProcessor # Initialize Depth Estimator outside functions to avoid re-loading # Use config.DEVICE and config.DTYPE for consistency try: from config import DEVICE, DTYPE except ImportError: DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32 print(f"Loading Depth Estimator on {DEVICE} with {DTYPE}...") depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas") depth_estimator.to(DEVICE) feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas") print("Depth Estimator loaded.") def apply_canny(image: Image.Image) -> Image.Image: """ Applies Canny edge detection to a PIL Image. """ image_np = np.array(image) # Convert to grayscale if not already if len(image_np.shape) == 3 and image_np.shape[2] == 3: image_np = cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY) # Apply Canny image_edges = cv2.Canny(image_np, 100, 200) # You can adjust thresholds # Convert back to 3-channel for ControlNet image_edges = image_edges[:, :, None] image_edges = np.concatenate([image_edges, image_edges, image_edges], axis=2) return Image.fromarray(image_edges) def apply_depth(image: Image.Image) -> Image.Image: """ Estimates depth from a PIL Image and returns a depth map image. """ original_size = image.size # Resize image for depth estimation speed if it's very large, maintain aspect ratio max_dim = max(original_size) if max_dim > 768: scale_factor = 768 / max_dim image = image.resize((int(original_size[0] * scale_factor), int(original_size[1] * scale_factor)), Image.BICUBIC) inputs = feature_extractor(images=image, return_tensors="pt").to(DEVICE) with torch.no_grad(): outputs = depth_estimator(**inputs) predicted_depth = outputs.predicted_depth # Interpolate to original size and normalize prediction = torch.nn.functional.interpolate( predicted_depth.unsqueeze(1), size=original_size[::-1], # PIL size is (width, height), interpolate expects (height, width) mode="bicubic", align_corners=False, ) output = prediction.squeeze().cpu().numpy() # Normalize to 0-255 and convert to uint8 formatted_output = np.interp(output, (output.min(), output.max()), (0, 255)).astype(np.uint8) return Image.fromarray(formatted_output)