|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import numpy as np |
|
|
import math |
|
|
from typing import List, Dict, Tuple, Optional |
|
|
from scipy import linalg |
|
|
from scipy.signal import cont2discrete |
|
|
|
|
|
SAFE_MIN = -1e6 |
|
|
SAFE_MAX = 1e6 |
|
|
EPS = 1e-8 |
|
|
|
|
|
|
|
|
|
|
|
def make_safe(tensor, min_val=SAFE_MIN, max_val=SAFE_MAX): |
|
|
zero = torch.tensor(0.0, device=tensor.device, dtype=tensor.dtype) |
|
|
maxv = torch.tensor(max_val, device=tensor.device, dtype=tensor.dtype) |
|
|
tensor = torch.where(torch.isnan(tensor), zero, tensor) |
|
|
tensor = torch.where(torch.isinf(tensor), maxv, tensor) |
|
|
return torch.clamp(tensor, min_val, max_val) |
|
|
|
|
|
def discrete_to_continuous_time(A_discrete, dt=1.0): |
|
|
try: |
|
|
A_continuous = linalg.logm(A_discrete.detach().cpu().numpy()) / dt |
|
|
return torch.tensor(A_continuous, dtype=torch.float32, device=A_discrete.device) |
|
|
except: |
|
|
return torch.eye(A_discrete.shape[0], device=A_discrete.device) * 0.01 |
|
|
|
|
|
def continuous_to_discrete_time(A_continuous, B_continuous, dt=1.0): |
|
|
try: |
|
|
A_np = A_continuous.detach().cpu().numpy() |
|
|
B_np = B_continuous.detach().cpu().numpy() |
|
|
|
|
|
if A_np.ndim == 3: |
|
|
A_list, B_list = [], [] |
|
|
for i in range(A_np.shape[0]): |
|
|
Ad, Bd, _, _, _ = cont2discrete((A_np[i], B_np, np.eye(A_np.shape[-1]), 0), dt) |
|
|
A_list.append(Ad) |
|
|
B_list.append(Bd) |
|
|
A_discrete = torch.tensor(np.stack(A_list), dtype=torch.float32, device=A_continuous.device) |
|
|
B_discrete = torch.tensor(np.stack(B_list), dtype=torch.float32, device=B_continuous.device) |
|
|
else: |
|
|
A_discrete, B_discrete, _, _, _ = cont2discrete((A_np, B_np, np.eye(A_np.shape[0]), 0), dt) |
|
|
A_discrete = torch.tensor(A_discrete, dtype=torch.float32, device=A_continuous.device) |
|
|
B_discrete = torch.tensor(B_discrete, dtype=torch.float32, device=B_continuous.device) |
|
|
|
|
|
|
|
|
return A_discrete, B_discrete |
|
|
except Exception: |
|
|
n = A_continuous.shape[-1] |
|
|
eye = torch.eye(n, device=A_continuous.device, dtype=A_continuous.dtype) |
|
|
if A_continuous.dim() == 3: |
|
|
eye = eye.unsqueeze(0).expand(A_continuous.size(0), -1, -1) |
|
|
B_disc = B_continuous.to(dtype=A_continuous.dtype, device=A_continuous.device) \ |
|
|
.unsqueeze(0).expand(A_continuous.size(0), -1, -1) |
|
|
else: |
|
|
B_disc = B_continuous.to(dtype=A_continuous.dtype, device=A_continuous.device) |
|
|
A_discrete = eye + A_continuous * dt |
|
|
B_discrete = B_disc * dt |
|
|
return A_discrete, B_discrete |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LiquidTimeConstantController(nn.Module): |
|
|
def __init__(self, state_dim, input_dim, init_tau=1.0): |
|
|
super().__init__() |
|
|
self.state_dim = state_dim |
|
|
self.input_dim = input_dim |
|
|
|
|
|
self.log_tau = nn.Parameter(torch.ones(state_dim) * math.log(init_tau)) |
|
|
|
|
|
self.tau_adaptation = nn.Sequential( |
|
|
nn.Linear(state_dim + input_dim, state_dim * 2), |
|
|
nn.LayerNorm(state_dim * 2), |
|
|
nn.Tanh(), |
|
|
nn.Linear(state_dim * 2, state_dim), |
|
|
nn.Tanh() |
|
|
) |
|
|
|
|
|
self.adaptation_rate = nn.Parameter(torch.tensor(0.1)) |
|
|
|
|
|
def get_time_constants(self, state, input_signal): |
|
|
base_tau = torch.exp(self.log_tau) |
|
|
base_tau = torch.clamp(base_tau, 0.01, 10.0) |
|
|
|
|
|
combined_input = torch.cat([state, input_signal], dim=-1) |
|
|
tau_modulation = self.tau_adaptation(combined_input) |
|
|
|
|
|
adaptation_rate = torch.clamp(self.adaptation_rate, 0.001, 1.0) |
|
|
modulated_tau = base_tau * (1.0 + adaptation_rate * tau_modulation) |
|
|
|
|
|
return torch.clamp(modulated_tau, 0.01, 10.0) |
|
|
|
|
|
def get_effective_dt(self, tau, target_dt=0.1): |
|
|
min_tau_val = torch.min(tau).item() |
|
|
effective_dt = max(0.001, min(float(target_dt), min_tau_val * 0.1)) |
|
|
return effective_dt |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LiquidSSMCore(nn.Module): |
|
|
def __init__(self, state_dim, input_dim, output_dim, dt=0.1, init_method='hippo'): |
|
|
super().__init__() |
|
|
self.state_dim = state_dim |
|
|
self.input_dim = input_dim |
|
|
self.output_dim = output_dim |
|
|
self.dt = dt |
|
|
|
|
|
if init_method == 'hippo': |
|
|
self.A_continuous = nn.Parameter(self._init_hippo_matrix(state_dim)) |
|
|
else: |
|
|
self.A_continuous = nn.Parameter(torch.randn(state_dim, state_dim) * 0.1) |
|
|
|
|
|
self.B_continuous = nn.Parameter(torch.randn(state_dim, input_dim) * 0.1) |
|
|
self.C = nn.Parameter(torch.randn(output_dim, state_dim) * 0.1) |
|
|
self.D = nn.Parameter(torch.zeros(output_dim, input_dim)) |
|
|
|
|
|
self.time_controller = LiquidTimeConstantController(state_dim, input_dim, init_tau=1.0) |
|
|
|
|
|
self.output_scale = nn.Parameter(torch.ones(output_dim)) |
|
|
self.output_bias = nn.Parameter(torch.zeros(output_dim)) |
|
|
|
|
|
self.state_normalizer = nn.LayerNorm(state_dim) |
|
|
|
|
|
self.register_buffer('continuous_state', torch.zeros(1, state_dim)) |
|
|
|
|
|
def _init_hippo_matrix(self, N): |
|
|
A = torch.zeros(N, N) |
|
|
for i in range(N): |
|
|
for j in range(N): |
|
|
if i > j: |
|
|
A[i, j] = math.sqrt(2 * i + 1) * math.sqrt(2 * j + 1) |
|
|
elif i == j: |
|
|
A[i, j] = -(2 * i + 1) |
|
|
A = A * 0.1 |
|
|
with torch.no_grad(): |
|
|
eig = torch.linalg.eigvals(A).real.abs().max() |
|
|
if eig > 0: |
|
|
A = A / eig * 0.9 |
|
|
return A |
|
|
|
|
|
def reset_state(self, batch_size=1): |
|
|
device = self.A_continuous.device |
|
|
self.continuous_state = torch.zeros(batch_size, self.state_dim, device=device) |
|
|
|
|
|
def liquid_state_evolution(self, input_signal, num_steps=10): |
|
|
batch_size = input_signal.shape[0] |
|
|
|
|
|
if self.continuous_state.shape[0] != batch_size: |
|
|
self.reset_state(batch_size) |
|
|
|
|
|
tau = self.time_controller.get_time_constants(self.continuous_state, input_signal) |
|
|
effective_dt = self.time_controller.get_effective_dt(tau, self.dt) |
|
|
|
|
|
tau_matrix = torch.diag_embed(1.0 / tau) |
|
|
liquid_A = self.A_continuous - tau_matrix |
|
|
|
|
|
liquid_A = make_safe(liquid_A, min_val=-10.0, max_val=10.0) |
|
|
|
|
|
A_discrete, B_discrete = continuous_to_discrete_time( |
|
|
liquid_A, self.B_continuous, effective_dt |
|
|
) |
|
|
|
|
|
step_dt = float(effective_dt) / num_steps |
|
|
A_discrete, B_discrete = continuous_to_discrete_time( |
|
|
liquid_A, self.B_continuous, step_dt |
|
|
) |
|
|
current_state = self.continuous_state |
|
|
|
|
|
if A_discrete.dim() == 3: |
|
|
A_T = A_discrete.transpose(1, 2) |
|
|
B_T = B_discrete.transpose(1, 2) |
|
|
input_update = torch.bmm(input_signal.unsqueeze(1), B_T).squeeze(1) |
|
|
for _ in range(num_steps): |
|
|
state_update = torch.bmm(current_state.unsqueeze(1), A_T).squeeze(1) |
|
|
current_state = state_update + input_update |
|
|
current_state = make_safe(current_state) |
|
|
else: |
|
|
A_T = A_discrete.T |
|
|
B_T = B_discrete.T |
|
|
input_update = input_signal @ B_T |
|
|
for _ in range(num_steps): |
|
|
current_state = current_state @ A_T + input_update |
|
|
current_state = make_safe(current_state) |
|
|
|
|
|
current_state = make_safe(current_state) |
|
|
|
|
|
self.continuous_state = current_state |
|
|
|
|
|
return current_state, tau, effective_dt |
|
|
|
|
|
def compute_output(self, state, input_signal): |
|
|
normalized_state = self.state_normalizer(state) |
|
|
|
|
|
state_output = torch.matmul(normalized_state, self.C.T) |
|
|
direct_output = torch.matmul(input_signal, self.D.T) |
|
|
|
|
|
raw_output = state_output + direct_output |
|
|
|
|
|
output = self.output_scale * raw_output + self.output_bias |
|
|
|
|
|
return make_safe(output) |
|
|
|
|
|
def forward(self, input_signal, return_diagnostics=False): |
|
|
evolved_state, tau, effective_dt = self.liquid_state_evolution(input_signal) |
|
|
|
|
|
output = self.compute_output(evolved_state, input_signal) |
|
|
|
|
|
result = { |
|
|
'output': output, |
|
|
'state': evolved_state |
|
|
} |
|
|
|
|
|
if return_diagnostics: |
|
|
result.update({ |
|
|
'time_constants': tau, |
|
|
'effective_dt': effective_dt, |
|
|
'state_norm': torch.norm(evolved_state, dim=-1), |
|
|
'adaptation_rate': self.time_controller.adaptation_rate |
|
|
}) |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LiquidSSMSequenceLayer(nn.Module): |
|
|
def __init__(self, input_dim, state_dim, output_dim, seq_len=None): |
|
|
super().__init__() |
|
|
self.input_dim = input_dim |
|
|
self.state_dim = state_dim |
|
|
self.output_dim = output_dim |
|
|
self.seq_len = seq_len |
|
|
|
|
|
self.liquid_ssm = LiquidSSMCore(state_dim, state_dim, output_dim) |
|
|
|
|
|
self.input_projection = nn.Sequential( |
|
|
nn.Linear(input_dim, state_dim), |
|
|
nn.LayerNorm(state_dim), |
|
|
nn.GELU() |
|
|
) |
|
|
|
|
|
self.output_projection = nn.Sequential( |
|
|
nn.Linear(output_dim, output_dim * 2), |
|
|
nn.LayerNorm(output_dim * 2), |
|
|
nn.GELU(), |
|
|
nn.Dropout(0.1), |
|
|
nn.Linear(output_dim * 2, output_dim) |
|
|
) |
|
|
|
|
|
self.residual_weight = nn.Parameter(torch.tensor(0.1)) |
|
|
|
|
|
self.sequence_adapter = nn.Sequential( |
|
|
nn.Linear(state_dim, state_dim), |
|
|
nn.Tanh(), |
|
|
nn.Linear(state_dim, 1), |
|
|
nn.Sigmoid() |
|
|
) |
|
|
|
|
|
def forward(self, sequence, return_diagnostics=False): |
|
|
batch_size, seq_len, input_dim = sequence.shape |
|
|
|
|
|
self.liquid_ssm.reset_state(batch_size) |
|
|
|
|
|
outputs = [] |
|
|
diagnostics = [] if return_diagnostics else None |
|
|
|
|
|
for t in range(seq_len): |
|
|
current_input = sequence[:, t, :] |
|
|
|
|
|
projected_input = self.input_projection(current_input) |
|
|
|
|
|
ssm_result = self.liquid_ssm(projected_input, return_diagnostics=return_diagnostics) |
|
|
|
|
|
adaptation_factor = self.sequence_adapter(ssm_result['state']) |
|
|
adapted_output = ssm_result['output'] * adaptation_factor |
|
|
|
|
|
final_output = self.output_projection(adapted_output) |
|
|
|
|
|
if final_output.shape == current_input.shape: |
|
|
residual_strength = torch.clamp(self.residual_weight, 0.0, 1.0) |
|
|
final_output = final_output + residual_strength * current_input |
|
|
|
|
|
outputs.append(final_output) |
|
|
|
|
|
if return_diagnostics: |
|
|
diagnostics.append({ |
|
|
'timestep': t, |
|
|
'adaptation_factor': adaptation_factor.mean().item(), |
|
|
**ssm_result |
|
|
}) |
|
|
|
|
|
output_sequence = torch.stack(outputs, dim=1) |
|
|
|
|
|
result = {'output': output_sequence} |
|
|
|
|
|
if return_diagnostics: |
|
|
result['diagnostics'] = diagnostics |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LiquidSSMLanguageModel(nn.Module): |
|
|
def __init__(self, vocab_size, d_model=512, state_dim=256, num_layers=6, max_seq_len=2048): |
|
|
super().__init__() |
|
|
self.vocab_size = vocab_size |
|
|
self.d_model = d_model |
|
|
self.state_dim = state_dim |
|
|
self.num_layers = num_layers |
|
|
self.max_seq_len = max_seq_len |
|
|
|
|
|
self.token_embedding = nn.Embedding(vocab_size, d_model) |
|
|
self.position_embedding = nn.Embedding(max_seq_len, d_model) |
|
|
|
|
|
self.liquid_layers = nn.ModuleList([ |
|
|
LiquidSSMSequenceLayer(d_model, state_dim, d_model) |
|
|
for _ in range(num_layers) |
|
|
]) |
|
|
|
|
|
self.layer_norms = nn.ModuleList([ |
|
|
nn.LayerNorm(d_model) for _ in range(num_layers) |
|
|
]) |
|
|
|
|
|
self.output_norm = nn.LayerNorm(d_model) |
|
|
self.lm_head = nn.Linear(d_model, vocab_size) |
|
|
|
|
|
self.global_adaptation = nn.Sequential( |
|
|
nn.Linear(d_model, d_model // 4), |
|
|
nn.GELU(), |
|
|
nn.Linear(d_model // 4, 1), |
|
|
nn.Sigmoid() |
|
|
) |
|
|
|
|
|
self._init_weights() |
|
|
|
|
|
def _init_weights(self): |
|
|
for module in self.modules(): |
|
|
if isinstance(module, nn.Linear): |
|
|
nn.init.xavier_uniform_(module.weight) |
|
|
if module.bias is not None: |
|
|
nn.init.zeros_(module.bias) |
|
|
elif isinstance(module, nn.Embedding): |
|
|
nn.init.normal_(module.weight, mean=0.0, std=0.02) |
|
|
|
|
|
def forward(self, input_ids, labels=None, return_diagnostics=False): |
|
|
batch_size, seq_len = input_ids.shape |
|
|
device = input_ids.device |
|
|
|
|
|
if seq_len > self.max_seq_len: |
|
|
input_ids = input_ids[:, :self.max_seq_len] |
|
|
seq_len = self.max_seq_len |
|
|
if labels is not None: |
|
|
labels = labels[:, :self.max_seq_len] |
|
|
|
|
|
input_ids = torch.clamp(input_ids, 0, self.vocab_size - 1) |
|
|
|
|
|
token_emb = self.token_embedding(input_ids) |
|
|
pos_ids = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1) |
|
|
pos_emb = self.position_embedding(pos_ids) |
|
|
|
|
|
x = token_emb + pos_emb |
|
|
x = make_safe(x) |
|
|
|
|
|
layer_diagnostics = [] if return_diagnostics else None |
|
|
|
|
|
for layer_idx, (liquid_layer, layer_norm) in enumerate(zip(self.liquid_layers, self.layer_norms)): |
|
|
residual = x |
|
|
|
|
|
x = layer_norm(x) |
|
|
|
|
|
layer_result = liquid_layer(x, return_diagnostics=return_diagnostics) |
|
|
x = layer_result['output'] |
|
|
|
|
|
adaptation = self.global_adaptation(x.mean(dim=1, keepdim=True)) |
|
|
x = x * adaptation |
|
|
|
|
|
x = residual + x |
|
|
x = make_safe(x) |
|
|
|
|
|
if return_diagnostics: |
|
|
layer_diagnostics.append({ |
|
|
'layer': layer_idx, |
|
|
'adaptation': adaptation.mean().item(), |
|
|
**layer_result |
|
|
}) |
|
|
|
|
|
x = self.output_norm(x) |
|
|
logits = self.lm_head(x) |
|
|
logits = make_safe(logits, min_val=-50, max_val=50) |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
shift_logits = logits[..., :-1, :].contiguous() |
|
|
shift_labels = labels[..., 1:].contiguous() |
|
|
loss = F.cross_entropy( |
|
|
shift_logits.view(-1, self.vocab_size), |
|
|
shift_labels.view(-1), |
|
|
ignore_index=-100 |
|
|
) |
|
|
|
|
|
result = { |
|
|
'logits': logits, |
|
|
'loss': loss |
|
|
} |
|
|
|
|
|
if return_diagnostics: |
|
|
result['layer_diagnostics'] = layer_diagnostics |
|
|
|
|
|
return result |
|
|
|
|
|
@torch.no_grad() |
|
|
def generate(self, input_ids, max_length=100, temperature=1.0, top_p=0.95, return_diagnostics=False): |
|
|
self.eval() |
|
|
generated = input_ids.clone() |
|
|
all_diagnostics = [] if return_diagnostics else None |
|
|
|
|
|
for step in range(max_length - input_ids.shape[1]): |
|
|
if generated.shape[1] > self.max_seq_len: |
|
|
break |
|
|
|
|
|
outputs = self(generated, return_diagnostics=return_diagnostics) |
|
|
logits = outputs['logits'] |
|
|
|
|
|
if return_diagnostics: |
|
|
all_diagnostics.append(outputs.get('layer_diagnostics', [])) |
|
|
|
|
|
next_token_logits = logits[:, -1, :] / max(temperature, EPS) |
|
|
next_token_logits = make_safe(next_token_logits, min_val=-50, max_val=50) |
|
|
|
|
|
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True) |
|
|
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
|
|
|
|
|
sorted_indices_to_remove = cumulative_probs > top_p |
|
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
|
|
sorted_indices_to_remove[..., 0] = False |
|
|
|
|
|
for b in range(next_token_logits.size(0)): |
|
|
indices_to_remove = sorted_indices[b][sorted_indices_to_remove[b]] |
|
|
next_token_logits[b, indices_to_remove] = -float('inf') |
|
|
|
|
|
probs = F.softmax(next_token_logits, dim=-1) |
|
|
next_token = torch.multinomial(probs, num_samples=1) |
|
|
next_token = torch.clamp(next_token, 0, self.vocab_size - 1) |
|
|
|
|
|
generated = torch.cat([generated, next_token], dim=1) |
|
|
|
|
|
if next_token.item() == 2: |
|
|
break |
|
|
|
|
|
result = {'generated_ids': generated} |
|
|
if return_diagnostics: |
|
|
result['diagnostics'] = all_diagnostics |
|
|
|
|
|
return result |
|
|
|
|
|
|