import gradio as gr
import torch
import numpy as np
from PIL import Image
from transformers import SamConfig, SamProcessor, SamModel

# Load the model and processor
print('Status Update: Loading SAM Model ...')
model_config = SamConfig.from_pretrained("facebook/sam-vit-base")
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

# Create an instance of the model architecture with the loaded configuration
sidewalk_model = SamModel(config=model_config)
# Update the model by loading the weights from saved file.
print('Status Update: Loading SAM pre-train weights ...')
checkpoint = torch.load("checkpoints/checkpoint_at_epoch_1.pt", map_location=torch.device('cpu')) #('mps')
sidewalk_model.load_state_dict(checkpoint["model"])

# Set the device
# device = "mps" if torch.backends.mps.is_available() else "cpu"
device = "cpu"
sidewalk_model.to(device)
# print('Status Update: Using GPU.')
print('Status Update: FindMySidewalk Ready for inference ...')

# Generate bounding box prompt for SAM
def get_bounding_box(W = 256, H = 256, x_min = 0, y_min = 0, x_max = 256, y_max = 256):
  # add perturbation if inputted bounding box coordinates
  x_min = max(0, x_min - np.random.randint(0, 20))
  x_max = min(W, x_max + np.random.randint(0, 20))
  y_min = max(0, y_min - np.random.randint(0, 20))
  y_max = min(H, y_max + np.random.randint(0, 20))

  bbox = [x_min, y_min, x_max, y_max]

  return bbox

def segment_sidewalk(image):
    test_image = Image.fromarray(image).convert("RGB")

    # Keep a copy of original image for display
    original_image = test_image.copy()

    # # Create grid of points for prompting
    # array_size = 256
    # grid_size = 7
    # x = np.linspace(0, array_size - 1, grid_size)
    # y = np.linspace(0, array_size - 1, grid_size)
    # xv, yv = np.meshgrid(x, y)
    # input_points = [[[int(x), int(y)] for x, y in zip(x_row, y_row)] for x_row, y_row in zip(xv.tolist(), yv.tolist())]
    # input_points = torch.tensor(input_points).view(1, 1, grid_size * grid_size, 2)

    # obtain bounding box prompt over entire image
    prompt = get_bounding_box(test_image.size[0], test_image.size[1], 0, 0, test_image.size[0], test_image.size[1])
    # prepare image for the model
    inputs = processor(test_image, input_boxes=[[prompt]], return_tensors="pt")
    # Convert dtype to float32 as the MPS framework doesn't support float64
    inputs = {k: v.to(torch.float32).to(device) for k, v in inputs.items()}
    sidewalk_model.eval()

    with torch.no_grad():
        outputs = sidewalk_model(**inputs, multimask_output=False)

    # apply sigmoid and convert soft mask to hard mask
    test_image_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
    test_image_prob = test_image_prob.cpu().numpy().squeeze()
    pred_mask_np = (test_image_prob > 0.85).astype(np.uint8)
    segmented_image = Image.fromarray(pred_mask_np * 255)  # Convert mask to an image

    return original_image, segmented_image


# Setup Gradio interface
demo = gr.Interface(
    fn=segment_sidewalk,
    inputs=gr.Image(),
    outputs=[
        gr.Image(type="pil", label="Original Image"),
        gr.Image(type="pil", label="Segmented Mask")
    ],
    title="Sidewalk Segmentation using SAM",
    description="Upload a satellite image tile to segment sidewalks."
)

# Run the interface
demo.launch()