keithhon commited on
Commit
ecc7fff
·
1 Parent(s): f6ad06e

Upload dalle/models/__init__.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. dalle/models/__init__.py +198 -0
dalle/models/__init__.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------
2
+ # Minimal DALL-E
3
+ # Copyright (c) 2021 KakaoBrain. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------
6
+
7
+ import os
8
+ import torch
9
+ import logging
10
+ import torch.nn as nn
11
+ import pytorch_lightning as pl
12
+ from typing import Optional, Tuple
13
+ from omegaconf import OmegaConf
14
+ from torch.cuda.amp import autocast
15
+ from torch.optim.lr_scheduler import CosineAnnealingLR
16
+ from torch.nn import functional as F
17
+ from .stage1.vqgan import VQGAN
18
+ from .stage2.transformer import Transformer1d, iGPT
19
+ from .. import utils
20
+ from ..utils.config import get_base_config
21
+ from ..utils.sampling import sampling, sampling_igpt
22
+ from .tokenizer import build_tokenizer
23
+
24
+ _MODELS = {
25
+ 'minDALL-E/1.3B': 'https://arena.kakaocdn.net/brainrepo/models/minDALL-E/57b008f02ceaa02b779c8b7463143315/1.3B.tar.gz'
26
+ }
27
+
28
+
29
+ class Dalle(nn.Module):
30
+ def __init__(self,
31
+ config: OmegaConf) -> None:
32
+ super().__init__()
33
+ self.tokenizer = None
34
+ self.stage1 = VQGAN(n_embed=config.stage1.n_embed,
35
+ embed_dim=config.stage1.embed_dim,
36
+ hparams=config.stage1.hparams)
37
+ self.stage2 = Transformer1d(vocab_size_txt=config.stage2.vocab_size_txt,
38
+ vocab_size_img=config.stage2.vocab_size_img,
39
+ hparams=config.stage2.hparams)
40
+ self.config_stage1 = config.stage1
41
+ self.config_stage2 = config.stage2
42
+ self.config_dataset = config.dataset
43
+
44
+ @classmethod
45
+ def from_pretrained(cls,
46
+ path: str) -> nn.Module:
47
+ config_base = get_base_config()
48
+ config_new = OmegaConf.load('config.yaml')
49
+ config_update = OmegaConf.merge(config_base, config_new)
50
+
51
+ model = cls(config_update)
52
+ model.tokenizer = build_tokenizer('tokenizer',
53
+ context_length=model.config_dataset.context_length,
54
+ lowercase=True,
55
+ dropout=None)
56
+ return model
57
+
58
+ @torch.no_grad()
59
+ def sampling(self,
60
+ prompt: str,
61
+ top_k: int = 256,
62
+ top_p: Optional[float] = None,
63
+ softmax_temperature: float = 1.0,
64
+ num_candidates: int = 96,
65
+ device: str = 'cuda:0',
66
+ use_fp16: bool = True) -> torch.FloatTensor:
67
+ self.stage1.eval()
68
+ self.stage2.eval()
69
+
70
+ tokens = self.tokenizer.encode(prompt)
71
+ tokens = torch.LongTensor(tokens.ids)
72
+ tokens = torch.repeat_interleave(tokens.unsqueeze(0), num_candidates, dim=0)
73
+
74
+ # Check if the encoding works as intended
75
+ # print(self.tokenizer.decode_batch(tokens.tolist(), skip_special_tokens=True)[0])
76
+
77
+ tokens = tokens.to(device)
78
+ codes = sampling(self.stage2,
79
+ tokens,
80
+ top_k=top_k,
81
+ top_p=top_p,
82
+ softmax_temperature=softmax_temperature,
83
+ use_fp16=use_fp16)
84
+ codes = codes.view(num_candidates, 16, 16) # [B, 16, 16]
85
+ pixels = torch.clamp(self.stage1.decode_code(codes) * 0.5 + 0.5, 0, 1) # [B, 256, 256]
86
+ return pixels
87
+
88
+
89
+ class ImageGPT(pl.LightningModule):
90
+ def __init__(self,
91
+ config: OmegaConf) -> None:
92
+ super().__init__()
93
+ self.stage1 = VQGAN(n_embed=config.stage1.n_embed,
94
+ embed_dim=config.stage1.embed_dim,
95
+ hparams=config.stage1.hparams)
96
+ self.stage2 = iGPT(vocab_size_img=config.stage2.vocab_size_img,
97
+ use_cls_cond=config.stage2.use_cls_cond,
98
+ hparams=config.stage2.hparams)
99
+ self.config = config
100
+ self.use_cls_cond = config.stage2.use_cls_cond
101
+
102
+ # make the parameters in stage 1 not trainable
103
+ self.stage1.eval()
104
+ for p in self.stage1.parameters():
105
+ p.requires_grad = False
106
+
107
+ @classmethod
108
+ def from_pretrained(cls,
109
+ path_upstream: str,
110
+ path_downstream: str) -> Tuple[nn.Module, OmegaConf]:
111
+ config_base = get_base_config(use_default=False)
112
+ config_down = OmegaConf.load(path_downstream)
113
+ config_down = OmegaConf.merge(config_base, config_down)
114
+
115
+ model = cls(config_down)
116
+ model.stage1.from_ckpt(os.path.join(path_upstream, 'stage1_last.ckpt'), strict=True)
117
+ model.stage2.from_ckpt(os.path.join(path_upstream, 'stage2_last.ckpt'), strict=False)
118
+ return model, config_down
119
+
120
+ def sample(self,
121
+ cls_idx: Optional[int] = None,
122
+ top_k: int = 256,
123
+ top_p: Optional[float] = None,
124
+ softmax_temperature: float = 1.0,
125
+ num_candidates: int = 16,
126
+ device: str = 'cuda:0',
127
+ use_fp16: bool = True,
128
+ is_tqdm: bool = True) -> torch.FloatTensor:
129
+ self.stage1.eval()
130
+ self.stage2.eval()
131
+
132
+ if cls_idx is None:
133
+ sos = self.stage2.sos.repeat(num_candidates, 1, 1)
134
+ else:
135
+ sos = torch.LongTensor([cls_idx]).to(device=device)
136
+ sos = sos.repeat(num_candidates)
137
+ sos = self.stage2.sos(sos).unsqueeze(1)
138
+
139
+ codes = sampling_igpt(self.stage2,
140
+ sos=sos,
141
+ top_k=top_k,
142
+ top_p=top_p,
143
+ softmax_temperature=softmax_temperature,
144
+ use_fp16=use_fp16,
145
+ is_tqdm=is_tqdm)
146
+ codes = codes.view(num_candidates, 16, 16) # [B, 16, 16]
147
+ pixels = torch.clamp(self.stage1.decode_code(codes) * 0.5 + 0.5, 0, 1) # [B, 256, 256]
148
+ return pixels
149
+
150
+ def forward(self,
151
+ images: torch.FloatTensor,
152
+ labels: Optional[torch.LongTensor] = None) -> torch.FloatTensor:
153
+ B, C, H, W = images.shape
154
+ with torch.no_grad():
155
+ with autocast(enabled=False):
156
+ codes = self.stage1.get_codes(images).detach()
157
+ logits = self.stage2(codes, labels)
158
+ return logits, codes
159
+
160
+ def training_step(self, batch, batch_idx):
161
+ images, labels = batch
162
+ logits, codes = self(images, labels=labels if self.use_cls_cond else None)
163
+ loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), codes.view(-1))
164
+ self.log("train/loss", loss, on_step=True, on_epoch=True, prog_bar=False, logger=True)
165
+ return loss
166
+
167
+ def validation_step(self, batch, batch_idx):
168
+ images, labels = batch
169
+ logits, codes = self(images, labels=labels if self.use_cls_cond else None)
170
+ loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), codes.view(-1))
171
+ self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)
172
+ return loss
173
+
174
+ def configure_optimizers(self):
175
+ assert self.config.optimizer.opt_type == 'adamW'
176
+ assert self.config.optimizer.sched_type == 'cosine'
177
+
178
+ opt = torch.optim.AdamW(self.parameters(),
179
+ lr=self.config.optimizer.base_lr,
180
+ betas=self.config.optimizer.betas,
181
+ weight_decay=self.config.optimizer.weight_decay)
182
+ sched = CosineAnnealingLR(opt,
183
+ T_max=self.config.optimizer.max_steps,
184
+ eta_min=self.config.optimizer.min_lr)
185
+ sched = {
186
+ 'scheduler': sched,
187
+ 'name': 'cosine'
188
+ }
189
+ return [opt], [sched]
190
+
191
+ def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure,
192
+ on_tpu=False, using_native_amp=False, using_lbfgs=False):
193
+ optimizer.step(closure=optimizer_closure)
194
+ self.lr_schedulers().step()
195
+ self.log("lr", self.lr_schedulers().get_last_lr()[0], on_step=True, on_epoch=False, prog_bar=True, logger=True)
196
+
197
+ def on_epoch_start(self):
198
+ self.stage1.eval()