File size: 1,720 Bytes
c082bd8
 
 
 
 
 
 
 
 
 
 
53a4124
 
 
 
 
 
 
 
c082bd8
 
53a4124
 
 
 
de2f517
53a4124
de2f517
df8cf53
 
 
53a4124
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
---
license: mit
language:
- en
metrics:
- code_eval
library_name: transformers
pipeline_tag: image-to-text
tags:
- text-generation-inference
---
<u><b>We are creating a spatial aware vision-language(VL) model.</b></u>

This is a trained model on COCO dataset images including extra information regarding the spatial relationship between the entities of the image.

This is a sequence to sequence model for image-captioning. The architecture is <u><b>ViT encoder and GPT2 decoder.</b></u>

<details>
  <summary>Requirements!</summary>
- 4GB GPU RAM.
- CUDA enabled docker
</details>

The way to download and run this: 
```python
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
from transformers import pipeline
image_captioner = pipeline("image-to-text", model="sadassa17/rgb-language_cap", max_new_tokens=200, device=device)
filename = 'path/to/file'
generated_captions = image_captioner(filename)
print(generated_captions)
```

The model is trained to produce as many words as possible with a maximum of 200 tokens, which translates to roughly 5 sentences, while the 6th sentence is usually cropped.

<i>The output is always of that form: "Object1" is to the "Left/Right etc." of the "Object2".</i>

## IF YOU WANT TO PRODUCE A SPECIFIC NUMBER OF CAPTIONS UP TO 5.
```python
import os
def print_up_to_n_sentences(captions, n):
    for caption in captions:
        generated_text = caption.get('generated_text', '')
        sentences = generated_text.split('.')
        result = '.'.join(sentences[:n])
        #print(result)
    return result
filename = 'path/to/file'

generated_captions = image_captioner(filename)
captions = print_up_to_n_sentences(generated_captions, 5)
print(captions)
```