| | import os |
| | import torch |
| | import base64 |
| | import io |
| | import requests |
| | import matplotlib.pyplot as plt |
| | from PIL import Image |
| | from transformers import AutoImageProcessor, AutoModelForDepthEstimation |
| | import numpy as np |
| |
|
| |
|
| |
|
| | class EndpointHandler: |
| | def __init__(self, path=""): |
| | |
| | self.model_path = path or os.environ.get("MODEL_PATH", "") |
| | print(self.model_path) |
| | self.image_processor = AutoImageProcessor.from_pretrained(self.model_path) |
| | self.model = AutoModelForDepthEstimation.from_pretrained(self.model_path) |
| |
|
| | |
| | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | self.model = self.model.to(self.device) |
| |
|
| | |
| | self.model.eval() |
| |
|
| | def __call__(self, data): |
| | """ |
| | Args: |
| | data: Input data in the format of a dictionary with either: |
| | - 'url': URL of the image |
| | - 'file': Base64 encoded image |
| | - 'image': PIL Image object |
| | - 'visualization': Boolean flag to return visualization-friendly format (default: False) |
| | - 'x': Int pixel position on axis x |
| | - 'y': Int pixel position on axis y |
| | Returns: |
| | Dictionary containing the depth map and metadata |
| | """ |
| | |
| | |
| | if "url" in data: |
| | |
| | response = requests.get(data["url"], stream=True) |
| | response.raise_for_status() |
| | image = Image.open(response.raw).convert("RGB") |
| |
|
| | elif "file" in data: |
| | |
| | image_bytes = base64.b64decode(data["file"]) |
| | image = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
| |
|
| | elif "image" in data: |
| | |
| | image = data["image"] |
| |
|
| | else: |
| | raise ValueError("No valid image input found. Please provide either 'url', 'file' (base64 encoded image), or 'image' (PIL Image object).") |
| |
|
| | |
| | inputs = self.image_processor(images=image, return_tensors="pt") |
| | inputs = {k: v.to(self.device) for k, v in inputs.items()} |
| |
|
| | |
| | with torch.no_grad(): |
| | outputs = self.model(**inputs) |
| | predicted_depth = outputs.predicted_depth |
| |
|
| | |
| | prediction = torch.nn.functional.interpolate( |
| | predicted_depth.unsqueeze(1), |
| | size=image.size[::-1], |
| | mode="bicubic", |
| | align_corners=False, |
| | ).squeeze() |
| |
|
| | |
| | depth_map = prediction.cpu().numpy() |
| |
|
| | |
| | depth_min = depth_map.min() |
| | depth_max = depth_map.max() |
| | normalized_depth = (depth_map - depth_min) / (depth_max - depth_min) |
| |
|
| | |
| | visualization = data.get("visualization", False) |
| |
|
| | |
| | x= data.get('x',0) |
| | y= data.get('y',0) |
| |
|
| | map = np.array(depth_map) |
| | print(map.shape) |
| |
|
| | if visualization: |
| | |
| | |
| | plt.figure(figsize=(10, 10)) |
| | plt.imshow(normalized_depth, cmap='plasma') |
| | plt.axis('off') |
| |
|
| | |
| | buf = io.BytesIO() |
| | plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0) |
| | plt.close() |
| | buf.seek(0) |
| |
|
| | |
| | img_str = base64.b64encode(buf.getvalue()).decode('utf-8') |
| |
|
| | result = { |
| | "visualization": img_str, |
| | "min_depth": float(depth_min), |
| | "max_depth": float(depth_max), |
| | "format": "base64_png" |
| | } |
| | else: |
| | result = { |
| | "deph": depth_map[y][x] |
| | |
| | |
| | |
| | |
| | |
| | |
| | } |
| |
|
| | return result |
| |
|