Update README.md
Browse files
README.md
CHANGED
@@ -31,12 +31,12 @@ You can use the raw model for either feature extractor or (un) conditional image
|
|
31 |
Here is how to use this model in PyTorch to perform unconditional image generation:
|
32 |
|
33 |
```python
|
34 |
-
from transformers import
|
35 |
import torch
|
36 |
import matplotlib.pyplot as plt
|
37 |
import numpy as np
|
38 |
|
39 |
-
|
40 |
model = ImageGPTForCausalImageModeling.from_pretrained('openai/imagegpt-medium')
|
41 |
|
42 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
@@ -48,8 +48,8 @@ context = torch.full((batch_size, 1), model.config.vocab_size - 1) #initialize w
|
|
48 |
context = torch.tensor(context).to(device)
|
49 |
output = model.generate(pixel_values=context, max_length=model.config.n_positions + 1, temperature=1.0, do_sample=True, top_k=40)
|
50 |
|
51 |
-
clusters =
|
52 |
-
n_px =
|
53 |
|
54 |
samples = output[:,1:].cpu().detach().numpy()
|
55 |
samples_img = [np.reshape(np.rint(127.5 * (clusters[s] + 1.0)), [n_px, n_px, 3]).astype(np.uint8) for s in samples] # convert color cluster tokens back to pixels
|
|
|
31 |
Here is how to use this model in PyTorch to perform unconditional image generation:
|
32 |
|
33 |
```python
|
34 |
+
from transformers import ImageGPTImageProcessor, ImageGPTForCausalImageModeling
|
35 |
import torch
|
36 |
import matplotlib.pyplot as plt
|
37 |
import numpy as np
|
38 |
|
39 |
+
processor = ImageGPTImageProcessor.from_pretrained('openai/imagegpt-medium')
|
40 |
model = ImageGPTForCausalImageModeling.from_pretrained('openai/imagegpt-medium')
|
41 |
|
42 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
48 |
context = torch.tensor(context).to(device)
|
49 |
output = model.generate(pixel_values=context, max_length=model.config.n_positions + 1, temperature=1.0, do_sample=True, top_k=40)
|
50 |
|
51 |
+
clusters = processor.clusters
|
52 |
+
n_px = processor.size
|
53 |
|
54 |
samples = output[:,1:].cpu().detach().numpy()
|
55 |
samples_img = [np.reshape(np.rint(127.5 * (clusters[s] + 1.0)), [n_px, n_px, 3]).astype(np.uint8) for s in samples] # convert color cluster tokens back to pixels
|