chumengl commited on
Commit
d443994
·
1 Parent(s): ef557d2

upload model

Browse files
Files changed (7) hide show
  1. README.md +42 -1
  2. __init__.py +6 -0
  3. config.json +29 -0
  4. config.py +63 -0
  5. inference.py +60 -0
  6. model.py +551 -0
  7. model.safetensors +3 -0
README.md CHANGED
@@ -1,3 +1,44 @@
1
  ---
2
- license: mit
 
 
 
 
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ datasets:
3
+ - Skylion007/openwebtext
4
+ language:
5
+ - en
6
+ library_name: transformers
7
+ license: apache-2.0
8
+ metrics:
9
+ - perplexity
10
+ pipeline_tag: text-generation
11
  ---
12
+
13
+ # LangFlow
14
+
15
+ LangFlow is a continuous diffusion language model that operates in embedding space. Unlike discrete diffusion models (MDLM, SEDD, DUO), LangFlow performs diffusion directly on continuous token embeddings, enabling smoother denoising dynamics.
16
+
17
+ ## Using LangFlow
18
+
19
+ To use the pre-trained model for text generation, use the following snippet:
20
+
21
+ ```python
22
+ from transformers import AutoModelForMaskedLM, AutoTokenizer
23
+
24
+ tokenizer = AutoTokenizer.from_pretrained('gpt2')
25
+ model = AutoModelForMaskedLM.from_pretrained('chumengl/langflow-owt', trust_remote_code=True)
26
+
27
+ # Generate samples
28
+ samples = model.generate_samples(num_samples=5, num_steps=128)
29
+ texts = tokenizer.batch_decode(samples, skip_special_tokens=True)
30
+ for text in texts:
31
+ print(text)
32
+ ```
33
+
34
+ ## Model Details
35
+
36
+ - **Architecture**: DiT (Diffusion Transformer) backbone with adaptive layer normalization
37
+ - **Context Length**: 1024 tokens
38
+ - **Parameters**: ~130M non-embedding parameters (similar to GPT-2 medium)
39
+ - **Training**: 1M steps on OpenWebText corpus
40
+ - **Tokenizer**: GPT-2 tokenizer (50,257 vocab size)
41
+
42
+ ## Model Card Contact
43
+
44
+ TODO (chumengl@illinois.edu)
__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ """HuggingFace release package for LangFlow."""
2
+
3
+ from .config import LangFlowConfig
4
+ from .model import LangFlow
5
+
6
+ __all__ = ["LangFlowConfig", "LangFlow"]
config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "chumengl/langflow-owt",
3
+ "architectures": [
4
+ "LangFlow"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "config.LangFlowConfig",
8
+ "AutoModelForMaskedLM": "model.LangFlow"
9
+ },
10
+ "model_type": "LangFlow",
11
+ "vocab_size": 50257,
12
+ "hidden_size": 768,
13
+ "cond_dim": 128,
14
+ "n_blocks": 12,
15
+ "n_heads": 12,
16
+ "dropout": 0.1,
17
+ "model_length": 1024,
18
+ "use_normalized_embedding": true,
19
+ "embedding_norm_method": "layernorm",
20
+ "self_conditioning": true,
21
+ "use_bias": true,
22
+ "gumbel_loc": 4.723,
23
+ "gumbel_scale": 0.852,
24
+ "gumbel_cutoff": 1e-5,
25
+ "gumbel_entropy": 7.02,
26
+ "return_dict": true,
27
+ "torch_dtype": "float32",
28
+ "transformers_version": "4.38.2"
29
+ }
config.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """HuggingFace configuration class for LangFlow."""
2
+
3
+ import transformers
4
+
5
+
6
+ class LangFlowConfig(transformers.PretrainedConfig):
7
+ """HuggingFace configuration class for LangFlow.
8
+
9
+ LangFlow is a continuous diffusion language model that operates in embedding space.
10
+ It uses a DiT (Diffusion Transformer) backbone with adaptive layer normalization.
11
+
12
+ Key features:
13
+ - Continuous diffusion in embedding space
14
+ - Self-conditioning: uses previous predictions as additional input
15
+ - Bias (preconditioning): skip connection for improved training
16
+ - Normalized embeddings: layernorm on embedding vectors
17
+ - Learnable Gumbel proposal for gamma (log-SNR) sampling
18
+ """
19
+ model_type = "LangFlow"
20
+
21
+ def __init__(
22
+ self,
23
+ vocab_size: int = 50257,
24
+ hidden_size: int = 768,
25
+ cond_dim: int = 128,
26
+ n_blocks: int = 12,
27
+ n_heads: int = 12,
28
+ dropout: float = 0.1,
29
+ model_length: int = 1024,
30
+ # Embedding normalization
31
+ use_normalized_embedding: bool = True,
32
+ embedding_norm_method: str = "layernorm",
33
+ # Self-conditioning
34
+ self_conditioning: bool = True,
35
+ # Bias (preconditioning) - always enabled for inference
36
+ use_bias: bool = True,
37
+ # Gumbel proposal parameters (learnable)
38
+ gumbel_loc: float = 4.723,
39
+ gumbel_scale: float = 0.852,
40
+ gumbel_cutoff: float = 1e-5,
41
+ gumbel_entropy: float = 7.02,
42
+ **kwargs
43
+ ):
44
+ super().__init__(**kwargs)
45
+ self.vocab_size = vocab_size
46
+ self.hidden_size = hidden_size
47
+ self.cond_dim = cond_dim
48
+ self.n_blocks = n_blocks
49
+ self.n_heads = n_heads
50
+ self.dropout = dropout
51
+ self.model_length = model_length
52
+ # Embedding normalization
53
+ self.use_normalized_embedding = use_normalized_embedding
54
+ self.embedding_norm_method = embedding_norm_method
55
+ # Self-conditioning
56
+ self.self_conditioning = self_conditioning
57
+ # Bias (preconditioning)
58
+ self.use_bias = use_bias
59
+ # Gumbel proposal parameters
60
+ self.gumbel_loc = gumbel_loc
61
+ self.gumbel_scale = gumbel_scale
62
+ self.gumbel_cutoff = gumbel_cutoff
63
+ self.gumbel_entropy = gumbel_entropy
inference.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Simple inference script to test the HuggingFace LangFlow model."""
2
+
3
+ import argparse
4
+ import torch
5
+ from transformers import AutoModelForMaskedLM, AutoTokenizer
6
+
7
+
8
+ def main():
9
+ parser = argparse.ArgumentParser(description="Generate samples with LangFlow")
10
+ parser.add_argument(
11
+ "--model_path", type=str, default="hf_release/model_weights",
12
+ help="Path to the HuggingFace model directory")
13
+ parser.add_argument(
14
+ "--num_samples", type=int, default=5,
15
+ help="Number of samples to generate")
16
+ parser.add_argument(
17
+ "--num_steps", type=int, default=128,
18
+ help="Number of denoising steps")
19
+ parser.add_argument(
20
+ "--seq_length", type=int, default=1024,
21
+ help="Sequence length")
22
+ parser.add_argument(
23
+ "--seed", type=int, default=42,
24
+ help="Random seed")
25
+ args = parser.parse_args()
26
+
27
+ # Set seed for reproducibility
28
+ torch.manual_seed(args.seed)
29
+ if torch.cuda.is_available():
30
+ torch.cuda.manual_seed_all(args.seed)
31
+
32
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
+ print(f"Using device: {device}")
34
+
35
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
36
+ model = AutoModelForMaskedLM.from_pretrained(
37
+ args.model_path,
38
+ trust_remote_code=True
39
+ )
40
+ model = model.to(device)
41
+ model.eval()
42
+
43
+ print(f"\nGenerating {args.num_samples} samples with {args.num_steps} steps...")
44
+ with torch.no_grad():
45
+ samples = model.generate_samples(
46
+ num_samples=args.num_samples,
47
+ seq_length=args.seq_length,
48
+ num_steps=args.num_steps,
49
+ device=device
50
+ )
51
+
52
+ texts = tokenizer.batch_decode(samples, skip_special_tokens=True)
53
+ for i, text in enumerate(texts):
54
+ print(f"\n--- Sample {i+1} ---")
55
+ # Print first 500 characters to keep output manageable
56
+ print(text[:500] + ("..." if len(text) > 500 else ""))
57
+
58
+
59
+ if __name__ == "__main__":
60
+ main()
model.py ADDED
@@ -0,0 +1,551 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """HuggingFace model implementation for LangFlow.
2
+
3
+ LangFlow is a continuous diffusion language model that operates in embedding space.
4
+ """
5
+
6
+ import math
7
+ import typing
8
+
9
+ import einops
10
+ import flash_attn
11
+ import flash_attn.layers.rotary
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ import transformers
16
+
17
+ from .config import LangFlowConfig
18
+
19
+
20
+ # Flags required to enable jit fusion kernels
21
+ torch._C._jit_set_profiling_mode(False)
22
+ torch._C._jit_set_profiling_executor(False)
23
+ torch._C._jit_override_can_fuse_on_cpu(True)
24
+ torch._C._jit_override_can_fuse_on_gpu(True)
25
+
26
+
27
+ def bias_dropout_add_scale(
28
+ x: torch.Tensor,
29
+ bias: typing.Optional[torch.Tensor],
30
+ scale: torch.Tensor,
31
+ residual: typing.Optional[torch.Tensor],
32
+ prob: float,
33
+ training: bool) -> torch.Tensor:
34
+ if bias is not None:
35
+ out = scale * F.dropout(x + bias, p=prob, training=training)
36
+ else:
37
+ out = scale * F.dropout(x, p=prob, training=training)
38
+
39
+ if residual is not None:
40
+ out = residual + out
41
+ return out
42
+
43
+
44
+ @torch.jit.script
45
+ def bias_dropout_add_scale_fused_train(
46
+ x: torch.Tensor,
47
+ bias: typing.Optional[torch.Tensor],
48
+ scale: torch.Tensor,
49
+ residual: typing.Optional[torch.Tensor],
50
+ prob: float) -> torch.Tensor:
51
+ return bias_dropout_add_scale(x, bias, scale, residual, prob, True)
52
+
53
+
54
+ @torch.jit.script
55
+ def bias_dropout_add_scale_fused_inference(
56
+ x: torch.Tensor,
57
+ bias: typing.Optional[torch.Tensor],
58
+ scale: torch.Tensor,
59
+ residual: typing.Optional[torch.Tensor],
60
+ prob: float) -> torch.Tensor:
61
+ return bias_dropout_add_scale(x, bias, scale, residual, prob, False)
62
+
63
+
64
+ @torch.jit.script
65
+ def modulate_fused(x: torch.Tensor,
66
+ shift: torch.Tensor,
67
+ scale: torch.Tensor) -> torch.Tensor:
68
+ return x * (1 + scale) + shift
69
+
70
+
71
+ class Rotary(nn.Module):
72
+ def __init__(self, dim, base=10_000):
73
+ super().__init__()
74
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
75
+ self.register_buffer('inv_freq', inv_freq)
76
+ self.seq_len_cached = None
77
+ self.cos_cached = None
78
+ self.sin_cached = None
79
+
80
+ def forward(self, x, seq_dim=1):
81
+ seq_len = x.shape[seq_dim]
82
+ if seq_len != self.seq_len_cached:
83
+ self.seq_len_cached = seq_len
84
+ t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)
85
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq.clone())
86
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
87
+ self.cos_cached = emb.cos()[None, :, None, None, :].repeat(1, 1, 3, 1, 1)
88
+ self.sin_cached = emb.sin()[None, :, None, None, :].repeat(1, 1, 3, 1, 1)
89
+ self.cos_cached[:, :, 2, :, :].fill_(1.)
90
+ self.sin_cached[:, :, 2, :, :].fill_(0.)
91
+ return self.cos_cached, self.sin_cached
92
+
93
+
94
+ def split_and_apply_rotary_pos_emb(qkv, rotary_cos_sin):
95
+ with torch.autocast(device_type='cuda', enabled=False):
96
+ cos, sin = rotary_cos_sin
97
+ cos = cos.to(qkv.dtype)
98
+ sin = sin.to(qkv.dtype)
99
+ cos = cos[0, :, 0, 0, :cos.shape[-1]//2]
100
+ sin = sin[0, :, 0, 0, :sin.shape[-1]//2]
101
+ q, k, v = qkv.chunk(3, dim=2)
102
+ q = flash_attn.layers.rotary.apply_rotary_emb_torch(
103
+ q.squeeze(dim=2), cos, sin)
104
+ k = flash_attn.layers.rotary.apply_rotary_emb_torch(
105
+ k.squeeze(dim=2), cos, sin)
106
+ v = v.squeeze(dim=2)
107
+ return q, k, v
108
+
109
+
110
+ def regular_attention_multi_headed(q, k, v):
111
+ attention_output = F.scaled_dot_product_attention(
112
+ query=q.transpose(1, 2),
113
+ key=k.transpose(1, 2),
114
+ value=v.transpose(1, 2),
115
+ attn_mask=None,
116
+ dropout_p=0.0,
117
+ is_causal=False)
118
+ attention_output = attention_output.transpose(1, 2)
119
+ return einops.rearrange(attention_output, 'b s h d -> b s (h d)')
120
+
121
+
122
+ class LayerNorm(nn.Module):
123
+ def __init__(self, dim):
124
+ super().__init__()
125
+ self.weight = nn.Parameter(torch.ones([dim]))
126
+ self.dim = dim
127
+
128
+ def forward(self, x):
129
+ with torch.autocast(device_type='cuda', enabled=False):
130
+ x = F.layer_norm(x.float(), [self.dim])
131
+ return x * self.weight[None, None, :]
132
+
133
+
134
+ class TimestepEmbedder(nn.Module):
135
+ """Embeds scalar timesteps into vector representations."""
136
+
137
+ def __init__(self, hidden_size, frequency_embedding_size=256):
138
+ super().__init__()
139
+ self.mlp = nn.Sequential(
140
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
141
+ nn.SiLU(),
142
+ nn.Linear(hidden_size, hidden_size, bias=True))
143
+ self.frequency_embedding_size = frequency_embedding_size
144
+
145
+ @staticmethod
146
+ def timestep_embedding(t, dim, max_period=10000):
147
+ half = dim // 2
148
+ freqs = torch.exp(
149
+ -math.log(max_period)
150
+ * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device)
151
+ / half)
152
+ args = t[:, None].float() * freqs[None]
153
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
154
+ if dim % 2:
155
+ embedding = torch.cat(
156
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
157
+ return embedding
158
+
159
+ def forward(self, t):
160
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
161
+ t_emb = self.mlp(t_freq)
162
+ return t_emb
163
+
164
+
165
+ class DDiTBlock(nn.Module):
166
+ def __init__(self, dim, n_heads, cond_dim, mlp_ratio=4, dropout=0.1):
167
+ super().__init__()
168
+ self.n_heads = n_heads
169
+
170
+ self.norm1 = LayerNorm(dim)
171
+ self.attn_qkv = nn.Linear(dim, 3 * dim, bias=False)
172
+ self.attn_out = nn.Linear(dim, dim, bias=False)
173
+
174
+ self.norm2 = LayerNorm(dim)
175
+ self.mlp = nn.Sequential(
176
+ nn.Linear(dim, mlp_ratio * dim, bias=True),
177
+ nn.GELU(approximate='tanh'),
178
+ nn.Linear(mlp_ratio * dim, dim, bias=True))
179
+ self.dropout = dropout
180
+
181
+ self.adaLN_modulation = nn.Linear(cond_dim, 6 * dim)
182
+ self.adaLN_modulation.weight.data.zero_()
183
+ self.adaLN_modulation.bias.data.zero_()
184
+
185
+ def _get_bias_dropout_scale(self):
186
+ if self.training:
187
+ return bias_dropout_add_scale_fused_train
188
+ else:
189
+ return bias_dropout_add_scale_fused_inference
190
+
191
+ def forward(self, x, rotary_cos_sin, c):
192
+ bias_dropout_scale_fn = self._get_bias_dropout_scale()
193
+
194
+ x_skip = x
195
+ x = self.norm1(x)
196
+
197
+ (shift_msa, scale_msa, gate_msa, shift_mlp,
198
+ scale_mlp, gate_mlp) = self.adaLN_modulation(c)[:, None].chunk(6, dim=2)
199
+ x = modulate_fused(x, shift_msa, scale_msa)
200
+
201
+ qkv = einops.rearrange(
202
+ self.attn_qkv(x),
203
+ 'b s (three h d) -> b s three h d',
204
+ three=3,
205
+ h=self.n_heads)
206
+ q, k, v = split_and_apply_rotary_pos_emb(qkv, rotary_cos_sin)
207
+ x = regular_attention_multi_headed(q, k, v)
208
+
209
+ x = bias_dropout_scale_fn(self.attn_out(x), None, gate_msa, x_skip, self.dropout)
210
+ x = bias_dropout_scale_fn(
211
+ self.mlp(modulate_fused(self.norm2(x), shift_mlp, scale_mlp)),
212
+ None, gate_mlp, x, self.dropout)
213
+ return x
214
+
215
+
216
+ def _normalize_embedding_layernorm(weight: torch.Tensor) -> torch.Tensor:
217
+ """Normalize embedding weights to unit norm per row, then scale by sqrt(dim)."""
218
+ normalized = F.normalize(weight.float(), dim=-1)
219
+ return (normalized * math.sqrt(weight.shape[-1])).to(weight.dtype)
220
+
221
+
222
+ class EmbeddingLayer(nn.Module):
223
+ """Embedding layer with optional layernorm normalization."""
224
+
225
+ def __init__(self, dim, vocab_dim, use_normalized_embedding=True):
226
+ super().__init__()
227
+ self.dim = dim
228
+ self.vocab_dim = vocab_dim
229
+ self.use_normalized_embedding = use_normalized_embedding
230
+ self.embedding = nn.Parameter(torch.empty((vocab_dim, dim)))
231
+ nn.init.kaiming_uniform_(self.embedding, a=math.sqrt(5))
232
+
233
+ def _get_embedding(self):
234
+ if self.use_normalized_embedding:
235
+ return _normalize_embedding_layernorm(self.embedding)
236
+ return self.embedding
237
+
238
+ def forward(self, x):
239
+ embedding = self._get_embedding()
240
+ if x.ndim == 2:
241
+ return embedding[x]
242
+ assert x.ndim == 3 # probabilities
243
+ return torch.einsum("blv,ve->ble", x.float(), embedding.float()).to(x.dtype)
244
+
245
+
246
+ class DDiTFinalLayer(nn.Module):
247
+ def __init__(self, hidden_size, out_channels, cond_dim):
248
+ super().__init__()
249
+ self.norm_final = LayerNorm(hidden_size)
250
+ self.linear = nn.Linear(hidden_size, out_channels)
251
+ self.linear.weight.data.zero_()
252
+ self.linear.bias.data.zero_()
253
+ self.adaLN_modulation = nn.Linear(cond_dim, 2 * hidden_size, bias=True)
254
+ self.adaLN_modulation.weight.data.zero_()
255
+ self.adaLN_modulation.bias.data.zero_()
256
+
257
+ def forward(self, x, c):
258
+ x = self.norm_final(x)
259
+ shift, scale = self.adaLN_modulation(c)[:, None].chunk(2, dim=2)
260
+ x = modulate_fused(x, shift, scale)
261
+ x = self.linear(x)
262
+ return x
263
+
264
+
265
+ class GumbelProposal(nn.Module):
266
+ """Learnable Gumbel distribution proposal for sampling gamma (log-SNR)."""
267
+
268
+ def __init__(self, loc: float = 4.723, scale: float = 0.852,
269
+ cutoff: float = 1e-5, entropy: float = 7.02):
270
+ super().__init__()
271
+ self.loc = nn.Parameter(torch.tensor(loc))
272
+ self.scale = nn.Parameter(torch.tensor(scale))
273
+ self.cutoff = cutoff
274
+ self.entropy = nn.Parameter(torch.tensor(entropy))
275
+
276
+ def _get_distribution(self) -> torch.distributions.Gumbel:
277
+ return torch.distributions.Gumbel(self.loc, self.scale)
278
+
279
+ @property
280
+ def gamma_min(self) -> float:
281
+ return float(self.loc - math.log(-math.log(self.cutoff)) * self.scale)
282
+
283
+ @property
284
+ def gamma_max(self) -> float:
285
+ return float(self.loc - math.log(self.cutoff) * self.scale)
286
+
287
+ def forward(self, q: torch.Tensor) -> torch.Tensor:
288
+ """Convert uniform samples to gamma values via inverse CDF."""
289
+ gamma = self._get_distribution().icdf(q)
290
+ return gamma.clamp(min=self.gamma_min, max=self.gamma_max)
291
+
292
+ def log_pdf(self, gamma: torch.Tensor) -> torch.Tensor:
293
+ """Compute log probability density at gamma."""
294
+ return self._get_distribution().log_prob(gamma)
295
+
296
+
297
+ class LangFlowBackbone(nn.Module):
298
+ """DiT backbone for LangFlow."""
299
+
300
+ def __init__(self, config: LangFlowConfig):
301
+ super().__init__()
302
+ self.config = config
303
+ dim = config.hidden_size
304
+ cond_dim = config.cond_dim
305
+
306
+ self.vocab_embed = EmbeddingLayer(
307
+ dim, config.vocab_size,
308
+ use_normalized_embedding=config.use_normalized_embedding)
309
+ self.sigma_map = TimestepEmbedder(cond_dim)
310
+ self.rotary_emb = Rotary(dim // config.n_heads)
311
+
312
+ self.blocks = nn.ModuleList([
313
+ DDiTBlock(dim=dim, n_heads=config.n_heads, cond_dim=cond_dim, dropout=config.dropout)
314
+ for _ in range(config.n_blocks)
315
+ ])
316
+
317
+ self.output_layer = DDiTFinalLayer(
318
+ hidden_size=dim, out_channels=config.vocab_size, cond_dim=cond_dim)
319
+
320
+ # Self-conditioning projection
321
+ if config.self_conditioning:
322
+ self.self_cond_proj = nn.Linear(dim * 2, dim, bias=False)
323
+ nn.init.zeros_(self.self_cond_proj.weight)
324
+
325
+ def forward(self, x_embed, sigma, x_self_cond=None, output_hidden_states=False):
326
+ """Forward pass from embeddings.
327
+
328
+ Args:
329
+ x_embed: [B, L, D] - Input embeddings (possibly noisy)
330
+ sigma: [B] - Gamma values (log-SNR)
331
+ x_self_cond: [B, L, D] - Self-conditioning embeddings (optional)
332
+ output_hidden_states: Whether to return all hidden states
333
+
334
+ Returns:
335
+ logits: [B, L, vocab_size]
336
+ hidden_states: List of hidden states if output_hidden_states=True
337
+ """
338
+ all_hidden_states = []
339
+ x = x_embed
340
+
341
+ if output_hidden_states:
342
+ all_hidden_states.append(x)
343
+
344
+ # Self-conditioning
345
+ if self.config.self_conditioning:
346
+ if x_self_cond is None:
347
+ x_self_cond = torch.zeros_like(x)
348
+ x = x + self.self_cond_proj(torch.cat([x, x_self_cond], dim=-1))
349
+
350
+ t_cond = F.silu(self.sigma_map(sigma))
351
+ rotary_cos_sin = self.rotary_emb(x)
352
+
353
+ with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
354
+ for block in self.blocks:
355
+ x = block(x, rotary_cos_sin, c=t_cond)
356
+ if output_hidden_states:
357
+ all_hidden_states.append(x)
358
+ x = self.output_layer(x, c=t_cond)
359
+
360
+ return x, all_hidden_states
361
+
362
+
363
+ class LangFlow(transformers.PreTrainedModel):
364
+ """HuggingFace-compatible LangFlow model.
365
+
366
+ LangFlow is a continuous diffusion language model that operates in embedding space.
367
+ It uses a DiT (Diffusion Transformer) backbone with:
368
+ - Self-conditioning: uses previous predictions as additional input
369
+ - Bias (preconditioning): skip connection for improved generation
370
+ - Normalized embeddings: layernorm on embedding vectors
371
+ - Learnable Gumbel proposal for gamma (log-SNR) sampling
372
+ """
373
+ config_class = LangFlowConfig
374
+ base_model_prefix = "langflow"
375
+
376
+ def __init__(self, config: LangFlowConfig):
377
+ super().__init__(config)
378
+ self.config = config
379
+ self.backbone = LangFlowBackbone(config)
380
+ self.proposal = GumbelProposal(
381
+ loc=config.gumbel_loc,
382
+ scale=config.gumbel_scale,
383
+ cutoff=config.gumbel_cutoff,
384
+ entropy=config.gumbel_entropy)
385
+
386
+ def _get_embedding_matrix(self) -> torch.Tensor:
387
+ """Get the embedding matrix for bias skip connection."""
388
+ return self.backbone.vocab_embed._get_embedding()
389
+
390
+ def _embed_tokens(self, x: torch.Tensor) -> torch.Tensor:
391
+ """Embed tokens or probabilities to continuous embeddings."""
392
+ return self.backbone.vocab_embed(x)
393
+
394
+ def _forward_diffusion(self, x_embed: torch.Tensor,
395
+ gamma: torch.Tensor) -> torch.Tensor:
396
+ """Add noise to embeddings (forward diffusion process)."""
397
+ gamma = gamma.float()
398
+ alpha = torch.sigmoid(-gamma).sqrt()[:, None, None]
399
+ sigma = torch.sigmoid(gamma).sqrt()[:, None, None]
400
+ noise = torch.randn_like(x_embed)
401
+ return (x_embed * alpha + noise * sigma).to(x_embed.dtype)
402
+
403
+ def _euler_edm_step(self, z: torch.Tensor, x_pred: torch.Tensor,
404
+ t: torch.Tensor, s: torch.Tensor) -> torch.Tensor:
405
+ """Single Euler step for EDM sampling."""
406
+ t_ = t.double()
407
+ s_ = s.double()
408
+ cur = z.double() * ((F.softplus(t_) - F.softplus(s_)) / 2).exp()
409
+ end = torch.sigmoid(-s_).sqrt() * x_pred.double()
410
+ z = end.lerp(cur, ((s_ - t_) / 2).exp()).to(z.dtype)
411
+ return z
412
+
413
+ def forward(
414
+ self,
415
+ input_ids: typing.Optional[torch.LongTensor] = None,
416
+ noisy_embeds: typing.Optional[torch.FloatTensor] = None,
417
+ timesteps: typing.Optional[torch.FloatTensor] = None,
418
+ x_self_cond: typing.Optional[torch.FloatTensor] = None,
419
+ output_hidden_states: typing.Optional[bool] = None,
420
+ return_dict: typing.Optional[bool] = None,
421
+ ) -> typing.Union[torch.Tensor, typing.Tuple, transformers.modeling_outputs.MaskedLMOutput]:
422
+ """Forward pass for LangFlow.
423
+
424
+ Args:
425
+ input_ids: [B, L] - Token IDs (will be embedded and noised if timesteps provided)
426
+ noisy_embeds: [B, L, D] - Pre-noised embeddings (alternative to input_ids)
427
+ timesteps: [B] - Gamma values (log-SNR) for conditioning
428
+ x_self_cond: [B, L, D] - Self-conditioning embeddings
429
+ output_hidden_states: Whether to return hidden states
430
+ return_dict: Whether to return MaskedLMOutput
431
+
432
+ Returns:
433
+ logits or MaskedLMOutput
434
+ """
435
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else False
436
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
437
+
438
+ # Get embeddings
439
+ if noisy_embeds is not None:
440
+ z = noisy_embeds
441
+ elif input_ids is not None:
442
+ x_embed = self._embed_tokens(input_ids)
443
+ if timesteps is not None:
444
+ z = self._forward_diffusion(x_embed, timesteps)
445
+ else:
446
+ z = x_embed
447
+ else:
448
+ raise ValueError("Either input_ids or noisy_embeds must be provided")
449
+
450
+ if timesteps is None:
451
+ # Use minimum gamma for clean input
452
+ timesteps = torch.full((z.shape[0],), self.proposal.gamma_min, device=z.device)
453
+
454
+ # Process sigma
455
+ sigma = timesteps
456
+ if sigma.ndim == 2:
457
+ sigma = sigma.mean(-1)
458
+
459
+ # Get model output
460
+ logits, all_hidden_states = self.backbone(
461
+ z, sigma, x_self_cond=x_self_cond, output_hidden_states=output_hidden_states)
462
+
463
+ # Add bias (preconditioning) skip connection
464
+ if self.config.use_bias:
465
+ c_skip = ((F.softplus(-sigma) - sigma) / 2).exp()
466
+ embedding = self._get_embedding_matrix()
467
+ skip_logits = torch.matmul(z.float(), embedding.t().float())
468
+ logits = logits + c_skip[:, None, None] * skip_logits.to(logits.dtype)
469
+
470
+ if return_dict:
471
+ return transformers.modeling_outputs.MaskedLMOutput(
472
+ logits=logits,
473
+ hidden_states=all_hidden_states if output_hidden_states else None,
474
+ loss=None)
475
+ elif output_hidden_states:
476
+ return logits, all_hidden_states
477
+ else:
478
+ return logits
479
+
480
+ @torch.no_grad()
481
+ def generate_samples(
482
+ self,
483
+ num_samples: int = 1,
484
+ seq_length: typing.Optional[int] = None,
485
+ num_steps: int = 128,
486
+ device: typing.Optional[torch.device] = None,
487
+ ) -> torch.LongTensor:
488
+ """Generate samples using Euler-EDM solver.
489
+
490
+ Args:
491
+ num_samples: Number of samples to generate
492
+ seq_length: Sequence length (defaults to config.model_length)
493
+ num_steps: Number of denoising steps
494
+ device: Device to generate on
495
+
496
+ Returns:
497
+ samples: [num_samples, seq_length] - Generated token IDs
498
+ """
499
+ if seq_length is None:
500
+ seq_length = self.config.model_length
501
+ if device is None:
502
+ device = next(self.parameters()).device
503
+
504
+ embed_dim = self.config.hidden_size
505
+ eps = 1e-5
506
+
507
+ # Initialize with Gaussian noise
508
+ z = torch.randn(num_samples, seq_length, embed_dim, device=device)
509
+
510
+ # Create gamma schedule from t=1-eps to t=eps
511
+ t = torch.linspace(1.0 - eps, eps, num_steps, device=device)
512
+ gamma = self.proposal(t)
513
+
514
+ # Self-conditioning state
515
+ x_self_cond = None
516
+
517
+ # Euler-EDM sampling loop
518
+ for i in range(len(gamma) - 1):
519
+ gamma_t = gamma[i]
520
+ gamma_s = gamma[i + 1]
521
+
522
+ # Get model prediction
523
+ gamma_expanded = gamma_t.unsqueeze(0).expand(num_samples)
524
+ logits = self.forward(
525
+ noisy_embeds=z,
526
+ timesteps=gamma_expanded,
527
+ x_self_cond=x_self_cond,
528
+ return_dict=False)
529
+
530
+ # Convert logits to embedding prediction
531
+ probs = F.softmax(logits.float(), dim=-1)
532
+ x_pred = self._embed_tokens(probs)
533
+
534
+ # Update self-conditioning
535
+ if self.config.self_conditioning:
536
+ x_self_cond = x_pred
537
+
538
+ # Euler step
539
+ z = self._euler_edm_step(z, x_pred, gamma_t, gamma_s)
540
+
541
+ # Final step: get logits and take argmax
542
+ gamma_final = gamma[-1]
543
+ gamma_expanded = gamma_final.unsqueeze(0).expand(num_samples)
544
+ logits = self.forward(
545
+ noisy_embeds=z,
546
+ timesteps=gamma_expanded,
547
+ x_self_cond=x_self_cond,
548
+ return_dict=False)
549
+ samples = logits.argmax(dim=-1)
550
+
551
+ return samples
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9dcc2573297d9f0ad91cb08d3469d141817d7bd3b11a2a7f332ff0f6a58e02a1
3
+ size 683235528