yyk19 commited on
Commit
92b8ea7
·
1 Parent(s): bd89e06

add number of rows parameters

Browse files
Files changed (2) hide show
  1. app.py +2 -1
  2. cldm/glyph_control.py +0 -230
app.py CHANGED
@@ -112,7 +112,8 @@ with block:
112
  exec(f"""top_left_x_{i} = gr.Slider(label="Bbox Top Left x", minimum=0., maximum=1, value={0.35 - 0.25 * math.cos(math.pi * i)}, step=0.01) """)
113
  exec(f"""top_left_y_{i} = gr.Slider(label="Bbox Top Left y", minimum=0., maximum=1, value={0.1 if i < 2 else 0.6}, step=0.01) """)
114
  exec(f"""yaw_{i} = gr.Slider(label="Bbox Yaw", minimum=-180, maximum=180, value=0, step=5) """)
115
- exec(f"""num_rows_{i} = gr.Slider(label="num_rows", minimum=1, maximum=4, value=1, step=1, visible=False) """)
 
116
 
117
  with gr.Row():
118
  with gr.Column():
 
112
  exec(f"""top_left_x_{i} = gr.Slider(label="Bbox Top Left x", minimum=0., maximum=1, value={0.35 - 0.25 * math.cos(math.pi * i)}, step=0.01) """)
113
  exec(f"""top_left_y_{i} = gr.Slider(label="Bbox Top Left y", minimum=0., maximum=1, value={0.1 if i < 2 else 0.6}, step=0.01) """)
114
  exec(f"""yaw_{i} = gr.Slider(label="Bbox Yaw", minimum=-180, maximum=180, value=0, step=5) """)
115
+ # exec(f"""num_rows_{i} = gr.Slider(label="num_rows", minimum=1, maximum=4, value=1, step=1, visible=False) """)
116
+ exec(f"""num_rows_{i} = gr.Slider(label="num_rows", minimum=1, maximum=4, value=1, step=1) """)
117
 
118
  with gr.Row():
119
  with gr.Column():
cldm/glyph_control.py DELETED
@@ -1,230 +0,0 @@
1
- import torch.nn as nn
2
- from ldm.modules.encoders.modules import OpenCLIPImageEmbedder, FrozenOpenCLIPEmbedder
3
- from ldm.util import instantiate_from_config
4
- import torch
5
- from taming.models.vqgan import VQModelInterfaceEncoder, VQModel
6
- from ldm.modules.attention import SpatialTransformer
7
- from ldm.modules.attention import Normalize, BasicTransformerBlock#, exists
8
- from ldm.modules.diffusionmodules.util import zero_module, identity_init_fc, conv_nd
9
- from einops import rearrange
10
- # from ldm.modules.diffusionmodules.openaimodel import TimestepEmbedSequential
11
- def disabled_train(self, mode=True):
12
- """Overwrite model.train with this function to make sure train/eval mode
13
- does not change anymore."""
14
- return self
15
-
16
-
17
-
18
- def make_zero_conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
19
- return zero_module(conv_nd(2, in_channels, out_channels, kernel_size, stride=stride, padding=padding))
20
-
21
-
22
- class SpatialTransformer_v2(nn.Module):
23
- """
24
- Transformer block for image-like data.
25
- First, project the input (aka embedding)
26
- and reshape to b, t, d.
27
- Then apply standard transformer action.
28
- Finally, reshape to image
29
- NEW: use_linear for more efficiency instead of the 1x1 convs
30
- """
31
- def __init__(self, in_channels, n_heads, d_head,
32
- depth=1, dropout=0., context_dim=None,
33
- disable_self_attn=False, use_linear=False,
34
- use_checkpoint=True):
35
- super().__init__()
36
- # change:
37
- # if exists(context_dim) and not isinstance(context_dim, list):
38
- if not isinstance(context_dim, list):
39
- context_dim = [context_dim]
40
- self.in_channels = in_channels
41
- inner_dim = n_heads * d_head
42
- self.norm = Normalize(in_channels)
43
- if not use_linear:
44
- self.proj_in = nn.Conv2d(in_channels,
45
- inner_dim,
46
- kernel_size=1,
47
- stride=1,
48
- padding=0)
49
- else:
50
- self.proj_in = nn.Linear(in_channels, inner_dim)
51
-
52
- self.transformer_blocks = nn.ModuleList(
53
- [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
54
- disable_self_attn=disable_self_attn, checkpoint=use_checkpoint)
55
- for d in range(depth)]
56
- )
57
- if not use_linear:
58
- self.proj_out = zero_module(nn.Conv2d(inner_dim,
59
- in_channels,
60
- kernel_size=1,
61
- stride=1,
62
- padding=0))
63
- else:
64
- self.proj_out = zero_module(nn.Linear(inner_dim, in_channels)) # change: switch
65
- self.use_linear = use_linear
66
-
67
- def forward(self, x, context=None):
68
- # note: if no context is given, cross-attention defaults to self-attention
69
- if not isinstance(context, list):
70
- context = [context]
71
- b, c, h, w = x.shape
72
- x_in = x
73
- x = self.norm(x)
74
- if not self.use_linear:
75
- x = self.proj_in(x)
76
- x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
77
- if self.use_linear:
78
- x = self.proj_in(x)
79
- for i, block in enumerate(self.transformer_blocks):
80
- x = block(x, context=context[i])
81
- if self.use_linear:
82
- x = self.proj_out(x)
83
- x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
84
- if not self.use_linear:
85
- x = self.proj_out(x)
86
- return x + x_in
87
-
88
- class trans_glyph_emb(nn.Module):
89
- def __init__(self,
90
- type = "fc", # "conv", "attn"
91
- input_dim = 256,
92
- out_dim = 1024,
93
- # fc
94
- fc_init = "zero",
95
- # conv/attn
96
- conv_ks = 3,
97
- conv_pad = 1,
98
- conv_stride = 1,
99
- # attn
100
- ch = 512, # 1024
101
- num_heads = 8, # 16
102
- dim_head = 64,
103
- use_linear_in_transformer = True,
104
- use_checkpoint = False, #True,
105
- ):
106
- super().__init__()
107
-
108
- if type == "fc":
109
- self.model = torch.nn.Linear(input_dim, out_dim)
110
- if fc_init == "zero":
111
- self.model = zero_module(self.model)
112
- elif fc_init == "identity":
113
- self.model = identity_init_fc(self.model)
114
- elif type == "conv":
115
- self.model = make_zero_conv(input_dim, out_dim, conv_ks, stride = conv_stride, padding = conv_pad)
116
- elif type == "attn":
117
- model = [
118
- # nn.Conv2d(input_dim, ch, 3, stride = 1, padding = 1),
119
- nn.Conv2d(input_dim, ch, conv_ks, stride = conv_stride, padding = conv_pad),
120
- SpatialTransformer_v2( #SpatialTransformer(
121
- ch, num_heads, dim_head, depth=1, context_dim=None, #ch,
122
- disable_self_attn=False, use_linear=use_linear_in_transformer,
123
- use_checkpoint=use_checkpoint, # False if the context is None
124
- ),
125
- make_zero_conv(ch, out_dim, 1, stride = 1, padding = 0)
126
- # make_zero_conv(ch, out_dim, conv_ks, stride = conv_stride, padding = conv_pad)
127
- ]
128
- self.model = nn.Sequential(*model)
129
- self.model_type = type
130
-
131
- def forward(self, x):
132
- if self.model_type == "fc":
133
- # b, c, h, w = x.shape
134
- x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
135
- x = self.model(x)
136
- # x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
137
- # return x
138
- else:
139
- x = self.model(x)
140
- x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
141
- return x
142
-
143
-
144
-
145
- class glyph_control(nn.Module):
146
- def __init__(self,
147
- image_encoder = "CLIP", # "VQGAN"
148
- image_encoder_config = None,
149
- fuse_way = "concat",
150
- load_text_encoder = False,
151
- text_encoder_config = None,
152
- freeze_image_encoder = True,
153
- trans_emb = False,
154
- trans_emb_config = None,
155
- # use_fp16 = False,
156
- ):
157
- super().__init__()
158
- if image_encoder_config is not None:
159
- image_encoder_config.params.freeze = freeze_image_encoder
160
- self.image_encoder = instantiate_from_config(image_encoder_config)
161
- else:
162
- if image_encoder == "CLIP":
163
- self.image_encoder = OpenCLIPImageEmbedder(freeze=freeze_image_encoder)
164
- elif image_encoder == "VQGAN":
165
- print("VQGAN glyph image encoder is missing config")
166
- raise ValueError
167
- else:
168
- print("Other types of glyph image encoder are not supported")
169
- raise ValueError
170
-
171
- if freeze_image_encoder:
172
- self.freeze_imenc()
173
- self.freeze_image_encoder = freeze_image_encoder
174
- self.image_encoder_type = image_encoder
175
-
176
-
177
- if load_text_encoder:
178
- if text_encoder_config is None:
179
- self.text_encoder = FrozenOpenCLIPEmbedder()
180
- else:
181
- self.text_encoder = instantiate_from_config(text_encoder_config)
182
- self.fuse_way = fuse_way
183
- # self.dtype = torch.float16 if use_fp16 else torch.float32
184
- if trans_emb:
185
- if trans_emb_config is not None:
186
- self.trans_glyph_emb_model = instantiate_from_config(trans_emb_config)
187
- else:
188
- self.trans_glyph_emb_model = trans_glyph_emb()
189
- else:
190
- self.trans_glyph_emb_model = None
191
-
192
- def freeze_imenc(self):
193
- self.image_encoder = self.image_encoder.eval()
194
- self.image_encoder.train = disabled_train
195
- for param in self.image_encoder.parameters():
196
- param.requires_grad = False
197
-
198
- def forward(self, glyph_image, text = None, text_embed = None):
199
- clgim_num_list = [img.shape[0] for img in glyph_image]
200
- # image_embeds = self.image_encoder(torch.concat(glyph_image, dim=0))
201
- gim_concat = torch.concat(glyph_image, dim=0)
202
- image_embeds = self.image_encoder(gim_concat)
203
- if self.trans_glyph_emb_model is not None:
204
- image_embeds = self.trans_glyph_emb_model(image_embeds)
205
- image_embeds = torch.split(image_embeds, clgim_num_list)
206
- max_image_tokens = max(clgim_num_list)
207
- pad_image_embeds = []
208
- for image_embed in image_embeds:
209
- if image_embed.shape[0] < max_image_tokens:
210
- image_embed = torch.concat([
211
- image_embed,
212
- torch.zeros(
213
- (max_image_tokens - image_embed.shape[0], *image_embed.shape[1:]), device=image_embed.device, dtype=image_embed.dtype, # add dtype
214
- )], dim=0
215
- )
216
- pad_image_embeds.append(image_embed)
217
- pad_image_embeds = torch.stack(pad_image_embeds, dim = 0)
218
- if text_embed is None:
219
- assert self.text_encoder, text is not None
220
- text_embed = self.text_encoder(text)
221
- if self.fuse_way == "concat":
222
- assert pad_image_embeds.shape[-1] == text_embed.shape[-1]
223
- if len(pad_image_embeds.shape) == 4:
224
- b, _, _ , embdim = pad_image_embeds.shape
225
- pad_image_embeds = pad_image_embeds.view(b, -1, embdim)
226
- out_embed = torch.concat([text_embed, pad_image_embeds], dim= 1)
227
- print("concat glyph_embed with text_embed:", out_embed.shape)
228
- return out_embed
229
- else:
230
- raise ValueError("Not support other fuse ways for now!")