Haowei Chen commited on
Commit
1fdf2ca
·
1 Parent(s): e5245e2

Sketch of text encoder merger

Browse files
Files changed (3) hide show
  1. .gitignore +2 -1
  2. pyproject.toml +8 -0
  3. text_encoder_merger.py +90 -0
.gitignore CHANGED
@@ -1 +1,2 @@
1
- *.lock
 
 
1
+ *.lock
2
+ debug.ipynb
pyproject.toml CHANGED
@@ -10,9 +10,17 @@ package-mode = false
10
  python = "^3.12"
11
  transformers = "^4.43.4"
12
  accelerate = "^0.33.0"
 
 
 
 
 
 
 
13
 
14
  [tool.poetry.group.dev.dependencies]
15
  jupyter = "^1.0.0"
 
16
 
17
  [build-system]
18
  requires = ["poetry-core"]
 
10
  python = "^3.12"
11
  transformers = "^4.43.4"
12
  accelerate = "^0.33.0"
13
+ protobuf = "^5.27.3"
14
+ torchvision = "^0.19.0"
15
+ datasets = "^2.21.0"
16
+ safetensors = "^0.4.4"
17
+ evaluate = "^0.4.2"
18
+ diffusers = "^0.30.0"
19
+ torch = "^2.4.0"
20
 
21
  [tool.poetry.group.dev.dependencies]
22
  jupyter = "^1.0.0"
23
+ autopep8 = "^2.3.1"
24
 
25
  [build-system]
26
  requires = ["poetry-core"]
text_encoder_merger.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig, PreTrainedModel
2
+ from torch import nn, tensor, concat
3
+ from diffusers.models.embeddings import get_timestep_embedding
4
+ import torch
5
+
6
+ class T5DiffusionXLTextEncoderMergerConfig(PretrainedConfig):
7
+
8
+ def __init__(self,
9
+ num_layers: int = 4,
10
+ dim_timestep_embeds: int = 16,
11
+ seq_len: int = 77,
12
+ channels_sdxl: int = 2048,
13
+ channels_t5: int = 4096,
14
+ channels_pooled: int = 1280,
15
+ **kwargs):
16
+ super().__init__(**kwargs)
17
+ self.num_layers = num_layers
18
+ self.dim_timestep_embeds = dim_timestep_embeds
19
+ self.seq_len = seq_len
20
+ self.channels_sdxl = channels_sdxl
21
+ self.channels_t5 = channels_t5
22
+ self.channels_pooled = channels_pooled
23
+
24
+
25
+ class T5DiffusionXLTextEncoderMerger(PreTrainedModel, nn.Module):
26
+
27
+ def __init__(self, config: T5DiffusionXLTextEncoderMergerConfig):
28
+ super().__init__(config)
29
+ self._last_timestep = 0
30
+ channels_concat = config.channels_sdxl + config.channels_t5
31
+ self.block_forward1 = nn.Sequential(
32
+ nn.Linear(channels_concat, channels_concat),
33
+ nn.LayerNorm([config.seq_len, channels_concat],
34
+ elementwise_affine=False))
35
+
36
+ layers = []
37
+ for _ in range(config.num_layers - 1):
38
+ layers.append(nn.Linear(channels_concat, channels_concat))
39
+ layers.append(nn.SiLU())
40
+ layers.append(nn.Linear(channels_concat, config.channels_sdxl))
41
+ layers.append(nn.Tanh())
42
+ self.block_forward2 = nn.Sequential(*layers)
43
+
44
+ self.block_modulate_by_pooled = nn.Sequential(
45
+ nn.Linear(config.channels_pooled, 512, bias=False), nn.SiLU(),
46
+ nn.Linear(512,
47
+ config.seq_len *
48
+ (channels_concat * 2 + config.channels_sdxl),
49
+ bias=False))
50
+
51
+ self.block_modulate_by_timestep = nn.Sequential(
52
+ nn.Linear(config.dim_timestep_embeds, 512, bias=False), nn.SiLU(),
53
+ nn.Linear(512,
54
+ config.seq_len *
55
+ (channels_concat * 2 + config.channels_sdxl),
56
+ bias=False))
57
+
58
+ def _init_weights(self, module):
59
+ if isinstance(module, nn.Linear):
60
+ module.weight.normal_(0, 0.1)
61
+ if module.bias is not None:
62
+ module.bias.zero_()
63
+
64
+ def forward(self, embeds_t5, embeds_sdxl, pooled_embeds_sdxl):
65
+ batch_size = embeds_sdxl.size(0)
66
+ assert batch_size == embeds_sdxl.size(0) == pooled_embeds_sdxl.size(0)
67
+ channels_sdxl = self.config.channels_sdxl
68
+ channels_concat = self.config.channels_t5 + channels_sdxl
69
+ seq_len = self.config.seq_len
70
+ timestep_embeds = get_timestep_embedding(
71
+ tensor([self._last_timestep]),
72
+ embedding_dim=self.config.dim_timestep_embeds).repeat(
73
+ batch_size, 1)
74
+ modulation = self.block_modulate_by_timestep(
75
+ timestep_embeds) + self.block_modulate_by_pooled(pooled_embeds_sdxl)
76
+ gamma, beta, zeta = [
77
+ m.view(batch_size, seq_len, -1) for m in modulation.split([
78
+ seq_len * channels_concat, seq_len * channels_concat, seq_len *
79
+ channels_sdxl
80
+ ],
81
+ dim=1)
82
+ ]
83
+ output = (gamma + 1) * self.block_forward1(
84
+ concat((embeds_t5, embeds_sdxl), dim=2)) + beta
85
+ output = (zeta + 1) * self.block_forward2(output)
86
+ output += embeds_sdxl
87
+ return {"output": output}
88
+
89
+ def set_timestep(self, timestep: int):
90
+ self._last_timestep = timestep