import gradio as gr
from PIL import Image
import src.depth_pro as depth_pro
import numpy as np
import matplotlib.pyplot as plt
import subprocess
import spaces
import torch
import tempfile
import os
import trimesh
import time
import timm  # Add this import
import subprocess
import cv2  # Add this import
from datetime import datetime

# Ensure timm is properly loaded
print(f"Timm version: {timm.__version__}")

# Run the script to download pretrained models
subprocess.run(["bash", "get_pretrained_models.sh"])

# Set the device to GPU if available, else CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Load the depth prediction model and its preprocessing transforms
model, transform = depth_pro.create_model_and_transforms()
model = model.to(device)  # Move the model to the selected device
model.eval()  # Set the model to evaluation mode

def resize_image(image_path, max_size=1024):
    """
    Resize the input image to ensure its largest dimension does not exceed max_size.
    Maintains the aspect ratio and saves the resized image as a temporary PNG file.

    Args:
        image_path (str): Path to the input image.
        max_size (int, optional): Maximum size for the largest dimension. Defaults to 1024.

    Returns:
        str: Path to the resized temporary image file.
    """
    with Image.open(image_path) as img:
        # Calculate the resizing ratio while maintaining aspect ratio
        ratio = max_size / max(img.size)
        new_size = tuple([int(x * ratio) for x in img.size])
        
        # Resize the image using LANCZOS filter for high-quality downsampling
        img = img.resize(new_size, Image.LANCZOS)
        
        # Save the resized image to a temporary file
        with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
            img.save(temp_file, format="PNG")
            return temp_file.name

def generate_3d_model(depth, image_path, focallength_px):
    """
    Generate a textured 3D mesh from the depth map and the original image.

    Args:
        depth (np.ndarray): 2D array representing depth in meters.
        image_path (str): Path to the resized RGB image.
        focallength_px (float): Focal length in pixels.

    Returns:
        tuple: Paths to the exported 3D model files for viewing and downloading.
    """
    # Load the RGB image and convert to a NumPy array
    image = np.array(Image.open(image_path))
    height, width = depth.shape

    # Compute camera intrinsic parameters
    fx = fy = focallength_px  # Assuming square pixels and fx = fy
    cx, cy = width / 2, height / 2  # Principal point at the image center

    # Create a grid of (u, v) pixel coordinates
    u = np.arange(0, width)
    v = np.arange(0, height)
    uu, vv = np.meshgrid(u, v)

    # Convert pixel coordinates to real-world 3D coordinates using the pinhole camera model
    Z = depth.flatten()
    X = ((uu.flatten() - cx) * Z) / fx
    Y = ((vv.flatten() - cy) * Z) / fy

    # Stack the coordinates to form vertices (X, Y, Z)
    vertices = np.vstack((X, Y, Z)).T

    # Normalize RGB colors to [0, 1] for vertex coloring
    colors = image.reshape(-1, 3) / 255.0

    # Generate faces by connecting adjacent vertices to form triangles
    faces = []
    for i in range(height - 1):
        for j in range(width - 1):
            idx = i * width + j
            # Triangle 1
            faces.append([idx, idx + width, idx + 1])
            # Triangle 2
            faces.append([idx + 1, idx + width, idx + width + 1])
    faces = np.array(faces)

    # Create the mesh using Trimesh with vertex colors
    mesh = trimesh.Trimesh(vertices=vertices, faces=faces, vertex_colors=colors)

    # Export the mesh to OBJ files with unique filenames
    timestamp = int(time.time())
    view_model_path = f'view_model_{timestamp}.obj'
    download_model_path = f'download_model_{timestamp}.obj'
    mesh.export(view_model_path)
    mesh.export(download_model_path)
    return view_model_path, download_model_path

@spaces.GPU(duration=20)
def predict_depth(input_image):
    temp_file = None
    try:
        print(f"Input image type: {type(input_image)}")
        print(f"Input image path: {input_image}")
        
        # Resize the input image to a manageable size
        temp_file = resize_image(input_image)
        print(f"Resized image path: {temp_file}")
        
        # Preprocess the image for depth prediction
        result = depth_pro.load_rgb(temp_file)
        
        # Add error checking for the result tuple
        if len(result) < 2:
            raise ValueError(f"Unexpected result from load_rgb: {result}")
        
        image = result[0]  # Unpack the result tuple correctly
        f_px = result[-1]  # Extract focal length
        
        print(f"Extracted focal length: {f_px}")
        
        image = transform(image)  # Apply preprocessing transforms
        image = image.to(device)  # Move the image tensor to the selected device

        # Run the depth prediction model
        prediction = model.infer(image, f_px=f_px)
        depth = prediction["depth"]  # Depth map in meters
        focallength_px = prediction["focallength_px"]  # Focal length in pixels

        # Convert depth from torch tensor to NumPy array if necessary
        if isinstance(depth, torch.Tensor):
            depth = depth.cpu().numpy()

        # Ensure the depth map is a 2D array
        if depth.ndim != 2:
            depth = depth.squeeze()

        # Print debug information
        print(f"Original depth shape: {depth.shape}")
        print(f"Original image shape: {image.shape}")

        # Resize depth to match image dimensions
        image_height, image_width = image.shape[2], image.shape[3]
        depth = cv2.resize(depth, (image_width, image_height), interpolation=cv2.INTER_LINEAR)

        print(f"Resized depth shape: {depth.shape}")
        print(f"Final image shape: {image.shape}")

        # No downsampling
        downscale_factor = 1
        
        # Convert image tensor to CPU and NumPy
        image_np = image.cpu().detach().numpy()[0].transpose(1, 2, 0)

        # No normalization of depth map as it is already in meters
        depth_min = np.min(depth)
        depth_max = np.max(depth)
        depth_normalized = depth  # Depth remains in meters

        # Create a color map for visualization using matplotlib
        plt.figure(figsize=(10, 10))
        plt.imshow(depth_normalized, cmap='gist_rainbow')
        plt.colorbar(label='Depth [m]')
        plt.title(f'Predicted Depth Map - Min: {depth_min:.1f}m, Max: {depth_max:.1f}m')
        plt.axis('off')  # Hide axis for a cleaner image

        # Save the depth map visualization to a file
        output_path = "depth_map.png"
        plt.savefig(output_path)
        plt.close()

        # Save the raw depth data to a CSV file for download
        raw_depth_path = "raw_depth_map.csv"
        np.savetxt(raw_depth_path, depth, delimiter=',')

        # Generate the 3D model from the depth map and resized image
        view_model_path, download_model_path = generate_3d_model(depth, temp_file, focallength_px)

        return output_path, f"Focal length: {focallength_px:.2f} pixels", raw_depth_path, view_model_path, download_model_path
    except Exception as e:
        # Return error messages in case of failures
        import traceback
        error_message = f"An error occurred: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
        print(error_message)  # Print the full error message to the console
        return None, error_message, None, None, None
    finally:
        # Clean up by removing the temporary resized image file
        if temp_file and os.path.exists(temp_file):
            os.remove(temp_file)

def get_last_commit_timestamp():
    try:
        timestamp = subprocess.check_output(['git', 'log', '-1', '--format=%cd', '--date=iso']).decode('utf-8').strip()
        return datetime.fromisoformat(timestamp).strftime("%Y-%m-%d %H:%M:%S")
    except Exception:
        return "Unknown"
    
# Create the Gradio interface with appropriate input and output components. 
last_updated = get_last_commit_timestamp()

iface = gr.Interface(
    fn=predict_depth,
    inputs=gr.Image(type="filepath"),
    outputs=[
        gr.Image(type="filepath", label="Depth Map"),
        gr.Textbox(label="Focal Length or Error Message"),
        gr.File(label="Download Raw Depth Map (CSV)"),
        gr.Model3D(label="View 3D Model"),
        gr.File(label="Download 3D Model (OBJ)")
    ],
    title="DepthPro Demo with 3D Visualization",
    description=(
        "An enhanced demo that creates a textured 3D model from the input image and depth map.\n\n"
        "Forked from https://huggingface.co/spaces/akhaliq/depth-pro and model from https://huggingface.co/apple/DepthPro\n"
        "**Instructions:**\n"
        "1. Upload an image.\n"
        "2. The app will predict the depth map, display it, and provide the focal length.\n"
        "3. Download the raw depth data as a CSV file.\n"
        "4. View the generated 3D model textured with the original image.\n"
        "5. Download the 3D model as an OBJ file if desired.\n\n"
        f"Last updated: {last_updated}"
    ),
)

# Launch the Gradio interface with sharing enabled
iface.launch(share=True)  # share=True allows you to share the interface with others.