import gradio as gr
import requests
from PIL import Image
from pdf2image import convert_from_path
from typing import List, Union, Dict, Optional, Tuple
from io import BytesIO
import base64
import numpy as np
import json

prompt = """You are an advanced document parsing bot. Given the fixture schedule I provided, you need to parse out 

1. the name of the fixture
2. the company that produces this fixture
3. the description of this fixture. This is a 20-word description which summarize the size, function and the mounting method of the fixture and mention any necessary accesories. For example: 1" x 1" recessed downlight.
4. the part number of this fixture. It is a series of specification codes connected with - , and you can get the info by reading the texts marked in a different color or reading the top bar. Include every specification code in a correct order in your answer. 
5. the input wattage of this fixture, short answer. Please answer the wattage according to the part number you found in question 3 

Please format your response in json format
{
    "fixture_name": <fixture name>,
    "manufacture_name": <company name>,
    "fixture_description": <description>,
    "mfr": <part number>,
    "input wattage": <numerical input wattage>
}

---
For example
{
    "fixture_name": "SW24/1.5 Led Strips - Static White",
    "manufacture_name": "Q-Tran Inc.",
    "fixture_description": "Surface mounted static white LED strip."
    "mfr": "SW24-1.5-DRY-30-BW-BW-WH-CL2-535",
    "input wattage": "1.5W"
}"""

def query_openai_api(messages, model, temperature=0, api_key=None, organization_key=None, json_mode=False):
    try:
        url = "https://api.openai.com/v1/chat/completions"
        if organization_key is not None:
            headers = {
                "Content-Type": "application/json",
                "Authorization": f"Bearer {api_key}",
                "OpenAI-Organization": f"{organization_key}",
            }
        else:
            headers = {
                "Content-Type": "application/json",
                "Authorization": f"Bearer {api_key}",
            }
        data = {"model": model, "messages": messages, "temperature": temperature}
        if json_mode:
            data["response_format"] = {"type": "json_object"}

        response = requests.post(url, headers=headers, data=json.dumps(data)).json()
        print(response)
        return response["choices"][0]["message"]["content"].lstrip(), response
    except Exception as e:
        print(f"An error occurred: {e}")
        return f"API_ERROR: {e}", None

class GPT4V_Client:
    def __init__(self, api_key, organization_key, model_name="gpt-4o", max_tokens=512):
        self.api_key = api_key
        self.organization_key = organization_key
        self.model_name = model_name
        self.max_tokens = max_tokens

    def chat(self, messages, json_mode):
        return query_openai_api(messages, self.model_name, api_key=self.api_key, organization_key=self.organization_key, json_mode=json_mode)

    def one_step_chat(
        self,
        text,
        image: Union[Image.Image, np.ndarray],
        system_msg: Optional[str] = None,
        json_mode=False,
    ):
        jpeg_buffer = BytesIO()

        # Save the image as JPEG to the buffer
        if isinstance(image, np.ndarray):
            image = Image.fromarray(image)
        image = image.convert("RGB")
        image.save(jpeg_buffer, format="JPEG")

        # Get the byte data from the buffer
        jpeg_data = jpeg_buffer.getvalue()

        # Encode the JPEG image data in base64
        jpg_base64 = base64.b64encode(jpeg_data)

        # If you need it in string format
        jpg_base64_str = jpg_base64.decode("utf-8")
        messages = []
        if system_msg is not None:
            messages.append({"role": "system", "content": system_msg})
        messages += [
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": text},
                    {
                        "type": "image_url",
                        "image_url": {
                            "url": f"data:image/jpeg;base64,{jpg_base64_str}"
                        },
                    },
                ],
            }
        ]
        return self.chat(messages, json_mode=json_mode)

    def one_step_multi_image_chat(
        self,
        text,
        images: list[Union[Image.Image, np.ndarray]],
        system_msg: Optional[str] = None,
        json_mode=False,
    ):
        details = [i["detail"] for i in images]
        img_strs = []
        for img_info in images:
            image = img_info["image"]
            jpeg_buffer = BytesIO()

            if isinstance(image, np.ndarray):
                image = Image.fromarray(image)
            image = image.convert("RGB")
            image.save(jpeg_buffer, format="JPEG")
            jpeg_data = jpeg_buffer.getvalue()
            jpg_base64 = base64.b64encode(jpeg_data)
            jpg_base64_str = jpg_base64.decode("utf-8")
            img_strs.append(f"data:image/jpeg;base64,{jpg_base64_str}")

        messages = []
        if system_msg is not None:
            messages.append({"role": "system", "content": system_msg})

        img_sub_msg = [
            {
                "type": "image_url",
                "image_url": {"url": img_str, "detail": detail},
            }
            for img_str, detail in zip(img_strs, details)
        ]
        messages += [
            {
                "role": "user",
                "content": [{"type": "text", "text": text}] + img_sub_msg,
            }
        ]
        return self.chat(messages, json_mode=json_mode)

def markdown_json_to_table(markdown_json_string, iteration, thumbnail_md):
    """
    Convert the GPT JSON string into a markdown row with the first column as the PDF thumbnail.
    Args:
        markdown_json_string: the raw markdown (JSON) string from GPT
        iteration: which row # we are on
        thumbnail_md: something like ![pdfpage](data:image/jpeg;base64,xxxxxx)
    Returns:
        A string with either:
          - The header row + first data row, if iteration==0
          - Additional data row, if iteration>0
    """
    # Try to detect if the JSON is enclosed in triple-backticks
    # so we can parse it out properly:
    if markdown_json_string.strip().startswith("```"):
        # Remove the backticks and possible extra notations
        json_string = markdown_json_string.strip().strip("```").strip("json").strip()
    else:
        # If the model didn't wrap it in markdown
        json_string = markdown_json_string.strip()

    # Safely parse JSON
    try:
        json_obj = json.loads(json_string)
    except Exception:
        # If it can't parse, return empty
        return ""

    # Turn the JSON object into a list of values for easier table building
    # e.g. [fixture_name, manufacture_name, mfr, input wattage]
    keys = list(json_obj.keys())
    values = list(json_obj.values())

    # We want the first column to be the PDF thumbnail
    # So the table columns become: [Thumbnail, key1, key2, key3, ...]
    # This means we have one extra column in front compared to the JSON.

    # If iteration == 0, produce header
    # e.g. | Thumbnail | fixture_name | manufacture_name | mfr | input wattage |
    if iteration == 0:
        header = ["Thumbnail"] + keys
        header_row = "| " + " | ".join(header) + " |\n"
        sep_row = "|" + "|".join(["---"] * len(header)) + "|\n"
    else:
        header_row = ""
        sep_row = ""

    # Then produce the data row
    # e.g. | ![pdfpage](data:image/jpeg;base64,xxx) | "SW24..." | "Q-Tran Inc." | ...
    str_values = [str(v) for v in values]
    data_row = "| " + thumbnail_md + " | " + " | ".join(str_values) + " |\n"

    return header_row + sep_row + data_row


def gptRead(cutsheets, api_key, organization_key):
    fixtureInfo = ""
    iteration = 0
    client = GPT4V_Client(api_key=api_key, organization_key=organization_key)

    for cutsheet in cutsheets:
        # Convert the first page of the PDF into an image
        source = (convert_from_path(cutsheet.name))[0]

        # Create a smaller thumbnail
        thumbnail_img = source.copy()
        thumbnail_img.thumbnail((100, 100))

        # Encode the thumbnail to base64 for embedding in Markdown
        thumb_io = BytesIO()
        thumbnail_img.save(thumb_io, format="JPEG")
        base64_thumb = base64.b64encode(thumb_io.getvalue()).decode('utf-8')
        thumbnail_md = f"![pdfpage](data:image/jpeg;base64,{base64_thumb})"

        # Chat with GPT about the original (non-thumbnail) image
        response_text, _ = client.one_step_chat(prompt, source)

        # Convert the GPT JSON to a Markdown row, including the thumbnail in the first column
        fixtureInfo += markdown_json_to_table(response_text, iteration, thumbnail_md)

        iteration += 1

    return fixtureInfo

if __name__ == "__main__":
    with gr.Blocks() as demo:
        api_key = gr.Textbox(label="Input your ChatGPT4 API Key: ")
        organization_key = gr.Textbox(label="Input your ChatGPT4 API Organization Key: ", info="(optional)")
        gr.Markdown("# Lighting Manufacture Cutsheet GPT Tool")
        file_uploader = gr.UploadButton("Upload cutsheets", type="filepath", file_count="multiple")
        form = gr.Markdown()

        # When user uploads, call gptRead -> produce the final Markdown w/ table
        file_uploader.upload(fn=gptRead, inputs=[file_uploader, api_key, organization_key], outputs=form)

    demo.launch(share=True)