from io import BytesIO
import string
import gradio as gr
import requests
from caption_anything import CaptionAnything
import torch
import json
import sys
import argparse
from caption_anything import parse_augment
import os

# download sam checkpoint if not downloaded
def download_checkpoint(url, folder, filename):
    os.makedirs(folder, exist_ok=True)
    filepath = os.path.join(folder, filename)

    if not os.path.exists(filepath):
        response = requests.get(url, stream=True)
        with open(filepath, "wb") as f:
            for chunk in response.iter_content(chunk_size=8192):
                if chunk:
                    f.write(chunk)

    return filepath
checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
folder = "segmenter"
filename = "sam_vit_h_4b8939.pth"

title = """<h1 align="center">Caption-Anything</h1>"""
description = """Gradio demo for Caption Anything, image to dense captioning generation with various language styles. To use it, simply upload your image, or click one of the examples to load them.
<br> <strong>Code</strong>: GitHub repo: <a href='https://github.com/ttengwang/Caption-Anything' target='_blank'></a>
"""

examples = [
    ["test_img/img2.jpg", "[[1000, 700, 1]]"]
]

args = parse_augment()

def get_prompt(chat_input, click_state):    
    points = click_state[0]
    labels = click_state[1]
    inputs = json.loads(chat_input)
    for input in inputs:
        points.append(input[:2])
        labels.append(input[2])
    
    prompt = {
        "prompt_type":["click"],
        "input_point":points,
        "input_label":labels,
        "multimask_output":"True",
    }
    return prompt
    
def inference_seg_cap(image_input, chat_input, language, sentiment, factuality, length, state, click_state):
    controls = {'length': length,
             'sentiment': sentiment,
             'factuality': factuality,
             'language': language}
    prompt = get_prompt(chat_input, click_state)
    print('prompt: ', prompt, 'controls: ', controls)
    out = model.inference(image_input, prompt, controls)
    state = state + [(None, "Image point: {}, Input label: {}".format(prompt["input_point"], prompt["input_label"]))]
    for k, v in out['generated_captions'].items():
        state = state + [(f'{k}: {v}', None)]
    click_state[2].append(out['generated_captions']['raw_caption'])
    image_output_mask = out['mask_save_path']
    image_output_crop = out['crop_save_path']
    return state, state, click_state, image_output_mask, image_output_crop


def upload_callback(image_input, state):
    state = state + [('Image size: ' + str(image_input.size), None)]
    return state

# get coordinate in format [[x,y,positive/negative]]
def get_select_coords(image_input, point_prompt, language, sentiment, factuality, length, state, click_state, evt: gr.SelectData):
        print("point_prompt: ", point_prompt)
        if point_prompt == 'Positive Point':
            coordinate = "[[{}, {}, 1]]".format(str(evt.index[0]), str(evt.index[1]))
        else:
            coordinate = "[[{}, {}, 0]]".format(str(evt.index[0]), str(evt.index[1]))
        return (coordinate,) + inference_seg_cap(image_input, coordinate, language, sentiment, factuality, length, state, click_state)
    
def chat_with_points(chat_input, click_state, state):
    points, labels, captions = click_state
    # point_chat_prompt = "I want you act as a chat bot in terms of image. I will give you some points (w, h) in the image and tell you what happed on the point in natural language. Note that (0, 0) refers to the top-left corner of the image, w refers to the width and h refers the height. You should chat with me based on the fact in the image instead of imagination. Now I tell you the points with their visual description:\n{points_with_caps}\n. Now begin chatting! Human: {chat_input}\nAI: "
    # "The image is of width {width} and height {height}." 
    point_chat_prompt = "a) Revised prompt: I am an AI trained to chat with you about an image based on specific points (w, h) you provide, along with their visual descriptions. Please note that (0, 0) refers to the top-left corner of the image, w refers to the width, and h refers to the height. Here are the points and their descriptions you've given me: {points_with_caps}. Now, let's chat! Human: {chat_input} AI:"
    prev_visual_context = ""
    pos_points = [f"{points[i][0]}, {points[i][1]}" for i in range(len(points)) if labels[i] == 1]
    prev_visual_context = ', '.join(pos_points) + captions[-1] + '\n'
    chat_prompt = point_chat_prompt.format(**{"points_with_caps": prev_visual_context, "chat_input": chat_input})
    response = model.text_refiner.llm(chat_prompt)
    state = state + [(chat_input, response)]
    return state, state

def init_openai_api_key(api_key):
    # os.environ['OPENAI_API_KEY'] = api_key
    global model
    model = CaptionAnything(args, api_key)

css='''
#image_upload{min-height:200px}
#image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 200px}
'''

with gr.Blocks(css=css) as iface:
    state = gr.State([])
    click_state = gr.State([[],[],[]])
    caption_state = gr.State([[]])
    gr.Markdown(title)
    gr.Markdown(description)

    with gr.Column():
        openai_api_key = gr.Textbox(
            placeholder="Input your openAI API key and press Enter",
            show_label=False,
            lines=1,
            type="password",
        )
        openai_api_key.submit(init_openai_api_key, inputs=[openai_api_key])
        
        with gr.Row():
            with gr.Column(scale=0.7):
                image_input = gr.Image(type="pil", interactive=True, label="Image", elem_id="image_upload").style(height=260,scale=1.0)

                with gr.Row(scale=0.7):
                    point_prompt = gr.Radio(
                        choices=["Positive Point",  "Negative Point"],
                        value="Positive Point",
                        label="Points",
                        interactive=True,
                    )
                
                # with gr.Row():
                language = gr.Radio(
                    choices=["English", "Chinese", "French", "Spanish", "Arabic", "Portuguese","Cantonese"],
                    value="English",
                    label="Language",
                    interactive=True,
                )
                sentiment = gr.Radio(
                    choices=["Positive", "Natural", "Negative"],
                    value="Natural",
                    label="Sentiment",
                    interactive=True,
                )
                factuality = gr.Radio(
                    choices=["Factual", "Imagination"],
                    value="Factual",
                    label="Factuality",
                    interactive=True,
                )
                length = gr.Slider(
                    minimum=5,
                    maximum=100,
                    value=10,
                    step=1,
                    interactive=True,
                    label="Length",
                )

            with gr.Column(scale=1.5):
                with gr.Row():
                    image_output_mask= gr.Image(type="pil", interactive=False, label="Mask").style(height=260,scale=1.0)
                    image_output_crop= gr.Image(type="pil", interactive=False, label="Cropped Image by Mask", show_progress=False).style(height=260,scale=1.0)
                chatbot = gr.Chatbot(label="Chat Output",).style(height=450,scale=0.5)
        
        with gr.Row():
            with gr.Column(scale=0.7):
                prompt_input = gr.Textbox(lines=1, label="Input Prompt (A list of points like : [[100, 200, 1], [200,300,0]])")
                prompt_input.submit(
                    inference_seg_cap,
                    [
                        image_input,
                        prompt_input,
                        language,
                        sentiment,
                        factuality,
                        length,
                        state,
                        click_state
                    ],
                    [chatbot, state, click_state, image_output_mask, image_output_crop],
                    show_progress=False
                )
                
                image_input.upload(
                    upload_callback,
                    [image_input, state],
                    [chatbot]
                    )

                with gr.Row():
                    clear_button = gr.Button(value="Clear Click", interactive=True)
                    clear_button.click(
                        lambda: ("", [[], [], []], None, None),
                        [],
                        [prompt_input, click_state, image_output_mask, image_output_crop],
                        queue=False,
                        show_progress=False
                    )
                    
                    clear_button = gr.Button(value="Clear", interactive=True)
                    clear_button.click(
                        lambda: ("", [], [], [[], [], []], None, None),
                        [],
                        [prompt_input, chatbot, state, click_state, image_output_mask, image_output_crop],
                        queue=False,
                        show_progress=False
                    )
                    
                    submit_button = gr.Button(
                        value="Submit", interactive=True, variant="primary"
                    )
                    submit_button.click(
                        inference_seg_cap,
                        [
                            image_input,
                            prompt_input,
                            language,
                            sentiment,
                            factuality,
                            length,
                            state,
                            click_state
                        ],
                        [chatbot, state, click_state, image_output_mask, image_output_crop],
                        show_progress=False
                    )
                    
                # select coordinate
                image_input.select(
                    get_select_coords, 
                    inputs=[image_input,point_prompt,language,sentiment,factuality,length,state,click_state], 
                    outputs=[prompt_input, chatbot, state, click_state, image_output_mask, image_output_crop],
                    show_progress=False
                    )

                image_input.change(
                    lambda: ("", [], [[], [], []]),
                    [],
                    [chatbot, state, click_state],
                    queue=False,
                )
                
            with gr.Column(scale=1.5):
                chat_input = gr.Textbox(lines=1, label="Chat Input")
                chat_input.submit(chat_with_points, [chat_input, click_state, state], [chatbot, state])
                
                    
    examples = gr.Examples(
        examples=examples,
        inputs=[image_input, prompt_input],
    )

iface.queue(concurrency_count=1, api_open=False, max_size=10)
iface.launch(server_name="0.0.0.0", enable_queue=True, server_port=args.port, share=args.gradio_share)