"""Model-related code and constants.""" import dataclasses import os import re import PIL.Image # pylint: disable=g-bad-import-order import gradio_helpers import paligemma_bv #ORGANIZATION = 'google' ORGANIZATION = 'cocktailpeanut' BASE_MODELS = [ ('pg', 'paligemma-3b-mix-224'), ('pg', 'paligemma-3b-mix-448'), # ('paligemma-3b-mix-224-jax', 'paligemma-3b-mix-224'), # ('paligemma-3b-mix-448-jax', 'paligemma-3b-mix-448'), ] MODELS = { **{ model_name: ( f'{ORGANIZATION}/{repo}', f'{model_name}.bf16.npz', 'main', #'bfloat16', # Model repo revision. ) for repo, model_name in BASE_MODELS }, } MODELS_INFO = { 'paligemma-3b-mix-224': ( 'JAX/FLAX PaliGemma 3B weights, finetuned with 224x224 input images and 256 token input/output ' 'text sequences on a mixture of downstream academic datasets. The models are available in float32, ' 'bfloat16 and float16 format for research purposes only.' ), 'paligemma-3b-mix-448': ( 'JAX/FLAX PaliGemma 3B weights, finetuned with 448x448 input images and 512 token input/output ' 'text sequences on a mixture of downstream academic datasets. The models are available in float32, ' 'bfloat16 and float16 format for research purposes only.' ), } MODELS_RES_SEQ = { 'paligemma-3b-mix-224': (224, 256), 'paligemma-3b-mix-448': (448, 512), } # "CPU basic" has 16G RAM, "T4 small" has 15 GB RAM. # Below value should be smaller than "available RAM - one model". # A single bf16 is about 5860 MB. MAX_RAM_CACHE = int(float(os.environ.get('RAM_CACHE_GB', '0')) * 1e9) config = paligemma_bv.PaligemmaConfig( ckpt='', # will be set below res=224, text_len=64, tokenizer='gemma(tokensets=("loc", "seg"))', vocab_size=256_000 + 1024 + 128, ) def get_cached_model( model_name: str, ) -> tuple[paligemma_bv.PaliGemmaModel, paligemma_bv.ParamsCpu]: """Returns model and params, using RAM cache.""" res, seq = MODELS_RES_SEQ[model_name] model_path = gradio_helpers.get_paths()[model_name] config_ = dataclasses.replace(config, ckpt=model_path, res=res, text_len=seq) model, params_cpu = gradio_helpers.get_memory_cache( config_, lambda: paligemma_bv.load_model(config_), max_cache_size_bytes=MAX_RAM_CACHE, ) return model, params_cpu def generate( model_name: str, sampler: str, image: PIL.Image.Image, prompt: str ) -> str: """Generates output with specified `model_name`, `sampler`.""" model, params_cpu = get_cached_model(model_name) batch = model.shard_batch(model.prepare_batch([image], [prompt])) with gradio_helpers.timed('sharding'): params = model.shard_params(params_cpu) with gradio_helpers.timed('computation', start_message=True): tokens = model.predict(params, batch, sampler=sampler) return model.tokenizer.to_str(tokens[0])