paligemma_vqav2

This model is a fine-tuned version of google/paligemma-3b-pt-224 on a small chunk of vq_av2 dataset. Fine-tuning code is here.

How to Use

Below is the code to use this model. Also see inference notebook.

from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
from PIL import Image
import requests

model_id = "merve/paligemma_vqav2"
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id)
processor = AutoProcessor.from_pretrained("google/paligemma-3b-pt-224")

prompt = "What is behind the cat?"
image_file = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/cat.png?download=true"
raw_image = Image.open(requests.get(image_file, stream=True).raw)

inputs = processor(prompt, raw_image.convert("RGB"), return_tensors="pt")
output = model.generate(**inputs, max_new_tokens=20)

print(processor.decode(output[0], skip_special_tokens=True)[len(prompt):])
# gramophone

Training hyperparameters

The following hyperparameters were used during training:

  • learning_rate: 2e-05
  • train_batch_size: 4
  • eval_batch_size: 8
  • seed: 42
  • gradient_accumulation_steps: 4
  • total_train_batch_size: 16
  • optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
  • lr_scheduler_type: linear
  • lr_scheduler_warmup_steps: 2
  • num_epochs: 2

Training results

Framework versions

  • Transformers 4.42.0.dev0
  • Pytorch 2.3.0+cu121
  • Datasets 2.19.1
  • Tokenizers 0.19.1
Downloads last month
235
Safetensors
Model size
2.92B params
Tensor type
BF16
·
Inference API
Inference API (serverless) does not yet support transformers models for this pipeline type.

Model tree for merve/paligemma_vqav2

Finetuned
(42)
this model

Dataset used to train merve/paligemma_vqav2

Space using merve/paligemma_vqav2 1