"""PaliGemma demo gradio app."""

import datetime
import functools
import glob
import json
import logging
import os
import time

import gradio as gr
import PIL.Image
import gradio_helpers
import models
import paligemma_parse

INTRO_TEXT = """🤲 PaliGemma GGUF demo\n\n
| [Paper](https://arxiv.org/abs/2407.07726)
| [GitHub](https://github.com/google-research/big_vision/blob/main/big_vision/configs/proj/paligemma/README.md) 
| [HF blog post](https://huggingface.co/blog/paligemma) 
| [Google blog post](https://developers.googleblog.com/en/gemma-family-and-toolkit-expansion-io-2024)
| [Vertex AI Model Garden](https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/363) 
| [Demo](https://huggingface.co/spaces/google/paligemma) 
|\n\n
[PaliGemma](https://ai.google.dev/gemma/docs/paligemma) is an open vision-language model by Google, 
inspired by [PaLI-3](https://arxiv.org/abs/2310.09199) and 
built with open components such as the [SigLIP](https://arxiv.org/abs/2303.15343) 
vision model and the [Gemma](https://arxiv.org/abs/2403.08295) language model. PaliGemma is designed as a versatile 
model for transfer to a wide range of vision-language tasks such as image and short video caption, visual question 
answering, text reading, object detection and object segmentation.
\n\n
This space includes models fine-tuned on a mix of downstream tasks. 
See the [blog post](https://huggingface.co/blog/paligemma) and 
[README](https://github.com/google-research/big_vision/blob/main/big_vision/configs/proj/paligemma/README.md) 
for detailed information how to use and fine-tune PaliGemma models.
\n\n
**This is an experimental research model.** Make sure to add appropriate guardrails when using the model for applications.
"""


make_image = lambda value, visible: gr.Image(
    value, label='Image', type='filepath', visible=visible)
make_annotated_image = functools.partial(gr.AnnotatedImage, label='Image')
make_highlighted_text = functools.partial(gr.HighlightedText, label='Output')


# https://coolors.co/4285f4-db4437-f4b400-0f9d58-e48ef1
COLORS = ['#4285f4', '#db4437', '#f4b400', '#0f9d58', '#e48ef1']


@gradio_helpers.synced
def compute(image, prompt, model_name, sampler):
  """Runs model inference."""
  if image is None:
    raise gr.Error('Image required')

  logging.info('prompt="%s"', prompt)

  if isinstance(image, str):
    image = PIL.Image.open(image)
  if gradio_helpers.should_mock():
    logging.warning('Mocking response')
    time.sleep(2.)
    output = paligemma_parse.EXAMPLE_STRING
  else:
    if not model_name:
      raise gr.Error('Models not loaded yet')
    output = models.generate(model_name, sampler, image, prompt)
    # output = 'output'
    logging.info('output="%s"', output)

  width, height = image.size
  objs = paligemma_parse.extract_objs(output, width, height, unique_labels=True)
  labels = set(obj.get('name') for obj in objs if obj.get('name'))
  color_map = {l: COLORS[i % len(COLORS)] for i, l in enumerate(labels)}
  highlighted_text = [(obj['content'], obj.get('name')) for obj in objs]
  annotated_image = (
      image,
      [
          (
              obj['mask'] if obj.get('mask') is not None else obj['xyxy'],
              obj['name'] or '',
          )
          for obj in objs
          if 'mask' in obj or 'xyxy' in obj
      ],
  )
  has_annotations = bool(annotated_image[1])
  return (
      make_highlighted_text(
          highlighted_text, visible=True, color_map=color_map),
      make_image(image, visible=not has_annotations),
      make_annotated_image(
          annotated_image, visible=has_annotations, width=width, height=height,
          color_map=color_map),
  )


def warmup(model_name):
  image = PIL.Image.new('RGB', [1, 1])
  _ = compute(image, '', model_name, 'greedy')


def reset():
  return (
      '', make_highlighted_text('', visible=False),
      make_image(None, visible=True), make_annotated_image(None, visible=False),
  )


def create_app():
  """Creates demo UI."""

  make_model = lambda choices: gr.Dropdown(
      value=(choices + [''])[0],
      choices=choices,
      label='Model',
      visible=bool(choices),
  )
  make_prompt = lambda value, visible=True: gr.Textbox(
      value, label='Prompt', visible=visible)

  with gr.Blocks() as demo:

    ##### Main UI structure.

    gr.Markdown(INTRO_TEXT)
    with gr.Row():
      image = make_image(None, visible=True)  # input
      annotated_image = make_annotated_image(None, visible=False)  # output
      with gr.Column():
        with gr.Row():
          prompt = make_prompt('', visible=True)
        model_info = gr.Markdown(label='Model Info')
        with gr.Row():
          model = make_model([])
          samplers = [
              'greedy', 'nucleus(0.1)', 'nucleus(0.3)', 'temperature(0.5)']
          sampler = gr.Dropdown(
              value=samplers[0], choices=samplers, label='Decoding'
          )
        with gr.Row():
          run = gr.Button('Run', variant='primary')
          clear = gr.Button('Clear')
        highlighted_text = make_highlighted_text('', visible=False)

    ##### UI logic.

    def update_ui(model, prompt):
      prompt = make_prompt(prompt, visible=True)
      model_info = f'Model `{model}` – {models.MODELS_INFO.get(model, "No info.")}'
      return [prompt, model_info]

    gr.on(
        [model.change],
        update_ui,
        [model, prompt],
        [prompt, model_info],
    )

    gr.on(
        [run.click, prompt.submit],
        compute,
        [image, prompt, model, sampler],
        [highlighted_text, image, annotated_image],
    )
    clear.click(
        reset, None, [prompt, highlighted_text, image, annotated_image]
    )

    ##### Examples.

    gr.set_static_paths(['examples/'])
    all_examples = [json.load(open(p)) for p in glob.glob('examples/*.json')]
    logging.info('loaded %d examples', len(all_examples))
    example_image = gr.Image(
        label='Image', visible=False)  # proxy, never visible
    example_model = gr.Text(
        label='Model', visible=False)  # proxy, never visible
    example_prompt = gr.Text(
        label='Prompt', visible=False)  # proxy, never visible
    example_license = gr.Markdown(
        label='Image License', visible=False)  # placeholder, never visible
    gr.Examples(
        examples=[
            [
                f'examples/{ex["name"]}.jpg',
                ex['prompt'],
                ex['model'],
                ex['license'],
            ]
            for ex in all_examples
            if ex['model'] in models.MODELS
        ],
        inputs=[example_image, example_prompt, example_model, example_license],
    )

    ##### Examples UI logic.

    example_image.change(
        lambda image_path: (
            make_image(image_path, visible=True),
            make_annotated_image(None, visible=False),
            make_highlighted_text('', visible=False),
        ),
        example_image,
        [image, annotated_image, highlighted_text],
    )
    def example_model_changed(model):
      if model not in gradio_helpers.get_paths():
        raise gr.Error(f'Model "{model}" not loaded!')
      return model
    example_model.change(example_model_changed, example_model, model)
    example_prompt.change(make_prompt, example_prompt, prompt)

    ##### Status.

    status = gr.Markdown(f'Startup: {datetime.datetime.now()}')
    # gpu_kind = gr.Markdown(f'GPU=?')
    # demo.load(
    #     lambda: [
    #         gradio_helpers.get_status(),
    #         make_model(list(gradio_helpers.get_paths())),
    #     ],
    #     None,
    #     [status, model],
    # )
    # def get_gpu_kind():
    #   device = jax.devices()[0]
    #   if not gradio_helpers.should_mock() and device.platform != 'gpu':
    #     raise gr.Error('GPU not visible to JAX!')
    #   return f'GPU={device.device_kind}'
    # demo.load(get_gpu_kind, None, gpu_kind)

  return demo


if __name__ == '__main__':

  logging.basicConfig(level=logging.INFO,
                      format='%(asctime)s - %(levelname)s - %(message)s')

  for k, v in os.environ.items():
    logging.info('environ["%s"] = %r', k, v)

  gradio_helpers.set_warmup_function(warmup)
  for name, (repo, filenames) in models.MODELS.items():
    gradio_helpers.register_download(name, repo, filenames)

  create_app().queue().launch()