File size: 5,391 Bytes
f75691d
0b1e25d
 
 
 
 
 
 
 
 
f75691d
 
0b1e25d
75771e2
0b1e25d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
---
language: 
- fr

thumbnail: https://raw.githubusercontent.com/AntoineSimoulin/gpt-fr/main/imgs/logo.png
tags:
- tf
- pytorch
- gpt2
- text-to-image
license: apache-2.0
---

<img src="https://raw.githubusercontent.com/AntoineSimoulin/gpt-fr/main/imgs/igpt-logo.png" width="400">

## Model description

**iGPT-fr** 🇫🇷 is a GPT model for French pre-trained incremental language model developped by the [Laboratoire de Linguistique Formelle (LLF)](http://www.llf.cnrs.fr/en). We adapted [GPT-fr 🇫🇷](https://huggingface.co/asi/gpt-fr-cased-base) model to generate images conditionned by text inputs.

## Intended uses & limitations

The model can be leveraged for image generation tasks. The model is currently under a developpment phase.

#### How to use

The model might be used through the 🤗 `Transformers` librairie. You will also need to install the `Taming Transformers` library for high-resolution image synthesis:

```bash
pip install git+https://github.com/CompVis/taming-transformers.git
```

```python
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from huggingface_hub import hf_hub_download
from omegaconf import OmegaConf
from taming.models import vqgan
import torch
from PIL import Image
import numpy as np

# Load VQGAN model
vqgan_ckpt = hf_hub_download(repo_id="boris/vqgan_f16_16384", filename="model.ckpt", force_download=False)
vqgan_config = hf_hub_download(repo_id="boris/vqgan_f16_16384", filename="config.yaml", force_download=False)

config = OmegaConf.load(vqgan_config)
vqgan_model = vqgan.VQModel(**config.model.params)
vqgan_model.eval().requires_grad_(False)
vqgan_model.init_from_ckpt(vqgan_ckpt)

# Load pretrained model
model = GPT2LMHeadModel.from_pretrained("asi/igpt-fr-cased-base")
model.eval()
tokenizer = GPT2Tokenizer.from_pretrained("asi/igpt-fr-cased-base")

# Generate a sample of text
input_sentence = "Une carte de l'europe"
input_ids = tokenizer.encode(input_sentence, return_tensors='pt')
input_ids = torch.cat((input_ids, torch.tensor([[50000]])), 1)  # Add image generation token

greedy_output = model.generate(
  input_ids.to(device), 
  max_length=256+input_ids.shape[1],
  do_sample=True, 
  top_p=0.92, 
  top_k=0)

def custom_to_pil(x):
  x = x.detach().cpu()
  x = torch.clamp(x, -1., 1.)
  x = (x + 1.)/2.
  x = x.permute(1,2,0).numpy()
  x = (255*x).astype(np.uint8)
  x = Image.fromarray(x)
  if not x.mode == "RGB":
    x = x.convert("RGB")
  return x

z_idx = greedy_output[0, input_ids.shape[1]:] - 50001
z_quant = vqgan_model.quantize.get_codebook_entry(z_idx, shape=(1, 16, 16, 256))
x_rec = vqgan_model.decode(z_quant).to('cpu')[0]
display(custom_to_pil(x_rec))
```

You may also filter results based on CLIP:

```python
from tqdm import tqdm

def hallucinate(prompt, num_images=64):
    input_ids = tokenizer.encode(prompt, return_tensors='pt')
    input_ids = torch.cat((input_ids, torch.tensor([[50000]])), 1).to(device)  # Add image generation token

    all_images = []
    for i in tqdm(range(num_images)):
        greedy_output = model.generate(
          input_ids.to(device), 
          max_length=256+input_ids.shape[1],
          do_sample=True,
          top_p=0.92, 
          top_k=0)

        z_idx = greedy_output[0, input_ids.shape[1]:] - 50001
        z_quant = vqgan_model.quantize.get_codebook_entry(z_idx, shape=(1, 16, 16, 256))
        x_rec = vqgan_model.decode(z_quant).to('cpu')[0]
        all_images.append(custom_to_pil(x_rec))
    return all_images

input_sentence = "Une carte de l'europe"
all_images = hallucinate(input_sentence)

from transformers import pipeline

opus_model = "Helsinki-NLP/opus-mt-fr-en"
opus_translator = pipeline("translation", model=opus_model)

opus_translator(input_sentence)

from transformers import CLIPProcessor, CLIPModel

clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

def clip_top_k(prompt, images, k=8):
  prompt_fr = opus_translator(input_sentence)[0]['translation_text']
  inputs = clip_processor(text=prompt_fr, images=images, return_tensors="pt", padding=True)
  outputs = clip_model(**inputs)
  logits = outputs.logits_per_text # this is the image-text similarity score
  scores = np.array(logits[0].detach()).argsort()[-k:][::-1]
  return [images[score] for score in scores]

filtered_images = clip_top_k(input_sentence, all_images)

for fi in filtered_images:
  display(fi)
```

## Training data

We created a dedicated corpus to train our generative model. The training corpus consists in text-image pairs. We aggregated portions from existing corpora: [Laion-5B](https://laion.ai/blog/laion-5b/) and [WIT](https://github.com/google-research-datasets/wit). The final dataset includes 10,807,534 samples.

## Training procedure

We pre-trained the model on the new CNRS (French National Centre for Scientific Research) [Jean Zay](http://www.idris.fr/eng/jean-zay/) supercomputer. We perform the training within a total of 140 hours of computation on Tesla V-100 hardware (TDP of 300W). The training was distributed on 8 compute nodes of 8 GPUs. We used data parallelization in order to divide each micro-batch on the computing units. We estimated the total emissions at 1161.22 kgCO2eq, using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al., (2019)](lacoste-2019).