diff --git a/README.md b/README.md index d2f62d2727715ca8e7d17d08bed4288cd9b679f0..5a90284608d2a9eb0793f793ca939bb7de2fcb37 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,49 @@ --- -title: Paligemma -emoji: 🌍 +title: PaliGemma Demo +emoji: 🤲 colorFrom: green -colorTo: gray +colorTo: yellow sdk: gradio -sdk_version: 4.31.2 +sdk_version: 4.22.0 app_file: app.py pinned: false -license: gemma +license: apache-2.0 --- -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +# PaliGemma Demo + +See [Blogpost] and [`big_vision README.md`] for details about the model. + + +[Blogpost]: https://huggingface.co/blog/paligemma + +[`big_vision README.md`]: https://github.com/google-research/big_vision/blob/main/big_vision/configs/proj/paligemma/README.md + +## Development + +Local testing (CPU, Python 3.12): + +```bash +pip -m venv env +. env/bin/activate +pip install -qr requirements-cpu.txt +python app.py +``` + +Environment variables: + +- `MOCK_MODEL=yes`: For quick UI testing. +- `RAM_CACHE_GB=18`: Enables caching of 3 bf16 models in memory: a single bf16 + model is about 5860 MB. Use with care on spaces with little RAM. For example, + on a `A10G large` space you can cache five models in RAM, so you would set + `RAM_CACHE_GB=30`. +- `HOST_COLOCATION=4`: If host RAM/disk is shared between 4 processes (e.g. the + Huggingface `A10 large` Spaces). + + +Loading models: + +- The set of models loaded is defined in `./models.py`. +- You must first acknowledge usage conditions to access models. +- When testing locally, you'll have to run `huggingface_cli login`. +- When running in a Huggingface Space, you'll have to set a `HF_TOKEN` secret. diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..63adb4ba1e3f9f4d95781e5d836b95fb6464a8c9 --- /dev/null +++ b/app.py @@ -0,0 +1,251 @@ +"""PaliGemma demo gradio app.""" + +import datetime +import functools +import glob +import json +import logging +import os +import time + +import gradio as gr +import jax +import PIL.Image +import gradio_helpers +import models +import paligemma_parse + +INTRO_TEXT = """🤲 PaliGemma demo\n\n +| [GitHub](https://github.com/google-research/big_vision/blob/main/big_vision/configs/proj/paligemma/README.md) +| [HF blog post](https://huggingface.co/blog/paligemma) +| [Google blog post](https://developers.googleblog.com/en/gemma-family-and-toolkit-expansion-io-2024) +| [Vertex AI Model Garden](https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/363) +| [Demo](https://huggingface.co/spaces/google/paligemma) +|\n\n +[PaliGemma](https://ai.google.dev/gemma/docs/paligemma) is an open vision-language model by Google, +inspired by [PaLI-3](https://arxiv.org/abs/2310.09199) and +built with open components such as the [SigLIP](https://arxiv.org/abs/2303.15343) +vision model and the [Gemma](https://arxiv.org/abs/2403.08295) language model. PaliGemma is designed as a versatile +model for transfer to a wide range of vision-language tasks such as image and short video caption, visual question +answering, text reading, object detection and object segmentation. +\n\n +This space includes models fine-tuned on a mix of downstream tasks. +See the [blog post](https://huggingface.co/blog/paligemma) and +[README](https://github.com/google-research/big_vision/blob/main/big_vision/configs/proj/paligemma/README.md) +for detailed information how to use and fine-tune PaliGemma models. +\n\n +**This is an experimental research model.** Make sure to add appropriate guardrails when using the model for applications. +""" + + +make_image = lambda value, visible: gr.Image( + value, label='Image', type='filepath', visible=visible) +make_annotated_image = functools.partial(gr.AnnotatedImage, label='Image') +make_highlighted_text = functools.partial(gr.HighlightedText, label='Output') + + +# https://coolors.co/4285f4-db4437-f4b400-0f9d58-e48ef1 +COLORS = ['#4285f4', '#db4437', '#f4b400', '#0f9d58', '#e48ef1'] + + +@gradio_helpers.synced +def compute(image, prompt, model_name, sampler): + """Runs model inference.""" + if image is None: + raise gr.Error('Image required') + + logging.info('prompt="%s"', prompt) + + if isinstance(image, str): + image = PIL.Image.open(image) + if gradio_helpers.should_mock(): + logging.warning('Mocking response') + time.sleep(2.) + output = paligemma_parse.EXAMPLE_STRING + else: + if not model_name: + raise gr.Error('Models not loaded yet') + output = models.generate(model_name, sampler, image, prompt) + logging.info('output="%s"', output) + + width, height = image.size + objs = paligemma_parse.extract_objs(output, width, height, unique_labels=True) + labels = set(obj.get('name') for obj in objs if obj.get('name')) + color_map = {l: COLORS[i % len(COLORS)] for i, l in enumerate(labels)} + highlighted_text = [(obj['content'], obj.get('name')) for obj in objs] + annotated_image = ( + image, + [ + ( + obj['mask'] if obj.get('mask') is not None else obj['xyxy'], + obj['name'] or '', + ) + for obj in objs + if 'mask' in obj or 'xyxy' in obj + ], + ) + has_annotations = bool(annotated_image[1]) + return ( + make_highlighted_text( + highlighted_text, visible=True, color_map=color_map), + make_image(image, visible=not has_annotations), + make_annotated_image( + annotated_image, visible=has_annotations, width=width, height=height, + color_map=color_map), + ) + + +def warmup(model_name): + image = PIL.Image.new('RGB', [1, 1]) + _ = compute(image, '', model_name, 'greedy') + + +def reset(): + return ( + '', make_highlighted_text('', visible=False), + make_image(None, visible=True), make_annotated_image(None, visible=False), + ) + + +def create_app(): + """Creates demo UI.""" + + make_model = lambda choices: gr.Dropdown( + value=(choices + [''])[0], + choices=choices, + label='Model', + visible=bool(choices), + ) + make_prompt = lambda value, visible=True: gr.Textbox( + value, label='Prompt', visible=visible) + + with gr.Blocks() as demo: + + ##### Main UI structure. + + gr.Markdown(INTRO_TEXT) + with gr.Row(): + image = make_image(None, visible=True) # input + annotated_image = make_annotated_image(None, visible=False) # output + with gr.Column(): + with gr.Row(): + prompt = make_prompt('', visible=True) + model_info = gr.Markdown(label='Model Info') + with gr.Row(): + model = make_model([]) + samplers = [ + 'greedy', 'nucleus(0.1)', 'nucleus(0.3)', 'temperature(0.5)'] + sampler = gr.Dropdown( + value=samplers[0], choices=samplers, label='Decoding' + ) + with gr.Row(): + run = gr.Button('Run', variant='primary') + clear = gr.Button('Clear') + highlighted_text = make_highlighted_text('', visible=False) + + ##### UI logic. + + def update_ui(model, prompt): + prompt = make_prompt(prompt, visible=True) + model_info = f'Model `{model}` – {models.MODELS_INFO.get(model, "No info.")}' + return [prompt, model_info] + + gr.on( + [model.change], + update_ui, + [model, prompt], + [prompt, model_info], + ) + + gr.on( + [run.click, prompt.submit], + compute, + [image, prompt, model, sampler], + [highlighted_text, image, annotated_image], + ) + clear.click( + reset, None, [prompt, highlighted_text, image, annotated_image] + ) + + ##### Examples. + + gr.set_static_paths(['examples/']) + all_examples = [json.load(open(p)) for p in glob.glob('examples/*.json')] + logging.info('loaded %d examples', len(all_examples)) + example_image = gr.Image( + label='Image', visible=False) # proxy, never visible + example_model = gr.Text( + label='Model', visible=False) # proxy, never visible + example_prompt = gr.Text( + label='Prompt', visible=False) # proxy, never visible + example_license = gr.Markdown( + label='Image License', visible=False) # placeholder, never visible + gr.Examples( + examples=[ + [ + f'examples/{ex["name"]}.jpg', + ex['prompt'], + ex['model'], + ex['license'], + ] + for ex in all_examples + if ex['model'] in models.MODELS + ], + inputs=[example_image, example_prompt, example_model, example_license], + ) + + ##### Examples UI logic. + + example_image.change( + lambda image_path: ( + make_image(image_path, visible=True), + make_annotated_image(None, visible=False), + make_highlighted_text('', visible=False), + ), + example_image, + [image, annotated_image, highlighted_text], + ) + def example_model_changed(model): + if model not in gradio_helpers.get_paths(): + raise gr.Error(f'Model "{model}" not loaded!') + return model + example_model.change(example_model_changed, example_model, model) + example_prompt.change(make_prompt, example_prompt, prompt) + + ##### Status. + + status = gr.Markdown(f'Startup: {datetime.datetime.now()}') + gpu_kind = gr.Markdown(f'GPU=?') + demo.load( + lambda: [ + gradio_helpers.get_status(), + make_model(list(gradio_helpers.get_paths())), + ], + None, + [status, model], + ) + def get_gpu_kind(): + device = jax.devices()[0] + if not gradio_helpers.should_mock() and device.platform != 'gpu': + raise gr.Error('GPU not visible to JAX!') + return f'GPU={device.device_kind}' + demo.load(get_gpu_kind, None, gpu_kind) + + return demo + + +if __name__ == '__main__': + + logging.basicConfig(level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s') + + logging.info('JAX devices: %s', jax.devices()) + + for k, v in os.environ.items(): + logging.info('environ["%s"] = %r', k, v) + + gradio_helpers.set_warmup_function(warmup) + for name, (repo, filename, revision) in models.MODELS.items(): + gradio_helpers.register_download(name, repo, filename, revision) + + create_app().queue().launch() diff --git a/examples/barsik.jpg b/examples/barsik.jpg new file mode 100644 index 0000000000000000000000000000000000000000..55f855f13e882e57272a4eed142c919e907b84b6 Binary files /dev/null and b/examples/barsik.jpg differ diff --git a/examples/barsik.json b/examples/barsik.json new file mode 100644 index 0000000000000000000000000000000000000000..6d6f13e76e15985824ab27135a8b62d8b278d0dc --- /dev/null +++ b/examples/barsik.json @@ -0,0 +1,7 @@ +{ + "name": "barsik", + "comment": "", + "model": "paligemma-3b-mix-224", + "prompt": "segment cat", + "license": "CC0 by [maximneumann@](https://github.com/maximneumann)" +} \ No newline at end of file diff --git a/examples/biennale.jpg b/examples/biennale.jpg new file mode 100644 index 0000000000000000000000000000000000000000..05ba1292a74c2842df4b4341ecaf1a1c1ecbcce0 Binary files /dev/null and b/examples/biennale.jpg differ diff --git a/examples/biennale.json b/examples/biennale.json new file mode 100644 index 0000000000000000000000000000000000000000..532ff527f32ad4e5fa1ebd71ebacc14d537370e5 --- /dev/null +++ b/examples/biennale.json @@ -0,0 +1,7 @@ +{ + "name": "biennale", + "comment": "", + "model": "paligemma-3b-mix-224", + "prompt": "In which city is this?", + "license": "CC0 by [andsteing@](https://huggingface.co/andsteing)" +} \ No newline at end of file diff --git a/examples/billard1.jpg b/examples/billard1.jpg new file mode 100644 index 0000000000000000000000000000000000000000..2fbf3c5a9e96df8099640c4d9700fccac7063648 Binary files /dev/null and b/examples/billard1.jpg differ diff --git a/examples/billard1.json b/examples/billard1.json new file mode 100644 index 0000000000000000000000000000000000000000..2667d173894c20049779f493091cb00be8205d07 --- /dev/null +++ b/examples/billard1.json @@ -0,0 +1,7 @@ +{ + "name": "billard1", + "comment": "", + "model": "paligemma-3b-mix-224", + "prompt": "How many red balls are there?", + "license": "CC0 by [mbosnjak@](https://github.com/mbosnjak)" +} \ No newline at end of file diff --git a/examples/billard2.jpg b/examples/billard2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a2a65c4b4c837082190bda4c1ec0d95aae757387 Binary files /dev/null and b/examples/billard2.jpg differ diff --git a/examples/billard2.json b/examples/billard2.json new file mode 100644 index 0000000000000000000000000000000000000000..1e66dd97b575f666c962436482fc18ee8682493e --- /dev/null +++ b/examples/billard2.json @@ -0,0 +1,7 @@ +{ + "name": "billard2", + "comment": "", + "model": "paligemma-3b-mix-224", + "prompt": "How many balls are there?", + "license": "CC0 by [mbosnjak@](https://github.com/mbosnjak)" +} \ No newline at end of file diff --git a/examples/bowie.jpg b/examples/bowie.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a470c3fbcd2af4e81af9de46f6ba26f17db81631 Binary files /dev/null and b/examples/bowie.jpg differ diff --git a/examples/bowie.json b/examples/bowie.json new file mode 100644 index 0000000000000000000000000000000000000000..deb4dfd631631946765c9e90fa4555822a453e03 --- /dev/null +++ b/examples/bowie.json @@ -0,0 +1,7 @@ +{ + "name": "bowie", + "comment": "", + "model": "paligemma-3b-mix-224", + "prompt": "Who is this?", + "license": "CC0 by [akolesnikoff@](https://github.com/akolesnikoff)" +} \ No newline at end of file diff --git a/examples/branch.jpg b/examples/branch.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d95595728b845c0ec2b7ee508473541a48d84290 Binary files /dev/null and b/examples/branch.jpg differ diff --git a/examples/branch.json b/examples/branch.json new file mode 100644 index 0000000000000000000000000000000000000000..a86c14f5d3fe2f2d0512ce49fc2ab3b9b6012c61 --- /dev/null +++ b/examples/branch.json @@ -0,0 +1,7 @@ +{ + "name": "branch", + "comment": "", + "model": "paligemma-3b-mix-224", + "prompt": "What caused this?", + "license": "CC0 by [andsteing@](https://huggingface.co/andsteing)" +} \ No newline at end of file diff --git a/examples/cc_fox.jpg b/examples/cc_fox.jpg new file mode 100644 index 0000000000000000000000000000000000000000..47c95d0a91241833574ccb19b8a355417a87bc7a Binary files /dev/null and b/examples/cc_fox.jpg differ diff --git a/examples/cc_fox.json b/examples/cc_fox.json new file mode 100644 index 0000000000000000000000000000000000000000..69ee0678e50e701e0167097f4c41ed360f449aed --- /dev/null +++ b/examples/cc_fox.json @@ -0,0 +1,7 @@ +{ + "name": "cc_fox", + "comment": "", + "model": "paligemma-3b-mix-448", + "prompt": "Which breed is this fox?", + "license": "CC0 by [XiaohuaZhai@](https://sites.google.com/view/xzhai)" +} diff --git a/examples/cc_landscape.jpg b/examples/cc_landscape.jpg new file mode 100644 index 0000000000000000000000000000000000000000..0a4b610c4234ffa67305b7952584430f4d953dde Binary files /dev/null and b/examples/cc_landscape.jpg differ diff --git a/examples/cc_landscape.json b/examples/cc_landscape.json new file mode 100644 index 0000000000000000000000000000000000000000..c1a66ec2901cd5108c71a690f23cf6ef51a9fbee --- /dev/null +++ b/examples/cc_landscape.json @@ -0,0 +1,7 @@ +{ + "name": "cc_landscape", + "comment": "", + "model": "paligemma-3b-mix-448", + "prompt": "What does the image show?", + "license": "CC0 by [XiaohuaZhai@](https://sites.google.com/view/xzhai)" +} diff --git a/examples/cc_puffin.jpg b/examples/cc_puffin.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ae6bb3dc676dc40b2854c58d8238ff7770f3072d Binary files /dev/null and b/examples/cc_puffin.jpg differ diff --git a/examples/cc_puffin.json b/examples/cc_puffin.json new file mode 100644 index 0000000000000000000000000000000000000000..3ca086360d6c168cd3587a189aebe7a7bae2ca41 --- /dev/null +++ b/examples/cc_puffin.json @@ -0,0 +1,7 @@ +{ + "name": "cc_puffin", + "comment": "", + "model": "paligemma-3b-mix-448", + "prompt": "detect puffin in the back ; puffin in front", + "license": "CC0 by [XiaohuaZhai@](https://sites.google.com/view/xzhai)" +} diff --git a/examples/couch.jpg b/examples/couch.jpg new file mode 100644 index 0000000000000000000000000000000000000000..81800961f0498e46fc06ef525311f0a5d88eb4cb Binary files /dev/null and b/examples/couch.jpg differ diff --git a/examples/couch.json b/examples/couch.json new file mode 100644 index 0000000000000000000000000000000000000000..32f4cba01ded6e629661c4f81ec9125f9af8409e --- /dev/null +++ b/examples/couch.json @@ -0,0 +1,7 @@ +{ + "name": "couch", + "comment": "", + "model": "paligemma-3b-mix-224", + "prompt": "How many yellow cushions are on the couch?", + "license": "CC0" +} \ No newline at end of file diff --git a/examples/couch_.json b/examples/couch_.json new file mode 100644 index 0000000000000000000000000000000000000000..22a288af099703296a1208279484354f88ed5c20 --- /dev/null +++ b/examples/couch_.json @@ -0,0 +1,7 @@ +{ + "name": "couch", + "comment": "", + "model": "paligemma-3b-mix-224", + "prompt": "How many painting do you see in the image?", + "license": "CC0" +} \ No newline at end of file diff --git a/examples/cups.jpg b/examples/cups.jpg new file mode 100644 index 0000000000000000000000000000000000000000..29fb745612887e7a0d4137a503831fe4dd0841d1 Binary files /dev/null and b/examples/cups.jpg differ diff --git a/examples/cups.json b/examples/cups.json new file mode 100644 index 0000000000000000000000000000000000000000..078e3df2986f38350c30eaf2e1e1522a842b7664 --- /dev/null +++ b/examples/cups.json @@ -0,0 +1,7 @@ +{ + "name": "cups", + "comment": "", + "model": "paligemma-3b-mix-224", + "prompt": "how many cups?", + "license": "CC0 by [mbosnjak@](https://github.com/mbosnjak)" +} \ No newline at end of file diff --git a/examples/dice.jpg b/examples/dice.jpg new file mode 100644 index 0000000000000000000000000000000000000000..76d0fbabee3a9aa31d3335a850f25b0c40952d70 Binary files /dev/null and b/examples/dice.jpg differ diff --git a/examples/dice.json b/examples/dice.json new file mode 100644 index 0000000000000000000000000000000000000000..a3fb3f9703dd6ea055569fba49b4a96b76df8235 --- /dev/null +++ b/examples/dice.json @@ -0,0 +1,7 @@ +{ + "name": "dice", + "comment": "", + "model": "paligemma-3b-mix-224", + "prompt": "segment dice ; dice", + "license": "CC0 by [andresusanopinto@](https://github.com/andresusanopinto)" +} \ No newline at end of file diff --git a/examples/emu.jpg b/examples/emu.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d298e271a108a89cb89473af26bd202902f8b901 Binary files /dev/null and b/examples/emu.jpg differ diff --git a/examples/emu.json b/examples/emu.json new file mode 100644 index 0000000000000000000000000000000000000000..23532eac207641e3d138ceb67f9a051d6d231539 --- /dev/null +++ b/examples/emu.json @@ -0,0 +1,7 @@ +{ + "name": "emu", + "comment": "", + "model": "paligemma-3b-mix-224", + "prompt": "What animal is this?", + "license": "CC0 by [akolesnikoff@](https://github.com/akolesnikoff)" +} \ No newline at end of file diff --git a/examples/fridge.jpg b/examples/fridge.jpg new file mode 100644 index 0000000000000000000000000000000000000000..dd6af3f32f8b3bad650b2162f1b4628d8e5a26db Binary files /dev/null and b/examples/fridge.jpg differ diff --git a/examples/fridge.json b/examples/fridge.json new file mode 100644 index 0000000000000000000000000000000000000000..c6628d78020b331530c4c6a3726c38c454c4da2f --- /dev/null +++ b/examples/fridge.json @@ -0,0 +1,7 @@ +{ + "name": "fridge", + "comment": "", + "model": "paligemma-3b-mix-224", + "prompt": "Describe the image.", + "license": "CC0 by [andresusanopinto@](https://github.com/andresusanopinto)" +} \ No newline at end of file diff --git a/examples/givt.jpg b/examples/givt.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b269c132464ecbd8c7ce8f5464cb6ebf142cc8a1 Binary files /dev/null and b/examples/givt.jpg differ diff --git a/examples/givt.json b/examples/givt.json new file mode 100644 index 0000000000000000000000000000000000000000..4e244d55bd0423fd8041accf1ba3d9bb43d494af --- /dev/null +++ b/examples/givt.json @@ -0,0 +1,7 @@ +{ + "name": "givt", + "comment": "", + "model": "paligemma-3b-mix-224", + "prompt": "What does the image show?", + "license": "CC-BY [GIVT paper](https://arxiv.org/abs/2312.02116)" +} \ No newline at end of file diff --git a/examples/greenlake.jpg b/examples/greenlake.jpg new file mode 100644 index 0000000000000000000000000000000000000000..65401579082eebb41a70869d8785b2c84a437476 Binary files /dev/null and b/examples/greenlake.jpg differ diff --git a/examples/greenlake.json b/examples/greenlake.json new file mode 100644 index 0000000000000000000000000000000000000000..5de5282b9608ada567cd696e8e4846e0906088da --- /dev/null +++ b/examples/greenlake.json @@ -0,0 +1,7 @@ +{ + "name": "greenlake", + "comment": "", + "model": "paligemma-3b-mix-224", + "prompt": "Describe the image.", + "license": "CC0 by [akolesnikoff@](https://github.com/akolesnikoff)" +} \ No newline at end of file diff --git a/examples/howto.jpg b/examples/howto.jpg new file mode 100644 index 0000000000000000000000000000000000000000..5f079c6751730ab16008835765e6296c8fcc2d8c Binary files /dev/null and b/examples/howto.jpg differ diff --git a/examples/howto.json b/examples/howto.json new file mode 100644 index 0000000000000000000000000000000000000000..2b44aae0878af6ff9abf1628ccedbb932179d19d --- /dev/null +++ b/examples/howto.json @@ -0,0 +1,7 @@ +{ + "name": "howto", + "comment": "", + "model": "paligemma-3b-mix-224", + "prompt": "What does this image show?", + "license": "CC-BY [How to train your ViT?](https://arxiv.org/abs/2106.10270)" +} \ No newline at end of file diff --git a/examples/markers.jpg b/examples/markers.jpg new file mode 100644 index 0000000000000000000000000000000000000000..756537b93cf074ebfd32a45a7438b46914d335c3 Binary files /dev/null and b/examples/markers.jpg differ diff --git a/examples/markers.json b/examples/markers.json new file mode 100644 index 0000000000000000000000000000000000000000..9093a2c9a468dd7995039cae415394f182c35e89 --- /dev/null +++ b/examples/markers.json @@ -0,0 +1,7 @@ +{ + "name": "markers", + "comment": "answer en How many cups are there?", + "model": "paligemma-3b-mix-224", + "prompt": "How many cups are there?", + "license": "CC0" +} \ No newline at end of file diff --git a/examples/mcair.jpg b/examples/mcair.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e965dde07d103114690bfc086d24ab3fdb054e65 Binary files /dev/null and b/examples/mcair.jpg differ diff --git a/examples/mcair.json b/examples/mcair.json new file mode 100644 index 0000000000000000000000000000000000000000..0f50b7f96253821cb70eb5aac40760d140252ffa --- /dev/null +++ b/examples/mcair.json @@ -0,0 +1,7 @@ +{ + "name": "mcair", + "comment": "", + "model": "paligemma-3b-mix-224", + "prompt": "Can you board this airplane?", + "license": "CC0 by [akolesnikoff@](https://github.com/akolesnikoff)" +} \ No newline at end of file diff --git a/examples/mcair_.json b/examples/mcair_.json new file mode 100644 index 0000000000000000000000000000000000000000..7ae3353a0ed5109a9cc9f26dc179ec7a86357b8c --- /dev/null +++ b/examples/mcair_.json @@ -0,0 +1,7 @@ +{ + "name": "mcair", + "comment": "", + "model": "paligemma-3b-mix-224", + "prompt": "Is this a restaurant?", + "license": "CC0 by [akolesnikoff@](https://github.com/akolesnikoff)" +} \ No newline at end of file diff --git a/examples/minergie.jpg b/examples/minergie.jpg new file mode 100644 index 0000000000000000000000000000000000000000..8372f189fa3549042e752931b2577fc7384be0fc Binary files /dev/null and b/examples/minergie.jpg differ diff --git a/examples/minergie.json b/examples/minergie.json new file mode 100644 index 0000000000000000000000000000000000000000..cb292ed5e6eafc30ccecb9d7b2569ee370ab3b06 --- /dev/null +++ b/examples/minergie.json @@ -0,0 +1,7 @@ +{ + "name": "minergie", + "comment": "", + "model": "paligemma-3b-mix-224", + "prompt": "ocr", + "license": "CC0 by [andsteing@](https://huggingface.co/andsteing)" +} \ No newline at end of file diff --git a/examples/morel.jpg b/examples/morel.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e1498f0c98d6a2a7187ec74b809ca9f8fdc776c2 Binary files /dev/null and b/examples/morel.jpg differ diff --git a/examples/morel.json b/examples/morel.json new file mode 100644 index 0000000000000000000000000000000000000000..c4fb09a89a268cae5cbdff0810feea34177da7c8 --- /dev/null +++ b/examples/morel.json @@ -0,0 +1,7 @@ +{ + "name": "morel", + "comment": "", + "model": "paligemma-3b-mix-224", + "prompt": "detect morel", + "license": "CC0 by [andsteing@](https://huggingface.co/andsteing)" +} \ No newline at end of file diff --git a/examples/motorcyclists.jpg b/examples/motorcyclists.jpg new file mode 100644 index 0000000000000000000000000000000000000000..91fdffa020ea0e2ba5ef1a9be7dd68bdb7a081ce Binary files /dev/null and b/examples/motorcyclists.jpg differ diff --git a/examples/motorcyclists.json b/examples/motorcyclists.json new file mode 100644 index 0000000000000000000000000000000000000000..f4a0d1e8b7207ac55ed90d09cbfc68ea065c901c --- /dev/null +++ b/examples/motorcyclists.json @@ -0,0 +1,7 @@ +{ + "name": "motorcyclists", + "comment": "", + "model": "paligemma-3b-mix-224", + "prompt": "What does the image show?", + "license": "CC0 by [akolesnikoff@](https://github.com/akolesnikoff)" +} \ No newline at end of file diff --git a/examples/parking.jpg b/examples/parking.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3b3c6d3ebb5057b228f04a85a83184c5a1c8aaba Binary files /dev/null and b/examples/parking.jpg differ diff --git a/examples/parking.json b/examples/parking.json new file mode 100644 index 0000000000000000000000000000000000000000..9964ba3acfadec3e0165377bb182ec416672b49a --- /dev/null +++ b/examples/parking.json @@ -0,0 +1,7 @@ +{ + "name": "parking", + "comment": "", + "model": "paligemma-3b-mix-224", + "prompt": "Describe the image.", + "license": "CC0 by [xiaohuazhai@](https://huggingface.co/xiaohuazhai)" +} \ No newline at end of file diff --git a/examples/password.jpg b/examples/password.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c7804dfa42aa3cad23089087bffca604a8c3507e Binary files /dev/null and b/examples/password.jpg differ diff --git a/examples/password.json b/examples/password.json new file mode 100644 index 0000000000000000000000000000000000000000..070f3f8c992a948b177f2cc647a7f9260c0d6c38 --- /dev/null +++ b/examples/password.json @@ -0,0 +1,7 @@ +{ + "name": "password", + "comment": "", + "model": "paligemma-3b-mix-224", + "prompt": "What is the password?", + "license": "CC0 by [akolesnikoff@](https://github.com/akolesnikoff)" +} \ No newline at end of file diff --git a/examples/preservationhall.jpg b/examples/preservationhall.jpg new file mode 100644 index 0000000000000000000000000000000000000000..adab242566923b4fa3ac97b179c358b07697f221 Binary files /dev/null and b/examples/preservationhall.jpg differ diff --git a/examples/preservationhall.json b/examples/preservationhall.json new file mode 100644 index 0000000000000000000000000000000000000000..6f9be7e1169ea9269ee0e9ebe58d32b424b4c21f --- /dev/null +++ b/examples/preservationhall.json @@ -0,0 +1,7 @@ +{ + "name": "preservationhall", + "comment": "", + "model": "paligemma-3b-mix-224", + "prompt": "Describe the image.", + "license": "CC0 by [mitscha@](https://github.com/mitscha)" +} \ No newline at end of file diff --git a/examples/preservationhall_.json b/examples/preservationhall_.json new file mode 100644 index 0000000000000000000000000000000000000000..5571c5272f91d36f2d67a55aac2d61715e3e5f26 --- /dev/null +++ b/examples/preservationhall_.json @@ -0,0 +1,7 @@ +{ + "name": "preservationhall", + "comment": "", + "model": "paligemma-3b-mix-224", + "prompt": "What's the name of the place?", + "license": "CC0 by [mitscha@](https://github.com/mitscha)" +} \ No newline at end of file diff --git a/examples/ulges.jpg b/examples/ulges.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e91d86083e3dbfecc2688735d380a441c0dee227 Binary files /dev/null and b/examples/ulges.jpg differ diff --git a/examples/ulges.json b/examples/ulges.json new file mode 100644 index 0000000000000000000000000000000000000000..d22ee5806c716238cd83fcbad2597dcf31dc6e04 --- /dev/null +++ b/examples/ulges.json @@ -0,0 +1,7 @@ +{ + "name": "ulges", + "comment": "", + "model": "paligemma-3b-mix-224", + "prompt": "Who is the author of this book?", + "license": "CC0" +} \ No newline at end of file diff --git a/gradio_helpers.py b/gradio_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..1de6dc7e34bf56bdb0ba4b451208638a11868879 --- /dev/null +++ b/gradio_helpers.py @@ -0,0 +1,280 @@ +"""Gradio helpers for caching, downloading etc.""" + +import concurrent.futures +import contextlib +import datetime +import functools +import logging +import os +import shutil +import subprocess +import sys +import tempfile +import threading +import time +import unittest.mock + +import huggingface_hub +import jax +import numpy as np +import psutil + + +def _clone_git(url, destination_folder, commit_hash=None): + subprocess.run([ + 'git', 'clone', '--depth=1', + url, destination_folder + ], check=True) + if commit_hash: + subprocess.run( + ['git', '-C', destination_folder, 'checkout', commit_hash], check=True + ) + + +def setup(): + """Installs big_vision repo and mocks tensorflow_text.""" + for url, dst_name, commit_hash in ( + ( + 'https://github.com/google-research/big_vision', + 'big_vision_repo', + None, + ), + ): + dst_path = os.path.join(tempfile.gettempdir(), dst_name) + if os.path.exists(dst_path): + print('Found existing "%s" at "%s"' % (url, dst_path)) + else: + print('Cloning "%s" into "%s"' % (url, dst_path)) + _clone_git(url, dst_path, commit_hash) + + if dst_path not in sys.path: + sys.path.insert(0, dst_path) + + # Imported in `big_vision.pp.ops_text` but we don't use it. + sys.modules['tensorflow_text'] = unittest.mock.MagicMock() + + +# Must be run in main app before other BV imports: +setup() + + +def should_mock(): + """Returns `True` if `MOCK_MODEL=yes` is set in environment.""" + return os.environ.get('MOCK_MODEL') == 'yes' + + +@contextlib.contextmanager +def timed(name, start_message=False): + """Emits "Timed {name}: .1f secs" message to INFO logs.""" + t0 = time.monotonic() + timing = dict(dt=None) + try: + if start_message: + logging.info('Timing %s...', name) + yield timing + finally: + timing['secs'] = time.monotonic() - t0 + logging.info('Timed %s: %.1f secs', name, timing['secs']) + + +def synced(f): + """Syncs calls to `f` with a `threading.Lock()`.""" + lock = threading.Lock() + @functools.wraps(f) + def wrapper(*args, **kw): + t0 = time.monotonic() + with lock: + lock_dt = time.monotonic() - t0 + logging.info('synced wait: %.1f secs', lock_dt) + return f(*args, **kw) + return wrapper + + +_warmed_up = set() +_warmup_function = None + + +def set_warmup_function(warmup_function): + global _warmup_function + _warmup_function = warmup_function + + +_lock = threading.Lock() +_scheduled = {} +_download_secs = 0 +_warmup_secs = 0 +_loading_secs = 0 +_done = {} +_failed = {} + + +def _do_download(): + """Downloading files, to be started in background thread.""" + global _download_secs + executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) + while True: + if not _scheduled: + time.sleep(1) + continue + + name, (repo, filename, revision) = next(iter(_scheduled.items())) + logging.info('Downloading "%s" %s/%s/%s...', name, repo, filename, revision) + with timed(f'downloading {name}', True) as t: + if should_mock(): + logging.warning('Mocking loading') + time.sleep(10.) + _done[name] = None + else: + try: + _done[name] = huggingface_hub.hf_hub_download( + repo_id=repo, filename=filename, revision=revision) + except Exception as e: # pylint: disable=broad-exception-caught + logging.exception('Could not download "%s" from hub!', name) + _failed[name] = str(e) + with _lock: + _scheduled.pop(name) + continue + + if _warmup_function: + def warmup(name): + global _warmup_secs + with timed(f'warming up {name}', True) as t: + try: + _warmup_function(name) + _warmed_up.add(name) + except Exception: # pylint: disable=broad-exception-caught + logging.exception('Could not warmup "%s"!', name) + _warmup_secs += t['secs'] + executor.submit(warmup, name) + + _download_secs += t['secs'] + with _lock: + _scheduled.pop(name) + + +def register_download(name, repo, filename, revision='main'): + """Will cause download of `filename` from HF `repo` in background thread.""" + with _lock: + if name not in _scheduled: + _scheduled[name] = (repo, filename, revision) + + +def _hms(secs): + """Formats `secs=3700` to `"01:01:40"`.""" + secs = int(secs) + h = secs // 3600 + m = (secs - h * 3600) // 60 + s = secs % 60 + return (f'{h}:' if h else '') + f'{m:02}:{s:02}' + + +def downloads_status(): + """Returns string representation of download stats.""" + done_t = remaining_t = '' + if _done: + done_t = f' in {_hms(_download_secs)}' + remaining_t = f' {_hms(_download_secs/len(_done)*len(_scheduled))}' + status = f'Downloaded {len(_done)}{done_t}' + if _scheduled: + status += f', {len(_scheduled)}{remaining_t} remaining' + if _warmup_function: + status += f', warmed up {len(_warmed_up)} in {_hms(_warmup_secs)}' + if _failed: + status += f', {len(_failed)} failed' + return status + + +def get_paths(): + """Returns dictionary `name` to `path` from previous `register_download()`.""" + return dict(_done) + + +_download_thread = threading.Thread(target=_do_download) +_download_thread.daemon = True +_download_thread.start() + + +_estimated_real = [(10, 10)] +_memory_cache = {} + + +def get_with_progress(getter, secs, progress, step=0.1): + """Returns result from `getter` while showing a progress bar.""" + if progress is None: + return getter() + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(getter) + for _ in progress.tqdm(list(range(int(np.ceil(secs/step)))), desc='read'): + if not future.done(): + time.sleep(step) + return future.result() + + +def _get_array_sizes(tree): + return [getattr(x, 'nbytes', 0) for x in jax.tree_leaves(tree)] + + +def get_memory_cache( + key, getter, max_cache_size_bytes, progress=None, estimated_secs=None +): + """Keeps cache below specified size by removing elements not last accessed.""" + if key in _memory_cache: + _memory_cache[key] = _memory_cache.pop(key) # Updates "last accessed" order + return _memory_cache[key] + + est, real = zip(*_estimated_real) + if estimated_secs is None: + estimated_secs = sum(est) / len(est) + with timed(f'loading {key}') as t: + estimated_secs *= sum(real) / sum(est) + value = get_with_progress(getter, estimated_secs, progress) + _estimated_real.append((estimated_secs, t['secs'])) + + if not max_cache_size_bytes: + return value + + _memory_cache[key] = value + sz = sum(_get_array_sizes(list(_memory_cache.values()))) + logging.info('New memory cache size=%.1f MB', sz/1e6) + + while sz > max_cache_size_bytes: + k, v = next(iter(_memory_cache.items())) + if k == key: + break + s = sum(_get_array_sizes(v)) + logging.info('Removing %s from memory cache (%.1f MB)', k, s/1e6) + _memory_cache.pop(k) + sz -= s + + return value + + +def get_memory_cache_info(): + """Returns number of items and total size in bytes.""" + sizes = _get_array_sizes(_memory_cache) + return len(_memory_cache), sum(sizes) + + +def get_system_info(): + """Returns string describing system's RAM/disk status.""" + host_colocation = int(os.environ.get('HOST_COLOCATION', '1')) + vm = psutil.virtual_memory() + du = shutil.disk_usage('.') + return ( + f'RAM {vm.used / 1e9:.1f}/{vm.total / host_colocation / 1e9:.1f}G, ' + f'disk {du.used / 1e9:.1f}/{du.total / host_colocation / 1e9:.1f}G' + ) + + +def get_status(include_system_info=True): + """Returns string about download/memory/system status.""" + mc_len, mc_sz = get_memory_cache_info() + mc_t = _hms(sum(real for _, real in _estimated_real[1:])) + return ( + 'Timestamp: ' + + datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') + + ' – Model stats: ' + + downloads_status() + + ', ' + f'memory-cached {mc_len} ({mc_sz/1e9:.1f}G) in {mc_t}' + + (' – System: ' + get_system_info() if include_system_info else '') + ) diff --git a/models.py b/models.py new file mode 100644 index 0000000000000000000000000000000000000000..c2b588c911c94878398ef6636431afa803443785 --- /dev/null +++ b/models.py @@ -0,0 +1,87 @@ +"""Model-related code and constants.""" + +import dataclasses +import os +import re + +import PIL.Image + +# pylint: disable=g-bad-import-order +import gradio_helpers +import paligemma_bv + + +ORGANIZATION = 'google' +BASE_MODELS = [ + ('paligemma-3b-mix-224-jax', 'paligemma-3b-mix-224'), + ('paligemma-3b-mix-448-jax', 'paligemma-3b-mix-448'), +] +MODELS = { + **{ + model_name: ( + f'{ORGANIZATION}/{repo}', + f'{model_name}.bf16.npz', + 'bfloat16', # Model repo revision. + ) + for repo, model_name in BASE_MODELS + }, +} + +MODELS_INFO = { + 'paligemma-3b-mix-224': ( + 'JAX/FLAX PaliGemma 3B weights, finetuned with 224x224 input images and 256 token input/output ' + 'text sequences on a mixture of downstream academic datasets. The models are available in float32, ' + 'bfloat16 and float16 format for research purposes only.' + ), + 'paligemma-3b-mix-448': ( + 'JAX/FLAX PaliGemma 3B weights, finetuned with 448x448 input images and 512 token input/output ' + 'text sequences on a mixture of downstream academic datasets. The models are available in float32, ' + 'bfloat16 and float16 format for research purposes only.' + ), +} + +MODELS_RES_SEQ = { + 'paligemma-3b-mix-224': (224, 256), + 'paligemma-3b-mix-448': (448, 512), +} + +# "CPU basic" has 16G RAM, "T4 small" has 15 GB RAM. +# Below value should be smaller than "available RAM - one model". +# A single bf16 is about 5860 MB. +MAX_RAM_CACHE = int(float(os.environ.get('RAM_CACHE_GB', '0')) * 1e9) + +config = paligemma_bv.PaligemmaConfig( + ckpt='', # will be set below + res=224, + text_len=64, + tokenizer='gemma(tokensets=("loc", "seg"))', + vocab_size=256_000 + 1024 + 128, +) + + +def get_cached_model( + model_name: str, +) -> tuple[paligemma_bv.PaliGemmaModel, paligemma_bv.ParamsCpu]: + """Returns model and params, using RAM cache.""" + res, seq = MODELS_RES_SEQ[model_name] + model_path = gradio_helpers.get_paths()[model_name] + config_ = dataclasses.replace(config, ckpt=model_path, res=res, text_len=seq) + model, params_cpu = gradio_helpers.get_memory_cache( + config_, + lambda: paligemma_bv.load_model(config_), + max_cache_size_bytes=MAX_RAM_CACHE, + ) + return model, params_cpu + + +def generate( + model_name: str, sampler: str, image: PIL.Image.Image, prompt: str +) -> str: + """Generates output with specified `model_name`, `sampler`.""" + model, params_cpu = get_cached_model(model_name) + batch = model.shard_batch(model.prepare_batch([image], [prompt])) + with gradio_helpers.timed('sharding'): + params = model.shard_params(params_cpu) + with gradio_helpers.timed('computation', start_message=True): + tokens = model.predict(params, batch, sampler=sampler) + return model.tokenizer.to_str(tokens[0]) diff --git a/paligemma_bv.py b/paligemma_bv.py new file mode 100644 index 0000000000000000000000000000000000000000..b70512134711c80a387f6a816cec7a64842a8dd9 --- /dev/null +++ b/paligemma_bv.py @@ -0,0 +1,207 @@ +"""Wraps `big_vision` PaliGemma model for easy use in demo.""" + +from collections.abc import Callable +import dataclasses +from typing import Any + +import jax +import jax.numpy as jnp +import ml_collections +import numpy as np +import PIL.Image + +from big_vision import sharding +from big_vision import utils +from big_vision.models.proj.paligemma import paligemma +from big_vision.pp import builder as pp_builder +from big_vision.pp import ops_general # pylint: disable=unused-import +from big_vision.pp import ops_image # pylint: disable=unused-import +from big_vision.pp import ops_text # pylint: disable=unused-import +from big_vision.pp import tokenizer +from big_vision.pp.proj.paligemma import ops as ops_paligemma # pylint: disable=unused-import +from big_vision.trainers.proj.paligemma import predict_fns + + +mesh = jax.sharding.Mesh(jax.devices(), 'data') + + +def _recover_bf16(x): + if x.dtype == np.dtype('V2'): + x = x.view('bfloat16') + return x + + +def _load( + path, tokenizer_spec='gemma(tokensets=("loc", "seg"))', vocab_size=257_152 +): + """Loads model, params, decode functions and tokenizer.""" + tok = tokenizer.get_tokenizer(tokenizer_spec) + + config = ml_collections.FrozenConfigDict(dict( + llm_model='proj.paligemma.gemma_bv', + llm=dict(vocab_size=vocab_size, variant='gemma_2b'), + img=dict(variant='So400m/14', pool_type='none', scan=True), + )) + model = paligemma.Model(**config) + decode = predict_fns.get_all(model)['decode'] + beam_decode = predict_fns.get_all(model)['beam_decode'] + + params_cpu = paligemma.load(None, path, config) + # Some numpy versions don't load bfloat16 correctly: + params_cpu = jax.tree.map(_recover_bf16, params_cpu) + + return model, params_cpu, decode, beam_decode, tok + + +def _shard_params(params_cpu): + """Shards `params_cpu` with fsdp strategy on all available devices.""" + params_sharding = sharding.infer_sharding( + params_cpu, strategy=[('.*', 'fsdp(axis="data")')], mesh=mesh + ) + params = jax.tree.map(utils.reshard, params_cpu, params_sharding) + return params + + +def _pil2np(img): + """Accepts `PIL.Image` or `np.ndarray` and returns `np.ndarray`.""" + if isinstance(img, PIL.Image.Image): + img = np.array(img) + img = img[..., :3] + if img.ndim == 2: + img = img[..., None] + if img.shape[-1] == 1: + img = np.repeat(img, 3, axis=-1) + return img + + +def _prepare_batch( + images, + prefixes, + *, + res=224, + tokenizer_spec='gemma(tokensets=("loc", "seg"))', + suffixes=None, + text_len=64, +): + """Returns non-sharded batch.""" + + pp_fn = pp_builder.get_preprocess_fn('|'.join([ + f'resize({res}, antialias=True)|value_range(-1, 1)', + f"tok(key='prefix', bos='yes', model='{tokenizer_spec}')", + f"tok(key='septok', text='\\n', model='{tokenizer_spec}')", + f"tok(key='suffix', model='{tokenizer_spec}')", + 'masked_concat(["prefix", "septok", "suffix"], mask_ar=[0, 0, 1], mask_input=[1, 1, 1])', # pylint: disable=line-too-long + f'tolen({text_len}, pad_value=0, key="text")', + f'tolen({text_len}, pad_value=1, key="mask_ar")', + f'tolen({text_len}, pad_value=0, key="mask_input")', + 'keep("image", "text", "mask_ar", "mask_input")', + ]), log_data=False) + assert not isinstance(prefixes, str), f'expected batch: {prefixes}' + assert ( + isinstance(images, (list, tuple)) or images.ndim == 4 + ), f'expected batch: {images.shape}' + if suffixes is None: + suffixes = [''] * len(prefixes) + assert len(prefixes) == len(suffixes) == len(images) + examples = [{'_mask': True, **pp_fn({ + 'image': np.asarray(_pil2np(image)), + 'prefix': np.array(prefix), + 'suffix': np.array(suffix), + })} for image, prefix, suffix in zip(images, prefixes, suffixes)] + batch = jax.tree_map(lambda *xs: np.stack(xs), *examples) + return batch + + +def _shard_batch(batch, n=None): + """Shards `batch` with fsdp strategy on all available devices.""" + if n is None: + n = jax.local_device_count() + def pad(x): + return jnp.pad(x, [(0, -len(x) % n)] + [(0, 0)] * (x.ndim - 1)) + batch = {k: pad(v) for k, v in batch.items()} + data_sharding = jax.sharding.NamedSharding( + mesh, jax.sharding.PartitionSpec('data') + ) + batch_on_device = utils.reshard(batch, data_sharding) + return batch_on_device + + +@dataclasses.dataclass(frozen=True, kw_only=True, order=True) +class PaligemmaConfig: + """Desribes a `big_vision` PaliGemma model.""" + + ckpt: str + res: int + text_len: int + tokenizer: str + vocab_size: int + + +@dataclasses.dataclass(frozen=True, kw_only=True) +class PaliGemmaModel: + """Wraps a `big_vision` PaliGemma model.""" + + config: PaligemmaConfig + tokenizer: tokenizer.Tokenizer + decode: Callable[..., Any] + beam_decode: Callable[..., Any] + + @classmethod + def shard_batch(cls, batch): + return _shard_batch(batch) + + @classmethod + def shard_params(cls, params_cpu): + return _shard_params(params_cpu) + + def prepare_batch(self, images, texts, suffixes=None): + return _prepare_batch( + images=images, + prefixes=texts, + suffixes=suffixes, + res=self.config.res, + tokenizer_spec=self.config.tokenizer, + text_len=self.config.text_len, + ) + + def predict( + self, + params, + batch, + devices=None, + max_decode_len=128, + sampler='greedy', + **kw, + ): + """Returns tokens.""" + if devices is None: + devices = jax.devices() + if sampler == 'beam': + decode = self.beam_decode + else: + decode = self.decode + kw['sampler'] = sampler + return decode( + {'params': params}, + batch=batch, + devices=devices, + eos_token=self.tokenizer.eos_token, + max_decode_len=max_decode_len, + **kw, + ) + + +ParamsCpu = Any + + +def load_model(config: PaligemmaConfig) -> tuple[PaliGemmaModel, ParamsCpu]: + """Loads model from config.""" + model, params_cpu, decode, beam_decode, tok = _load( + path=config.ckpt, + tokenizer_spec=config.tokenizer, + vocab_size=config.vocab_size, + ) + del model + return PaliGemmaModel( + config=config, tokenizer=tok, decode=decode, beam_decode=beam_decode, + ), params_cpu diff --git a/paligemma_parse.py b/paligemma_parse.py new file mode 100644 index 0000000000000000000000000000000000000000..38356f2103edfe88852402de53948ce220c1b665 --- /dev/null +++ b/paligemma_parse.py @@ -0,0 +1,184 @@ +"""Parses PaliGemma output.""" + +import functools +import re + +import flax.linen as nn +import jax +import jax.numpy as jnp +import numpy as np +import PIL.Image + + +EXAMPLE_STRING = ' wall ; car ; david bowie ; david bowie ; david bowie' # pylint: disable=line-too-long + +_MODEL_PATH = 'vae-oid.npz' + +_SEGMENT_DETECT_RE = re.compile( + r'(.*?)' + + r'' * 4 + r'\s*' + + '(?:%s)?' % (r'' * 16) + + r'\s*([^;<>]+)? ?(?:; )?', +) + + +def _get_params(checkpoint): + """Converts PyTorch checkpoint to Flax params.""" + + def transp(kernel): + return np.transpose(kernel, (2, 3, 1, 0)) + + def conv(name): + return { + 'bias': checkpoint[name + '.bias'], + 'kernel': transp(checkpoint[name + '.weight']), + } + + def resblock(name): + return { + 'Conv_0': conv(name + '.0'), + 'Conv_1': conv(name + '.2'), + 'Conv_2': conv(name + '.4'), + } + + return { + '_embeddings': checkpoint['_vq_vae._embedding'], + 'Conv_0': conv('decoder.0'), + 'ResBlock_0': resblock('decoder.2.net'), + 'ResBlock_1': resblock('decoder.3.net'), + 'ConvTranspose_0': conv('decoder.4'), + 'ConvTranspose_1': conv('decoder.6'), + 'ConvTranspose_2': conv('decoder.8'), + 'ConvTranspose_3': conv('decoder.10'), + 'Conv_1': conv('decoder.12'), + } + + +def _quantized_values_from_codebook_indices(codebook_indices, embeddings): + batch_size, num_tokens = codebook_indices.shape + assert num_tokens == 16, codebook_indices.shape + unused_num_embeddings, embedding_dim = embeddings.shape + + encodings = jnp.take(embeddings, codebook_indices.reshape((-1)), axis=0) + encodings = encodings.reshape((batch_size, 4, 4, embedding_dim)) + return encodings + + +@functools.cache +def _get_reconstruct_masks(): + """Reconstructs masks from codebook indices. + + Returns: + A function that expects indices shaped `[B, 16]` of dtype int32, each + ranging from 0 to 127 (inclusive), and that returns a decoded masks sized + `[B, 64, 64, 1]`, of dtype float32, in range [-1, 1]. + """ + + class ResBlock(nn.Module): + features: int + + @nn.compact + def __call__(self, x): + original_x = x + x = nn.Conv(features=self.features, kernel_size=(3, 3), padding=1)(x) + x = nn.relu(x) + x = nn.Conv(features=self.features, kernel_size=(3, 3), padding=1)(x) + x = nn.relu(x) + x = nn.Conv(features=self.features, kernel_size=(1, 1), padding=0)(x) + return x + original_x + + class Decoder(nn.Module): + """Upscales quantized vectors to mask.""" + + @nn.compact + def __call__(self, x): + num_res_blocks = 2 + dim = 128 + num_upsample_layers = 4 + + x = nn.Conv(features=dim, kernel_size=(1, 1), padding=0)(x) + x = nn.relu(x) + + for _ in range(num_res_blocks): + x = ResBlock(features=dim)(x) + + for _ in range(num_upsample_layers): + x = nn.ConvTranspose( + features=dim, + kernel_size=(4, 4), + strides=(2, 2), + padding=2, + transpose_kernel=True, + )(x) + x = nn.relu(x) + dim //= 2 + + x = nn.Conv(features=1, kernel_size=(1, 1), padding=0)(x) + + return x + + def reconstruct_masks(codebook_indices): + quantized = _quantized_values_from_codebook_indices( + codebook_indices, params['_embeddings'] + ) + return Decoder().apply({'params': params}, quantized) + + with open(_MODEL_PATH, 'rb') as f: + params = _get_params(dict(np.load(f))) + + return jax.jit(reconstruct_masks, backend='cpu') + + +def extract_objs(text, width, height, unique_labels=False): + """Returns objs for a string with "" and "" tokens.""" + objs = [] + seen = set() + while text: + m = _SEGMENT_DETECT_RE.match(text) + if not m: + break + + gs = list(m.groups()) + before = gs.pop(0) + name = gs.pop() + y1, x1, y2, x2 = [int(x) / 1024 for x in gs[:4]] + y1, x1, y2, x2 = map(round, (y1*height, x1*width, y2*height, x2*width)) + + seg_indices = gs[4:20] + if seg_indices[0] is None: + mask = None + else: + seg_indices = np.array([int(x) for x in seg_indices], dtype=np.int32) + m64, = _get_reconstruct_masks()(seg_indices[None])[..., 0] + m64 = np.clip(np.array(m64) * 0.5 + 0.5, 0, 1) + m64 = PIL.Image.fromarray((m64 * 255).astype('uint8')) + mask = np.zeros([height, width]) + if y2 > y1 and x2 > x1: + mask[y1:y2, x1:x2] = np.array(m64.resize([x2 - x1, y2 - y1])) / 255.0 + + content = m.group() + if before: + objs.append(dict(content=before)) + content = content[len(before):] + while unique_labels and name in seen: + name = (name or '') + "'" + seen.add(name) + objs.append(dict( + content=content, xyxy=(x1, y1, x2, y2), mask=mask, name=name)) + text = text[len(before) + len(content):] + + if text: + objs.append(dict(content=text)) + + return objs + + +if __name__ == '__main__': + # Simple test. + print([ + { + k: (v.shape, v.mean()) if isinstance(v, np.ndarray) else v + for k, v in obj.items() + } + for obj in extract_objs(EXAMPLE_STRING, 100, 200) + ]) diff --git a/requirements-cpu.txt b/requirements-cpu.txt new file mode 100644 index 0000000000000000000000000000000000000000..e54c76d0fa705382d9cfb17745a48488830ed081 --- /dev/null +++ b/requirements-cpu.txt @@ -0,0 +1,13 @@ +einops +flax +gradio +huggingface-hub +jax +jaxlib +ml_collections +numpy +orbax-checkpoint +Pillow +psutil +sentencepiece +tensorflow diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..8fa778537ea48a6629baab79680ab7ec388a0ab3 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,14 @@ +einops +flax +gradio +huggingface-hub +-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html +jax[cuda12_pip]~=0.4.25 +jaxlib +ml_collections +numpy +orbax-checkpoint +Pillow +psutil +sentencepiece +tensorflow-cpu diff --git a/vae-oid.npz b/vae-oid.npz new file mode 100644 index 0000000000000000000000000000000000000000..e30bd245fc0b67df063c5bd49d83c7130bba2637 --- /dev/null +++ b/vae-oid.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5586010257b8536dddefab65e7755077f21d5672d5674dacf911f73ae95a4447 +size 8479556