import os
import torch
import numpy as np
import triton_python_backend_utils as pb_utils
from llama_cpp import Llama
from llama_cpp.llama_chat_format import Llava15ChatHandler


class TritonPythonModel:
    """Your Python model must use the same class name. Every Python model
    that is created must have "TritonPythonModel" as the class name.
    """

    @staticmethod
    def auto_complete_config(auto_complete_model_config):
        """`auto_complete_config` is called only once when loading the model
        assuming the server was not started with
        `--disable-auto-complete-config`. Implementing this function is
        optional. No implementation of `auto_complete_config` will do nothing.
        This function can be used to set `max_batch_size`, `input` and `output`
        properties of the model using `set_max_batch_size`, `add_input`, and
        `add_output`. These properties will allow Triton to load the model with
        minimal model configuration in absence of a configuration file. This
        function returns the `pb_utils.ModelConfig` object with these
        properties. You can use the `as_dict` function to gain read-only access
        to the `pb_utils.ModelConfig` object. The `pb_utils.ModelConfig` object
        being returned from here will be used as the final configuration for
        the model.

        Note: The Python interpreter used to invoke this function will be
        destroyed upon returning from this function and as a result none of the
        objects created here will be available in the `initialize`, `execute`,
        or `finalize` functions.

        Parameters
        ----------
        auto_complete_model_config : pb_utils.ModelConfig
          An object containing the existing model configuration. You can build
          upon the configuration given by this object when setting the
          properties for this model.

        Returns
        -------
        pb_utils.ModelConfig
          An object containing the auto-completed model configuration
        """
        inputs = [{
            'name': 'PROMPT',
            'data_type': 'TYPE_STRING',
            'dims': [-1]
        }, {
            'name': 'IMAGES',
            'data_type': 'TYPE_STRING',  # Changed from TYPE_FP16 to TYPE_STRING
            'dims': [-1]  # Changed to indicate a variable-length array of strings
        }]

        outputs = [{
            'name': 'RESULTS',
            'data_type': 'TYPE_STRING',
            'dims': [-1]
        }]

        config = auto_complete_model_config.as_dict()
        input_names = []
        output_names = []
        for input in config['input']:
            input_names.append(input['name'])
        for output in config['output']:
            output_names.append(output['name'])

        for input in inputs:
            if input['name'] not in input_names:
                auto_complete_model_config.add_input(input)
        for output in outputs:
            if output['name'] not in output_names:
                auto_complete_model_config.add_output(output)

        auto_complete_model_config.set_dynamic_batching()

        return auto_complete_model_config

    def initialize(self, args):
        """`initialize` is called only once when the model is being loaded.
        Implementing `initialize` function is optional. This function allows
        the model to initialize any state associated with this model.

        Parameters
        ----------
        args : dict
          Both keys and values are strings. The dictionary keys and values are:
          * model_config: A JSON string containing the model configuration
          * model_instance_kind: A string containing model instance kind
          * model_instance_device_id: A string containing model instance device
            ID
          * model_repository: Model repository path
          * model_version: Model version
          * model_name: Model name
        """
        chat_handler = Llava15ChatHandler(clip_model_path="/weights/mmproj-model-f16.gguf", verbose=True)
        self.model = Llama(model_path="/weights/ggml-model-q4_0.gguf", chat_handler=chat_handler, n_ctx=2048, logits_all=True, n_gpu_layers=-1)
        print('Initialized...')

    def run_inference(self, prompt, image):
        image_data = f"data:image/png;base64,{image}"
        messages = [
            {"role": "system", "content": "You are an assistant who perfectly describes images."},
             {
                 "role": "user",
                 "content": [
                     {"type": "image_url", "image_url": {"url": image_data}},
                      {"type" : "text", "text": prompt}
                     ]
                 }
            ]
        result = self.model.create_chat_completion(messages=messages)
        output_string = result["choices"][0]["message"]["content"]
        output_data = np.array([output_string.encode('utf-8')], dtype=object)
        return output_data

    def execute(self, requests):
        """`execute` must be implemented in every Python model. `execute`
        function receives a list of pb_utils.InferenceRequest as the only
        argument. This function is called when an inference is requested
        for this model.

        Parameters
        ----------
        requests : list
          A list of pb_utils.InferenceRequest

        Returns
        -------
        list
          A list of pb_utils.InferenceResponse. The length of this list must
          be the same as `requests`
        """

        responses = []

        for request in requests:
            # Perform inference on the request and append it to responses
            # list...
            prompt = [
                t.decode("UTF-8")
                for t in pb_utils.get_input_tensor_by_name(request, "PROMPT")
                .as_numpy()
                .tolist()
            ][0]
            image = [
                t.decode("UTF-8")
                for t in pb_utils.get_input_tensor_by_name(request, "IMAGES")
                .as_numpy()
                .tolist()
            ][0]
            results = self.run_inference(prompt, image)

            # Sending results
            inference_response = pb_utils.InferenceResponse(output_tensors=[
                pb_utils.Tensor(
                    "RESULTS",
                    results,
                )
            ])

            responses.append(inference_response)

        return responses

    def finalize(self):
        """`finalize` is called only once when the model is being unloaded.
        Implementing `finalize` function is optional. This function allows
        the model to perform any necessary clean ups before exit.
        """
        print('Cleaning up...')