import os
import warnings
warnings.filterwarnings('ignore')

import subprocess, io, os, sys, time

os.system("pip install -q gradio")
os.system("pip install -q diffusers")
os.system("pip install -q segment_anything")
os.system("pip install accelerate")
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"

result = subprocess.run(['pip', 'install', '-e', 'GroundingDINO'], check=True)
print(f'pip install GroundingDINO = {result}')
sys.path.insert(0, './GroundingDINO')


'''Importing Libraries'''
import os

import groundingdino.datasets.transforms as T
import numpy as np
import torch
from groundingdino.models import build_model
from groundingdino.util import box_ops
from groundingdino.util.inference import predict
from groundingdino.util.slconfig import SLConfig
from groundingdino.util.utils import clean_state_dict
from huggingface_hub import hf_hub_download
from segment_anything import sam_model_registry
from segment_anything import SamPredictor
from diffusers import StableDiffusionInpaintPipeline, AutoPipelineForInpainting

from scipy.ndimage import binary_dilation

import cv2
import matplotlib.pyplot as plt
from PIL import Image
from torchvision.utils import draw_bounding_boxes
from torchvision.utils import draw_segmentation_masks

torch.set_default_dtype(torch.float32)


def load_model_hf(repo_id, filename, ckpt_config_filename, device='cpu'):
    '''
        Loads model from hugging face, we use it to get grounding dino model checkpoints
    '''
    cache_config_file = hf_hub_download(repo_id=repo_id, filename=ckpt_config_filename)

    args = SLConfig.fromfile(cache_config_file) 
    model = build_model(args)
    args.device = device

    cache_file = hf_hub_download(repo_id=repo_id, filename=filename)
    checkpoint = torch.load(cache_file, map_location='cpu')
    log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)
    model.eval()
    return model  


def transform_image(image) -> torch.Tensor:
    transform = T.Compose([
#         T.RandomResize([800], max_size=1333),
        T.ToTensor(),
#         T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])

    image_transformed, _ = transform(image, None)
    return image_transformed


class CFG:
    '''
        Defines variables used in our code
    '''
    # sam_type = "vit_h"
    SAM_MODELS = {
        "vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
        "vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
        "vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
    }
    INPAINTING_MODELS = {
        "Stable Diffusion" : "runwayml/stable-diffusion-inpainting",
        "Stable Diffusion 2" : "stabilityai/stable-diffusion-2-inpainting",
        "Stable Diffusion XL" : "diffusers/stable-diffusion-xl-1.0-inpainting-0.1"
    }
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    ckpt_repo_id = "ShilongLiu/GroundingDINO"
    ckpt_filename = "groundingdino_swinb_cogcoor.pth"
    ckpt_config_filename = "GroundingDINO_SwinB.cfg.py"



'''Build models'''
def build_sam(sam_type):
    checkpoint_url = CFG.SAM_MODELS[sam_type]
    sam = sam_model_registry[sam_type]()
    state_dict = torch.hub.load_state_dict_from_url(checkpoint_url)
    sam.load_state_dict(state_dict, strict=True)
    sam.to(device = CFG.device)
    sam = SamPredictor(sam)
    print('SAM is built !')
    return sam


def build_groundingdino():
    ckpt_repo_id = CFG.ckpt_repo_id
    ckpt_filename = CFG.ckpt_filename
    ckpt_config_filename = CFG.ckpt_config_filename
    groundingdino = load_model_hf(ckpt_repo_id, ckpt_filename, ckpt_config_filename)
    print('Grounding DINO is built !')
    return groundingdino


'''Predictions'''
def predict_dino(image_pil, text_prompt, box_threshold, text_threshold, model_groundingdino):
    image_trans = transform_image(image_pil)
    boxes, logits, phrases = predict(model = model_groundingdino,
                                     image = image_trans,
                                     caption = text_prompt,
                                     box_threshold = box_threshold,
                                     text_threshold = text_threshold,
                                     device = CFG.device)
    W, H = image_pil.size
    boxes = box_ops.box_cxcywh_to_xyxy(boxes) * torch.Tensor([W, H, W, H]) # center cood to corner cood
    print('DINO prediction done !')
    return boxes, logits, phrases


def predict_sam(image_pil, boxes, model_sam):
    image_array = np.asarray(image_pil)
    model_sam.set_image(image_array)
    transformed_boxes = model_sam.transform.apply_boxes_torch(boxes, image_array.shape[:2])
    masks, _, _ = model_sam.predict_torch(
        point_coords=None,
        point_labels=None,
        boxes=transformed_boxes.to(model_sam.device),
        multimask_output=False,
    )
    print('SAM prediction done !')
    return masks.cpu()


def mask_predict(image_pil, text_prompt, box_threshold, text_threshold, models):
    boxes, logits, phrases = predict_dino(image_pil, text_prompt, box_threshold, text_threshold, models[0])
    masks = torch.tensor([])
    if len(boxes) > 0:
        masks = predict_sam(image_pil, boxes, models[1])
        masks = masks.squeeze(1)
    return masks, boxes, phrases, logits


'''Utils'''

def load_image(image_path):
    return Image.open(image_path).convert("RGB")


def draw_image(image_pil, masks, boxes, alpha=0.4):
    image = np.asarray(image_pil)
    image = torch.from_numpy(image).permute(2, 0, 1)
    if len(masks) > 0:
        image = draw_segmentation_masks(image, masks=masks, colors=['red'] * len(masks), alpha=alpha)
    return image.numpy().transpose(1, 2, 0)

# torch.save(masks, 'masks.pt')


'''Visualise segmented results'''

def visualize_results(img1, img2, task):
    fig, axes = plt.subplots(1, 2, figsize=(20, 10))

    axes[0].imshow(img1)
    axes[0].set_title('Original Image')

    axes[1].imshow(img2)
    axes[1].set_title(f'{text_prompt} : {task}')

    for ax in axes:
        ax.axis('off')

# visualize_results(image_pil, output, 'segmented')

# x_units = 200
# y_units = -100
# text_prompt = 'wooden stool'
# image_path = '/kaggle/input/avataar/stool.jpeg'
# output_image_path = '/kaggle/working'

def build_models(sam_type):
    model_sam = build_sam(sam_type)
    model_groundingdino = build_groundingdino()
    models = [model_groundingdino, model_sam]
    return models
    

def main_fun(image_pil, x_units, y_units, text_prompt, box_threshold, text_threshold, inpaint_text_prompt, num_inference_steps, sam_type, inpainting_model):
#     x_units = 200
#     y_units = -100
#     text_prompt = 'wooden stool'
    
#     image_pil = load_image(image_path)
    models = build_models(sam_type)

    masks, boxes, phrases, logits = mask_predict(image_pil, text_prompt, box_threshold, text_threshold, models)
    segmented_image = draw_image(image_pil, masks, boxes, alpha=0.4)

    # Combined all segmentation masks
    combined_mask = torch.sum(masks, axis=0)
    combined_mask = np.where(combined_mask[:, :] != 0, True, False)
    
    '''Get masked object and background as two separate images'''
    mask = np.expand_dims(combined_mask, axis=-1)
    masked_object = image_pil * mask
    background = image_pil * ~mask


    '''Shifts image by x_units and y_units'''
    M = np.float32([[1, 0, x_units], [0, 1, -y_units]])
    shifted_image = cv2.warpAffine(masked_object, M, (masked_object.shape[1] , masked_object.shape[0]), borderMode=cv2.BORDER_CONSTANT, borderValue=(0, 0, 0))
    masked_shifted_image = np.where(shifted_image[:, :, 0] != 0, True, False)

    '''Load stable diffuser model at checkpoint finetuned for inpainting task'''
    inpainting_model_path = CFG.INPAINTING_MODELS[inpainting_model]

    if inpainting_model=='Stable Diffusion XL':
        pipe = AutoPipelineForInpainting.from_pretrained(inpainting_model_path, 
                      torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
    else:
        pipe = StableDiffusionInpaintPipeline.from_pretrained(inpainting_model_path,
            torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
    
    pipe.to(CFG.device)
    print('StableDiffusion model loaded !')

    # With Dilation
    structuring_element = np.ones((15, 15, 1), dtype=bool)
    extrapolated_mask = binary_dilation(mask, structure=structuring_element)
    mask_as_uint8 = extrapolated_mask.astype(np.uint8) * 255
    pil_mask = Image.fromarray(mask_as_uint8.squeeze(), mode='L').resize((1024, 1024))

    # # Without Dilation
    # pil_background = Image.fromarray(background)
    # mask_as_uint8 = mask.astype(np.uint8) * 255
    # pil_mask = Image.fromarray(mask_as_uint8.squeeze(), mode='L')

    print('Image Inpainting in process.....')
    '''Do inpainting on masked locations of original image'''
    # prompt = 'fill as per background'
    prompt = inpaint_text_prompt
    inpainted_image = pipe(prompt=prompt, image=image_pil, mask_image=pil_mask, num_inference_steps=num_inference_steps).images[0]
    print('Image INPAINTED !')
    # inpainted_image

    '''Get composite of shifted object and background inpainted imaage'''
    pil_shifted_image = Image.fromarray(shifted_image).resize(inpainted_image.size)
    np_shifted_image = np.array(pil_shifted_image)
    masked_shifted_image = np.where(np_shifted_image[:, :, 0] != 0, True, False)
    masked_shifted_image = np.expand_dims(masked_shifted_image, axis=-1)
    inpainted_shifted = np.array(inpainted_image) * ~masked_shifted_image

    shifted_image = cv2.resize(shifted_image, inpainted_image.size)
    output = inpainted_shifted + shifted_image
    output = Image.fromarray(output)
#     visualize_results(image_pil, output, 'shifted')
    segmented_image = Image.fromarray(segmented_image)
    return segmented_image.resize(image_pil.size), output.resize(image_pil.size)

import gradio as gr

image_blocks = gr.Blocks()
with image_blocks as demo:
    with gr.Row():
        with gr.Column():
            image = gr.Image(sources=['upload'], type="pil", label="Upload")
            # with gr.Row(elem_id="prompt-container").style(mobile_collapse=False, equal_height=True):
            text_prompt = gr.Textbox(placeholder = 'Your prompt (what you want in place of what is erased)', label="Object class", show_label=True)
            x_units = gr.Slider(minimum=0, maximum=300, step=10, value=100, label="x_units")
            y_units = gr.Slider(minimum=0, maximum=300, step=10, value=0, label="y_units")
            sam_type = gr.Dropdown(
                    ["vit_h", "vit_l", "vit_b"], label="ViT base model for SAM", value="vit_h"
                )
            inpainting_model = gr.Dropdown(
                    ["Stable Diffusion", "Stable Diffusion 2", "Stable Diffusion XL"], label="Model for inpainting", value="Stable Diffusion 2"
                )
            with gr.Accordion("Advanced options", open=False) as advanced_options:
                box_threshold = gr.Slider(
                    label="Box Threshold", minimum=0.0, maximum=1.0, value=0.23, step=0.01
                )
                num_inference_steps = gr.Slider(
                    label="number of inference steps", minimum=20, maximum=100, value=20, step=10
                )
                text_threshold = gr.Slider(
                    label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.01
                )
                inpaint_text_prompt = gr.Textbox(placeholder = 'Your prompt (default=fill as per background)', value="fill as per background", show_label=False)                   

            # text_prompt = gr.Textbox(lines=1, label="Prompt")
            btn = gr.Button(value="Submit")
        with gr.Column():
            image_out_seg = gr.Image(label="Segmented object", height=400, width=400)
            image_out_shift = gr.Image(label="Shifted object", height=400, width=400)
            
    btn.click(fn=main_fun, inputs=[image, x_units, y_units, text_prompt, box_threshold, text_threshold, inpaint_text_prompt, num_inference_steps, sam_type, inpainting_model], outputs=[image_out_seg, image_out_shift])

image_blocks.launch()