|
--- |
|
license: apache-2.0 |
|
tags: |
|
- image-captioning |
|
languages: |
|
- en |
|
datasets: |
|
- michelecafagna26/hl |
|
language: |
|
- en |
|
metrics: |
|
- sacrebleu |
|
- rouge |
|
library_name: transformers |
|
--- |
|
## ClipCap fine-tuned for Action Image Captioning |
|
|
|
[ClipCap](https://arxiv.org/abs/2111.09734) base trained on the [HL Dataset](https://huggingface.co/datasets/michelecafagna26/hl) for **high-level action descriptions generation** |
|
|
|
## Model fine-tuning ποΈβ |
|
|
|
We fine-tune LM + Mapping Network starting from the model pretrained on COCO |
|
|
|
- Trained for 10 epochs |
|
- lr: 5eβ5 |
|
- Adam optimizer |
|
- half-precision (fp16) |
|
|
|
## Test set metrics π§Ύ |
|
|
|
| Cider | SacreBLEU | Rouge-L| |
|
|---------|------------|--------| |
|
| 176.54 | 27.37 | 39.15 | |
|
|
|
## Demo |
|
|
|
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1Rw9_oNNfP2QsIpekmJhRHAXv_6MX-0ur?usp=sharing) |
|
|
|
## Installation |
|
|
|
```bash |
|
pip install git+https://github.com/michelecafagna26/CLIPCap.git |
|
``` |
|
|
|
## Download the model |
|
|
|
```bash |
|
git lfs install # if not installed |
|
git clone https://huggingface.co/michelecafagna26/clipcap-base-captioning-ft-hl-actions |
|
``` |
|
|
|
## Model in Action π |
|
|
|
|
|
```python |
|
from clipcap import ClipCaptionModel |
|
from transformers import ( |
|
GPT2Tokenizer, |
|
GPT2LMHeadModel, |
|
) |
|
import torch |
|
import clip |
|
import requests |
|
from PIL import Image |
|
|
|
model_path = "clipcap-base-captioning-ft-hl-actions/pytorch_model.pt" # change accordingly |
|
|
|
# load clip |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
clip_model, preprocess = clip.load("ViT-B/32", device=device, jit=False) |
|
tokenizer = GPT2Tokenizer.from_pretrained("gpt2") |
|
prefix_length = 10 |
|
|
|
# load ClipCap |
|
model = ClipCaptionModel(prefix_length, tokenizer=tokenizer) |
|
model.from_pretrained(model_path) |
|
model = model.eval() |
|
model = model.to(device) |
|
|
|
# load the image |
|
img_url = 'https://datasets-server.huggingface.co/assets/michelecafagna26/hl/--/default/train/0/image/image.jpg' |
|
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB') |
|
|
|
|
|
# extract the prefix |
|
image = preprocess(raw_image).unsqueeze(0).to(device) |
|
with torch.no_grad(): |
|
prefix = clip_model.encode_image(image).to( |
|
device, dtype=torch.float32 |
|
) |
|
prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1) |
|
|
|
# generate the caption |
|
model.generate_beam(embed=prefix_embed)[0] |
|
|
|
|
|
# >> "she is posing for a photo." |
|
``` |
|
|
|
## BibTex and citation info |
|
|
|
```BibTeX |
|
``` |