"""PaliGemma demo gradio app.""" import datetime import functools import glob import json import logging import os import time import gradio as gr import jax import PIL.Image import gradio_helpers import models import paligemma_parse INTRO_TEXT = """🤲 PaliGemma demo\n\n | [GitHub](https://github.com/google-research/big_vision/blob/main/big_vision/configs/proj/paligemma/README.md) | [HF blog post](https://huggingface.co/blog/paligemma) | [Google blog post](https://developers.googleblog.com/en/gemma-family-and-toolkit-expansion-io-2024) | [Vertex AI Model Garden](https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/363) | [Demo](https://huggingface.co/spaces/google/paligemma) |\n\n [PaliGemma](https://ai.google.dev/gemma/docs/paligemma) is an open vision-language model by Google, inspired by [PaLI-3](https://arxiv.org/abs/2310.09199) and built with open components such as the [SigLIP](https://arxiv.org/abs/2303.15343) vision model and the [Gemma](https://arxiv.org/abs/2403.08295) language model. PaliGemma is designed as a versatile model for transfer to a wide range of vision-language tasks such as image and short video caption, visual question answering, text reading, object detection and object segmentation. \n\n This space includes models fine-tuned on a mix of downstream tasks. See the [blog post](https://huggingface.co/blog/paligemma) and [README](https://github.com/google-research/big_vision/blob/main/big_vision/configs/proj/paligemma/README.md) for detailed information how to use and fine-tune PaliGemma models. \n\n **This is an experimental research model.** Make sure to add appropriate guardrails when using the model for applications. """ make_image = lambda value, visible: gr.Image( value, label='Image', type='filepath', visible=visible) make_annotated_image = functools.partial(gr.AnnotatedImage, label='Image') make_highlighted_text = functools.partial(gr.HighlightedText, label='Output') # https://coolors.co/4285f4-db4437-f4b400-0f9d58-e48ef1 COLORS = ['#4285f4', '#db4437', '#f4b400', '#0f9d58', '#e48ef1'] @gradio_helpers.synced def compute(image, prompt, model_name, sampler): """Runs model inference.""" if image is None: raise gr.Error('Image required') logging.info('prompt="%s"', prompt) if isinstance(image, str): image = PIL.Image.open(image) if gradio_helpers.should_mock(): logging.warning('Mocking response') time.sleep(2.) output = paligemma_parse.EXAMPLE_STRING else: if not model_name: raise gr.Error('Models not loaded yet') output = models.generate(model_name, sampler, image, prompt) logging.info('output="%s"', output) width, height = image.size objs = paligemma_parse.extract_objs(output, width, height, unique_labels=True) labels = set(obj.get('name') for obj in objs if obj.get('name')) color_map = {l: COLORS[i % len(COLORS)] for i, l in enumerate(labels)} highlighted_text = [(obj['content'], obj.get('name')) for obj in objs] annotated_image = ( image, [ ( obj['mask'] if obj.get('mask') is not None else obj['xyxy'], obj['name'] or '', ) for obj in objs if 'mask' in obj or 'xyxy' in obj ], ) has_annotations = bool(annotated_image[1]) return ( make_highlighted_text( highlighted_text, visible=True, color_map=color_map), make_image(image, visible=not has_annotations), make_annotated_image( annotated_image, visible=has_annotations, width=width, height=height, color_map=color_map), ) def warmup(model_name): image = PIL.Image.new('RGB', [1, 1]) _ = compute(image, '', model_name, 'greedy') def reset(): return ( '', make_highlighted_text('', visible=False), make_image(None, visible=True), make_annotated_image(None, visible=False), ) def create_app(): """Creates demo UI.""" make_model = lambda choices: gr.Dropdown( value=(choices + [''])[0], choices=choices, label='Model', visible=bool(choices), ) make_prompt = lambda value, visible=True: gr.Textbox( value, label='Prompt', visible=visible) with gr.Blocks() as demo: ##### Main UI structure. gr.Markdown(INTRO_TEXT) with gr.Row(): image = make_image(None, visible=True) # input annotated_image = make_annotated_image(None, visible=False) # output with gr.Column(): with gr.Row(): prompt = make_prompt('', visible=True) model_info = gr.Markdown(label='Model Info') with gr.Row(): model = make_model([]) samplers = [ 'greedy', 'nucleus(0.1)', 'nucleus(0.3)', 'temperature(0.5)'] sampler = gr.Dropdown( value=samplers[0], choices=samplers, label='Decoding' ) with gr.Row(): run = gr.Button('Run', variant='primary') clear = gr.Button('Clear') highlighted_text = make_highlighted_text('', visible=False) ##### UI logic. def update_ui(model, prompt): prompt = make_prompt(prompt, visible=True) model_info = f'Model `{model}` – {models.MODELS_INFO.get(model, "No info.")}' return [prompt, model_info] gr.on( [model.change], update_ui, [model, prompt], [prompt, model_info], ) gr.on( [run.click, prompt.submit], compute, [image, prompt, model, sampler], [highlighted_text, image, annotated_image], ) clear.click( reset, None, [prompt, highlighted_text, image, annotated_image] ) ##### Examples. gr.set_static_paths(['examples/']) all_examples = [json.load(open(p)) for p in glob.glob('examples/*.json')] logging.info('loaded %d examples', len(all_examples)) example_image = gr.Image( label='Image', visible=False) # proxy, never visible example_model = gr.Text( label='Model', visible=False) # proxy, never visible example_prompt = gr.Text( label='Prompt', visible=False) # proxy, never visible example_license = gr.Markdown( label='Image License', visible=False) # placeholder, never visible gr.Examples( examples=[ [ f'examples/{ex["name"]}.jpg', ex['prompt'], ex['model'], ex['license'], ] for ex in all_examples if ex['model'] in models.MODELS ], inputs=[example_image, example_prompt, example_model, example_license], ) ##### Examples UI logic. example_image.change( lambda image_path: ( make_image(image_path, visible=True), make_annotated_image(None, visible=False), make_highlighted_text('', visible=False), ), example_image, [image, annotated_image, highlighted_text], ) def example_model_changed(model): if model not in gradio_helpers.get_paths(): raise gr.Error(f'Model "{model}" not loaded!') return model example_model.change(example_model_changed, example_model, model) example_prompt.change(make_prompt, example_prompt, prompt) ##### Status. status = gr.Markdown(f'Startup: {datetime.datetime.now()}') gpu_kind = gr.Markdown(f'GPU=?') demo.load( lambda: [ gradio_helpers.get_status(), make_model(list(gradio_helpers.get_paths())), ], None, [status, model], ) def get_gpu_kind(): device = jax.devices()[0] if not gradio_helpers.should_mock() and device.platform != 'gpu': raise gr.Error('GPU not visible to JAX!') return f'GPU={device.device_kind}' demo.load(get_gpu_kind, None, gpu_kind) return demo if __name__ == '__main__': logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logging.info('JAX devices: %s', jax.devices()) for k, v in os.environ.items(): logging.info('environ["%s"] = %r', k, v) gradio_helpers.set_warmup_function(warmup) for name, (repo, filename, revision) in models.MODELS.items(): gradio_helpers.register_download(name, repo, filename, revision) create_app().queue().launch()