nealchen commited on
Commit
ff393be
·
verified ·
1 Parent(s): 8d8169e

Upload folder using huggingface_hub

Browse files
README.md ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ datasets:
3
+ - dvruette/lm1b
4
+ papers:
5
+ - arxiv: 2604.11748
6
+ language:
7
+ - en
8
+ library_name: transformers
9
+ license: apache-2.0
10
+ metrics:
11
+ - perplexity
12
+ pipeline_tag: text-generation
13
+ ---
14
+
15
+ # LangFlow
16
+
17
+ LangFlow is a continuous diffusion language model that operates in embedding space. Unlike discrete diffusion models (MDLM, SEDD, DUO), LangFlow performs diffusion directly on continuous token embeddings, enabling smoother denoising dynamics.
18
+
19
+ For more details, please see our paper: [LangFlow: Continuous Diffusion Rivals Discrete in Language Modeling](https://arxiv.org/abs/2604.11748).
20
+
21
+
22
+ ## Using LangFlow
23
+
24
+ To use the pre-trained model for text generation, use the following snippet:
25
+
26
+ ```python
27
+ from transformers import AutoModelForMaskedLM, AutoTokenizer
28
+
29
+ tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
30
+ model = AutoModelForMaskedLM.from_pretrained('Continuous-Rivals-Discrete/langflow-owt', trust_remote_code=True)
31
+
32
+ # Generate samples
33
+ samples = model.generate_samples(num_samples=5, num_steps=128)
34
+ texts = tokenizer.batch_decode(samples, skip_special_tokens=True)
35
+ for text in texts:
36
+ print(text)
37
+ ```
38
+
39
+ ## Model Details
40
+
41
+ - **Architecture**: DiT (Diffusion Transformer) backbone with adaptive layer normalization
42
+ - **Context Length**: 128 tokens
43
+ - **Parameters**: ~130M parameters (similar to GPT-2 small)
44
+ - **Training**: 1M steps on LM1B corpus
45
+ - **Tokenizer**: bert-base-uncased tokenizer (30,522 vocab size)
46
+
47
+ ## Citation
48
+
49
+ ```
50
+ @article{chen2026langflow,
51
+ title={LangFlow: Continuous Diffusion Rivals Discrete in Language Modeling},
52
+ author={Chen, Yuxin and Liang, Chumeng and Sui, Hangke and Guo, Ruihan and Cheng, Chaoran and You, Jiaxuan and Liu, Ge},
53
+ journal={arXiv preprint arXiv:2604.11748},
54
+ year={2026}
55
+ }
56
+ ```
57
+
58
+ ## Model Card Contact
59
+
60
+ Chumeng Liang (chumengl@illinois.edu)
61
+
__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ """HuggingFace release package for LangFlow."""
2
+
3
+ from .config import LangFlowConfig
4
+ from .model import LangFlow
5
+
6
+ __all__ = ["LangFlowConfig", "LangFlow"]
__pycache__/config.cpython-313.pyc ADDED
Binary file (2.47 kB). View file
 
__pycache__/model.cpython-313.pyc ADDED
Binary file (34 kB). View file
 
config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "nealchen/langflow-lm1b",
3
+ "architectures": [
4
+ "LangFlow"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "config.LangFlowConfig",
8
+ "AutoModelForMaskedLM": "model.LangFlow"
9
+ },
10
+ "model_type": "LangFlow",
11
+ "vocab_size": 30522,
12
+ "hidden_size": 768,
13
+ "cond_dim": 128,
14
+ "n_blocks": 12,
15
+ "n_heads": 12,
16
+ "dropout": 0.1,
17
+ "model_length": 128,
18
+ "use_normalized_embedding": true,
19
+ "embedding_norm_method": "layernorm",
20
+ "self_conditioning": true,
21
+ "use_bias": true,
22
+ "gumbel_loc": 4.723,
23
+ "gumbel_scale": 0.852,
24
+ "gumbel_cutoff": 1e-5,
25
+ "gumbel_entropy": 7.02,
26
+ "return_dict": true,
27
+ "torch_dtype": "float32",
28
+ "transformers_version": "4.38.2"
29
+ }
config.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """HuggingFace configuration class for LangFlow."""
2
+
3
+ import transformers
4
+
5
+
6
+ class LangFlowConfig(transformers.PretrainedConfig):
7
+ """HuggingFace configuration class for LangFlow.
8
+
9
+ LangFlow is a continuous diffusion language model that operates in embedding space.
10
+ It uses a DiT (Diffusion Transformer) backbone with adaptive layer normalization.
11
+
12
+ Key features:
13
+ - Continuous diffusion in embedding space
14
+ - Self-conditioning: uses previous predictions as additional input
15
+ - Bias (preconditioning): skip connection for improved training
16
+ - Normalized embeddings: layernorm on embedding vectors
17
+ - Learnable Gumbel proposal for gamma (log-SNR) sampling
18
+ """
19
+ model_type = "LangFlow"
20
+
21
+ def __init__(
22
+ self,
23
+ vocab_size: int = 30522,
24
+ hidden_size: int = 768,
25
+ cond_dim: int = 128,
26
+ n_blocks: int = 12,
27
+ n_heads: int = 12,
28
+ dropout: float = 0.1,
29
+ model_length: int = 128,
30
+ # Embedding normalization
31
+ use_normalized_embedding: bool = True,
32
+ embedding_norm_method: str = "layernorm",
33
+ # Self-conditioning
34
+ self_conditioning: bool = True,
35
+ # Bias (preconditioning) - always enabled for inference
36
+ use_bias: bool = True,
37
+ # Gumbel proposal parameters (learnable)
38
+ gumbel_loc: float = 4.723,
39
+ gumbel_scale: float = 0.852,
40
+ gumbel_cutoff: float = 1e-5,
41
+ gumbel_entropy: float = 7.02,
42
+ **kwargs
43
+ ):
44
+ super().__init__(**kwargs)
45
+ self.vocab_size = vocab_size
46
+ self.hidden_size = hidden_size
47
+ self.cond_dim = cond_dim
48
+ self.n_blocks = n_blocks
49
+ self.n_heads = n_heads
50
+ self.dropout = dropout
51
+ self.model_length = model_length
52
+ # Embedding normalization
53
+ self.use_normalized_embedding = use_normalized_embedding
54
+ self.embedding_norm_method = embedding_norm_method
55
+ # Self-conditioning
56
+ self.self_conditioning = self_conditioning
57
+ # Bias (preconditioning)
58
+ self.use_bias = use_bias
59
+ # Gumbel proposal parameters
60
+ self.gumbel_loc = gumbel_loc
61
+ self.gumbel_scale = gumbel_scale
62
+ self.gumbel_cutoff = gumbel_cutoff
63
+ self.gumbel_entropy = gumbel_entropy
model.py ADDED
@@ -0,0 +1,560 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """HuggingFace model implementation for LangFlow.
2
+
3
+ LangFlow is a continuous diffusion language model that operates in embedding space.
4
+ """
5
+
6
+ import math
7
+ import typing
8
+
9
+ import einops
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ import transformers
14
+
15
+ from .config import LangFlowConfig
16
+
17
+
18
+ # Flags required to enable jit fusion kernels
19
+ torch._C._jit_set_profiling_mode(False)
20
+ torch._C._jit_set_profiling_executor(False)
21
+ torch._C._jit_override_can_fuse_on_cpu(True)
22
+ torch._C._jit_override_can_fuse_on_gpu(True)
23
+
24
+
25
+ def bias_dropout_add_scale(
26
+ x: torch.Tensor,
27
+ bias: typing.Optional[torch.Tensor],
28
+ scale: torch.Tensor,
29
+ residual: typing.Optional[torch.Tensor],
30
+ prob: float,
31
+ training: bool) -> torch.Tensor:
32
+ if bias is not None:
33
+ out = scale * F.dropout(x + bias, p=prob, training=training)
34
+ else:
35
+ out = scale * F.dropout(x, p=prob, training=training)
36
+
37
+ if residual is not None:
38
+ out = residual + out
39
+ return out
40
+
41
+
42
+ @torch.jit.script
43
+ def bias_dropout_add_scale_fused_train(
44
+ x: torch.Tensor,
45
+ bias: typing.Optional[torch.Tensor],
46
+ scale: torch.Tensor,
47
+ residual: typing.Optional[torch.Tensor],
48
+ prob: float) -> torch.Tensor:
49
+ return bias_dropout_add_scale(x, bias, scale, residual, prob, True)
50
+
51
+
52
+ @torch.jit.script
53
+ def bias_dropout_add_scale_fused_inference(
54
+ x: torch.Tensor,
55
+ bias: typing.Optional[torch.Tensor],
56
+ scale: torch.Tensor,
57
+ residual: typing.Optional[torch.Tensor],
58
+ prob: float) -> torch.Tensor:
59
+ return bias_dropout_add_scale(x, bias, scale, residual, prob, False)
60
+
61
+
62
+ @torch.jit.script
63
+ def modulate_fused(x: torch.Tensor,
64
+ shift: torch.Tensor,
65
+ scale: torch.Tensor) -> torch.Tensor:
66
+ return x * (1 + scale) + shift
67
+
68
+
69
+ class Rotary(nn.Module):
70
+ def __init__(self, dim, base=10_000):
71
+ super().__init__()
72
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
73
+ self.register_buffer('inv_freq', inv_freq)
74
+ self.seq_len_cached = None
75
+ self.cos_cached = None
76
+ self.sin_cached = None
77
+
78
+ def forward(self, x, seq_dim=1):
79
+ seq_len = x.shape[seq_dim]
80
+ if seq_len != self.seq_len_cached:
81
+ self.seq_len_cached = seq_len
82
+ t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)
83
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq.clone())
84
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
85
+ self.cos_cached = emb.cos()[None, :, None, None, :].repeat(1, 1, 3, 1, 1)
86
+ self.sin_cached = emb.sin()[None, :, None, None, :].repeat(1, 1, 3, 1, 1)
87
+ self.cos_cached[:, :, 2, :, :].fill_(1.)
88
+ self.sin_cached[:, :, 2, :, :].fill_(0.)
89
+ return self.cos_cached, self.sin_cached
90
+
91
+
92
+ def _apply_rotary_emb(x, cos, sin):
93
+ # x: [batch, seqlen, nheads, headdim]
94
+ # cos, sin: [seqlen, headdim//2]
95
+ ro_dim = cos.shape[-1] * 2
96
+ # Expand to [1, seqlen, 1, ro_dim] for broadcasting
97
+ cos = torch.cat([cos, cos], dim=-1)[None, :, None, :]
98
+ sin = torch.cat([sin, sin], dim=-1)[None, :, None, :]
99
+ x_rot = x[..., :ro_dim]
100
+ x1, x2 = x_rot.chunk(2, dim=-1)
101
+ x_rotated = torch.cat([-x2, x1], dim=-1)
102
+ return torch.cat([x_rot * cos + x_rotated * sin, x[..., ro_dim:]], dim=-1)
103
+
104
+
105
+ def split_and_apply_rotary_pos_emb(qkv, rotary_cos_sin):
106
+ with torch.autocast(device_type='cuda', enabled=False):
107
+ cos, sin = rotary_cos_sin
108
+ cos = cos.to(qkv.dtype)
109
+ sin = sin.to(qkv.dtype)
110
+ cos = cos[0, :, 0, 0, :cos.shape[-1]//2]
111
+ sin = sin[0, :, 0, 0, :sin.shape[-1]//2]
112
+ q, k, v = qkv.chunk(3, dim=2)
113
+ q = _apply_rotary_emb(q.squeeze(dim=2), cos, sin)
114
+ k = _apply_rotary_emb(k.squeeze(dim=2), cos, sin)
115
+ v = v.squeeze(dim=2)
116
+ return q, k, v
117
+
118
+
119
+ def regular_attention_multi_headed(q, k, v):
120
+ attention_output = F.scaled_dot_product_attention(
121
+ query=q.transpose(1, 2),
122
+ key=k.transpose(1, 2),
123
+ value=v.transpose(1, 2),
124
+ attn_mask=None,
125
+ dropout_p=0.0,
126
+ is_causal=False)
127
+ attention_output = attention_output.transpose(1, 2)
128
+ return einops.rearrange(attention_output, 'b s h d -> b s (h d)')
129
+
130
+
131
+ class LayerNorm(nn.Module):
132
+ def __init__(self, dim):
133
+ super().__init__()
134
+ self.weight = nn.Parameter(torch.ones([dim]))
135
+ self.dim = dim
136
+
137
+ def forward(self, x):
138
+ with torch.autocast(device_type='cuda', enabled=False):
139
+ x = F.layer_norm(x.float(), [self.dim])
140
+ return x * self.weight[None, None, :]
141
+
142
+
143
+ class TimestepEmbedder(nn.Module):
144
+ """Embeds scalar timesteps into vector representations."""
145
+
146
+ def __init__(self, hidden_size, frequency_embedding_size=256):
147
+ super().__init__()
148
+ self.mlp = nn.Sequential(
149
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
150
+ nn.SiLU(),
151
+ nn.Linear(hidden_size, hidden_size, bias=True))
152
+ self.frequency_embedding_size = frequency_embedding_size
153
+
154
+ @staticmethod
155
+ def timestep_embedding(t, dim, max_period=10000):
156
+ half = dim // 2
157
+ freqs = torch.exp(
158
+ -math.log(max_period)
159
+ * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device)
160
+ / half)
161
+ args = t[:, None].float() * freqs[None]
162
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
163
+ if dim % 2:
164
+ embedding = torch.cat(
165
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
166
+ return embedding
167
+
168
+ def forward(self, t):
169
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
170
+ t_emb = self.mlp(t_freq)
171
+ return t_emb
172
+
173
+
174
+ class DDiTBlock(nn.Module):
175
+ def __init__(self, dim, n_heads, cond_dim, mlp_ratio=4, dropout=0.1):
176
+ super().__init__()
177
+ self.n_heads = n_heads
178
+
179
+ self.norm1 = LayerNorm(dim)
180
+ self.attn_qkv = nn.Linear(dim, 3 * dim, bias=False)
181
+ self.attn_out = nn.Linear(dim, dim, bias=False)
182
+
183
+ self.norm2 = LayerNorm(dim)
184
+ self.mlp = nn.Sequential(
185
+ nn.Linear(dim, mlp_ratio * dim, bias=True),
186
+ nn.GELU(approximate='tanh'),
187
+ nn.Linear(mlp_ratio * dim, dim, bias=True))
188
+ self.dropout = dropout
189
+
190
+ self.adaLN_modulation = nn.Linear(cond_dim, 6 * dim)
191
+ self.adaLN_modulation.weight.data.zero_()
192
+ self.adaLN_modulation.bias.data.zero_()
193
+
194
+ def _get_bias_dropout_scale(self):
195
+ if self.training:
196
+ return bias_dropout_add_scale_fused_train
197
+ else:
198
+ return bias_dropout_add_scale_fused_inference
199
+
200
+ def forward(self, x, rotary_cos_sin, c):
201
+ bias_dropout_scale_fn = self._get_bias_dropout_scale()
202
+
203
+ x_skip = x
204
+ x = self.norm1(x)
205
+
206
+ (shift_msa, scale_msa, gate_msa, shift_mlp,
207
+ scale_mlp, gate_mlp) = self.adaLN_modulation(c)[:, None].chunk(6, dim=2)
208
+ x = modulate_fused(x, shift_msa, scale_msa)
209
+
210
+ qkv = einops.rearrange(
211
+ self.attn_qkv(x),
212
+ 'b s (three h d) -> b s three h d',
213
+ three=3,
214
+ h=self.n_heads)
215
+ q, k, v = split_and_apply_rotary_pos_emb(qkv, rotary_cos_sin)
216
+ x = regular_attention_multi_headed(q, k, v)
217
+
218
+ x = bias_dropout_scale_fn(self.attn_out(x), None, gate_msa, x_skip, self.dropout)
219
+ x = bias_dropout_scale_fn(
220
+ self.mlp(modulate_fused(self.norm2(x), shift_mlp, scale_mlp)),
221
+ None, gate_mlp, x, self.dropout)
222
+ return x
223
+
224
+
225
+ def _normalize_embedding_layernorm(weight: torch.Tensor) -> torch.Tensor:
226
+ """Normalize embedding weights to unit norm per row, then scale by sqrt(dim)."""
227
+ normalized = F.normalize(weight.float(), dim=-1)
228
+ return (normalized * math.sqrt(weight.shape[-1])).to(weight.dtype)
229
+
230
+
231
+ class EmbeddingLayer(nn.Module):
232
+ """Embedding layer with optional layernorm normalization."""
233
+
234
+ def __init__(self, dim, vocab_dim, use_normalized_embedding=True):
235
+ super().__init__()
236
+ self.dim = dim
237
+ self.vocab_dim = vocab_dim
238
+ self.use_normalized_embedding = use_normalized_embedding
239
+ self.embedding = nn.Parameter(torch.empty((vocab_dim, dim)))
240
+ nn.init.kaiming_uniform_(self.embedding, a=math.sqrt(5))
241
+
242
+ def _get_embedding(self):
243
+ if self.use_normalized_embedding:
244
+ return _normalize_embedding_layernorm(self.embedding)
245
+ return self.embedding
246
+
247
+ def forward(self, x):
248
+ embedding = self._get_embedding()
249
+ if x.ndim == 2:
250
+ return embedding[x]
251
+ assert x.ndim == 3 # probabilities
252
+ return torch.einsum("blv,ve->ble", x.float(), embedding.float()).to(x.dtype)
253
+
254
+
255
+ class DDiTFinalLayer(nn.Module):
256
+ def __init__(self, hidden_size, out_channels, cond_dim):
257
+ super().__init__()
258
+ self.norm_final = LayerNorm(hidden_size)
259
+ self.linear = nn.Linear(hidden_size, out_channels)
260
+ self.linear.weight.data.zero_()
261
+ self.linear.bias.data.zero_()
262
+ self.adaLN_modulation = nn.Linear(cond_dim, 2 * hidden_size, bias=True)
263
+ self.adaLN_modulation.weight.data.zero_()
264
+ self.adaLN_modulation.bias.data.zero_()
265
+
266
+ def forward(self, x, c):
267
+ x = self.norm_final(x)
268
+ shift, scale = self.adaLN_modulation(c)[:, None].chunk(2, dim=2)
269
+ x = modulate_fused(x, shift, scale)
270
+ x = self.linear(x)
271
+ return x
272
+
273
+
274
+ class GumbelProposal(nn.Module):
275
+ """Learnable Gumbel distribution proposal for sampling gamma (log-SNR)."""
276
+
277
+ def __init__(self, loc: float = 4.723, scale: float = 0.852,
278
+ cutoff: float = 1e-5, entropy: float = 7.02):
279
+ super().__init__()
280
+ self.loc = nn.Parameter(torch.tensor(loc))
281
+ self.scale = nn.Parameter(torch.tensor(scale))
282
+ self.cutoff = cutoff
283
+ self.entropy = nn.Parameter(torch.tensor(entropy))
284
+
285
+ def _get_distribution(self) -> torch.distributions.Gumbel:
286
+ return torch.distributions.Gumbel(self.loc, self.scale)
287
+
288
+ @property
289
+ def gamma_min(self) -> float:
290
+ return float(self.loc - math.log(-math.log(self.cutoff)) * self.scale)
291
+
292
+ @property
293
+ def gamma_max(self) -> float:
294
+ return float(self.loc - math.log(self.cutoff) * self.scale)
295
+
296
+ def forward(self, q: torch.Tensor) -> torch.Tensor:
297
+ """Convert uniform samples to gamma values via inverse CDF."""
298
+ gamma = self._get_distribution().icdf(q)
299
+ return gamma.clamp(min=self.gamma_min, max=self.gamma_max)
300
+
301
+ def log_pdf(self, gamma: torch.Tensor) -> torch.Tensor:
302
+ """Compute log probability density at gamma."""
303
+ return self._get_distribution().log_prob(gamma)
304
+
305
+
306
+ class LangFlowBackbone(nn.Module):
307
+ """DiT backbone for LangFlow."""
308
+
309
+ def __init__(self, config: LangFlowConfig):
310
+ super().__init__()
311
+ self.config = config
312
+ dim = config.hidden_size
313
+ cond_dim = config.cond_dim
314
+
315
+ self.vocab_embed = EmbeddingLayer(
316
+ dim, config.vocab_size,
317
+ use_normalized_embedding=config.use_normalized_embedding)
318
+ self.sigma_map = TimestepEmbedder(cond_dim)
319
+ self.rotary_emb = Rotary(dim // config.n_heads)
320
+
321
+ self.blocks = nn.ModuleList([
322
+ DDiTBlock(dim=dim, n_heads=config.n_heads, cond_dim=cond_dim, dropout=config.dropout)
323
+ for _ in range(config.n_blocks)
324
+ ])
325
+
326
+ self.output_layer = DDiTFinalLayer(
327
+ hidden_size=dim, out_channels=config.vocab_size, cond_dim=cond_dim)
328
+
329
+ # Self-conditioning projection
330
+ if config.self_conditioning:
331
+ self.self_cond_proj = nn.Linear(dim * 2, dim, bias=False)
332
+ nn.init.zeros_(self.self_cond_proj.weight)
333
+
334
+ def forward(self, x_embed, sigma, x_self_cond=None, output_hidden_states=False):
335
+ """Forward pass from embeddings.
336
+
337
+ Args:
338
+ x_embed: [B, L, D] - Input embeddings (possibly noisy)
339
+ sigma: [B] - Gamma values (log-SNR)
340
+ x_self_cond: [B, L, D] - Self-conditioning embeddings (optional)
341
+ output_hidden_states: Whether to return all hidden states
342
+
343
+ Returns:
344
+ logits: [B, L, vocab_size]
345
+ hidden_states: List of hidden states if output_hidden_states=True
346
+ """
347
+ all_hidden_states = []
348
+ x = x_embed
349
+
350
+ if output_hidden_states:
351
+ all_hidden_states.append(x)
352
+
353
+ # Self-conditioning
354
+ if self.config.self_conditioning:
355
+ if x_self_cond is None:
356
+ x_self_cond = torch.zeros_like(x)
357
+ x = x + self.self_cond_proj(torch.cat([x, x_self_cond], dim=-1))
358
+
359
+ t_cond = F.silu(self.sigma_map(sigma))
360
+ rotary_cos_sin = self.rotary_emb(x)
361
+
362
+ with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
363
+ for block in self.blocks:
364
+ x = block(x, rotary_cos_sin, c=t_cond)
365
+ if output_hidden_states:
366
+ all_hidden_states.append(x)
367
+ x = self.output_layer(x, c=t_cond)
368
+
369
+ return x, all_hidden_states
370
+
371
+
372
+ class LangFlow(transformers.PreTrainedModel):
373
+ """HuggingFace-compatible LangFlow model.
374
+
375
+ LangFlow is a continuous diffusion language model that operates in embedding space.
376
+ It uses a DiT (Diffusion Transformer) backbone with:
377
+ - Self-conditioning: uses previous predictions as additional input
378
+ - Bias (preconditioning): skip connection for improved generation
379
+ - Normalized embeddings: layernorm on embedding vectors
380
+ - Learnable Gumbel proposal for gamma (log-SNR) sampling
381
+ """
382
+ config_class = LangFlowConfig
383
+ base_model_prefix = "langflow"
384
+
385
+ def __init__(self, config: LangFlowConfig):
386
+ super().__init__(config)
387
+ self.config = config
388
+ self.backbone = LangFlowBackbone(config)
389
+ self.proposal = GumbelProposal(
390
+ loc=config.gumbel_loc,
391
+ scale=config.gumbel_scale,
392
+ cutoff=config.gumbel_cutoff,
393
+ entropy=config.gumbel_entropy)
394
+
395
+ def _get_embedding_matrix(self) -> torch.Tensor:
396
+ """Get the embedding matrix for bias skip connection."""
397
+ return self.backbone.vocab_embed._get_embedding()
398
+
399
+ def _embed_tokens(self, x: torch.Tensor) -> torch.Tensor:
400
+ """Embed tokens or probabilities to continuous embeddings."""
401
+ return self.backbone.vocab_embed(x)
402
+
403
+ def _forward_diffusion(self, x_embed: torch.Tensor,
404
+ gamma: torch.Tensor) -> torch.Tensor:
405
+ """Add noise to embeddings (forward diffusion process)."""
406
+ gamma = gamma.float()
407
+ alpha = torch.sigmoid(-gamma).sqrt()[:, None, None]
408
+ sigma = torch.sigmoid(gamma).sqrt()[:, None, None]
409
+ noise = torch.randn_like(x_embed)
410
+ return (x_embed * alpha + noise * sigma).to(x_embed.dtype)
411
+
412
+ def _euler_edm_step(self, z: torch.Tensor, x_pred: torch.Tensor,
413
+ t: torch.Tensor, s: torch.Tensor) -> torch.Tensor:
414
+ """Single Euler step for EDM sampling."""
415
+ t_ = t.double()
416
+ s_ = s.double()
417
+ cur = z.double() * ((F.softplus(t_) - F.softplus(s_)) / 2).exp()
418
+ end = torch.sigmoid(-s_).sqrt() * x_pred.double()
419
+ z = end.lerp(cur, ((s_ - t_) / 2).exp()).to(z.dtype)
420
+ return z
421
+
422
+ def forward(
423
+ self,
424
+ input_ids: typing.Optional[torch.LongTensor] = None,
425
+ noisy_embeds: typing.Optional[torch.FloatTensor] = None,
426
+ timesteps: typing.Optional[torch.FloatTensor] = None,
427
+ x_self_cond: typing.Optional[torch.FloatTensor] = None,
428
+ output_hidden_states: typing.Optional[bool] = None,
429
+ return_dict: typing.Optional[bool] = None,
430
+ ) -> typing.Union[torch.Tensor, typing.Tuple, transformers.modeling_outputs.MaskedLMOutput]:
431
+ """Forward pass for LangFlow.
432
+
433
+ Args:
434
+ input_ids: [B, L] - Token IDs (will be embedded and noised if timesteps provided)
435
+ noisy_embeds: [B, L, D] - Pre-noised embeddings (alternative to input_ids)
436
+ timesteps: [B] - Gamma values (log-SNR) for conditioning
437
+ x_self_cond: [B, L, D] - Self-conditioning embeddings
438
+ output_hidden_states: Whether to return hidden states
439
+ return_dict: Whether to return MaskedLMOutput
440
+
441
+ Returns:
442
+ logits or MaskedLMOutput
443
+ """
444
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else False
445
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
446
+
447
+ # Get embeddings
448
+ if noisy_embeds is not None:
449
+ z = noisy_embeds
450
+ elif input_ids is not None:
451
+ x_embed = self._embed_tokens(input_ids)
452
+ if timesteps is not None:
453
+ z = self._forward_diffusion(x_embed, timesteps)
454
+ else:
455
+ z = x_embed
456
+ else:
457
+ raise ValueError("Either input_ids or noisy_embeds must be provided")
458
+
459
+ if timesteps is None:
460
+ # Use minimum gamma for clean input
461
+ timesteps = torch.full((z.shape[0],), self.proposal.gamma_min, device=z.device)
462
+
463
+ # Process sigma
464
+ sigma = timesteps
465
+ if sigma.ndim == 2:
466
+ sigma = sigma.mean(-1)
467
+
468
+ # Get model output
469
+ logits, all_hidden_states = self.backbone(
470
+ z, sigma, x_self_cond=x_self_cond, output_hidden_states=output_hidden_states)
471
+
472
+ # Add bias (preconditioning) skip connection
473
+ if self.config.use_bias:
474
+ c_skip = ((F.softplus(-sigma) - sigma) / 2).exp()
475
+ embedding = self._get_embedding_matrix()
476
+ skip_logits = torch.matmul(z.float(), embedding.t().float())
477
+ logits = logits + c_skip[:, None, None] * skip_logits.to(logits.dtype)
478
+
479
+ if return_dict:
480
+ return transformers.modeling_outputs.MaskedLMOutput(
481
+ logits=logits,
482
+ hidden_states=all_hidden_states if output_hidden_states else None,
483
+ loss=None)
484
+ elif output_hidden_states:
485
+ return logits, all_hidden_states
486
+ else:
487
+ return logits
488
+
489
+ @torch.no_grad()
490
+ def generate_samples(
491
+ self,
492
+ num_samples: int = 1,
493
+ seq_length: typing.Optional[int] = None,
494
+ num_steps: int = 128,
495
+ device: typing.Optional[torch.device] = None,
496
+ ) -> torch.LongTensor:
497
+ """Generate samples using Euler-EDM solver.
498
+
499
+ Args:
500
+ num_samples: Number of samples to generate
501
+ seq_length: Sequence length (defaults to config.model_length)
502
+ num_steps: Number of denoising steps
503
+ device: Device to generate on
504
+
505
+ Returns:
506
+ samples: [num_samples, seq_length] - Generated token IDs
507
+ """
508
+ if seq_length is None:
509
+ seq_length = self.config.model_length
510
+ if device is None:
511
+ device = next(self.parameters()).device
512
+
513
+ embed_dim = self.config.hidden_size
514
+ eps = 1e-5
515
+
516
+ # Initialize with Gaussian noise
517
+ z = torch.randn(num_samples, seq_length, embed_dim, device=device)
518
+
519
+ # Create gamma schedule from t=1-eps to t=eps
520
+ t = torch.linspace(1.0 - eps, eps, num_steps, device=device)
521
+ gamma = self.proposal(t)
522
+
523
+ # Self-conditioning state
524
+ x_self_cond = None
525
+
526
+ # Euler-EDM sampling loop
527
+ for i in range(len(gamma) - 1):
528
+ gamma_t = gamma[i]
529
+ gamma_s = gamma[i + 1]
530
+
531
+ # Get model prediction
532
+ gamma_expanded = gamma_t.unsqueeze(0).expand(num_samples)
533
+ logits = self.forward(
534
+ noisy_embeds=z,
535
+ timesteps=gamma_expanded,
536
+ x_self_cond=x_self_cond,
537
+ return_dict=False)
538
+
539
+ # Convert logits to embedding prediction
540
+ probs = F.softmax(logits.float(), dim=-1)
541
+ x_pred = self._embed_tokens(probs)
542
+
543
+ # Update self-conditioning
544
+ if self.config.self_conditioning:
545
+ x_self_cond = x_pred
546
+
547
+ # Euler step
548
+ z = self._euler_edm_step(z, x_pred, gamma_t, gamma_s)
549
+
550
+ # Final step: get logits and take argmax
551
+ gamma_final = gamma[-1]
552
+ gamma_expanded = gamma_final.unsqueeze(0).expand(num_samples)
553
+ logits = self.forward(
554
+ noisy_embeds=z,
555
+ timesteps=gamma_expanded,
556
+ x_self_cond=x_self_cond,
557
+ return_dict=False)
558
+ samples = logits.argmax(dim=-1)
559
+
560
+ return samples
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b111b5141379bde1bc8531278a2d144bc753cf3b100ed32c8a917f2ed0c025d2
3
+ size 561904748