liquid_state_space / liquid_state_space.py
1990two's picture
Update liquid_state_space.py
4d03589 verified
##############################################################################################################################################
#||||- - - |8.19.2025| - - - || LIQUID STATE SPACE || - - - |1990two| - - -|||| #
##############################################################################################################################################
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
from typing import List, Dict, Tuple, Optional
from scipy import linalg
from scipy.signal import cont2discrete
SAFE_MIN = -1e6
SAFE_MAX = 1e6
EPS = 1e-8
#||||- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 𓅸 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -||||#
def make_safe(tensor, min_val=SAFE_MIN, max_val=SAFE_MAX):
zero = torch.tensor(0.0, device=tensor.device, dtype=tensor.dtype)
maxv = torch.tensor(max_val, device=tensor.device, dtype=tensor.dtype)
tensor = torch.where(torch.isnan(tensor), zero, tensor)
tensor = torch.where(torch.isinf(tensor), maxv, tensor)
return torch.clamp(tensor, min_val, max_val)
def discrete_to_continuous_time(A_discrete, dt=1.0):
try:
A_continuous = linalg.logm(A_discrete.detach().cpu().numpy()) / dt
return torch.tensor(A_continuous, dtype=torch.float32, device=A_discrete.device)
except:
return torch.eye(A_discrete.shape[0], device=A_discrete.device) * 0.01
def continuous_to_discrete_time(A_continuous, B_continuous, dt=1.0):
try:
A_np = A_continuous.detach().cpu().numpy()
B_np = B_continuous.detach().cpu().numpy()
if A_np.ndim == 3:
A_list, B_list = [], []
for i in range(A_np.shape[0]):
Ad, Bd, _, _, _ = cont2discrete((A_np[i], B_np, np.eye(A_np.shape[-1]), 0), dt)
A_list.append(Ad)
B_list.append(Bd)
A_discrete = torch.tensor(np.stack(A_list), dtype=torch.float32, device=A_continuous.device)
B_discrete = torch.tensor(np.stack(B_list), dtype=torch.float32, device=B_continuous.device)
else:
A_discrete, B_discrete, _, _, _ = cont2discrete((A_np, B_np, np.eye(A_np.shape[0]), 0), dt)
A_discrete = torch.tensor(A_discrete, dtype=torch.float32, device=A_continuous.device)
B_discrete = torch.tensor(B_discrete, dtype=torch.float32, device=B_continuous.device)
return A_discrete, B_discrete
except Exception:
n = A_continuous.shape[-1]
eye = torch.eye(n, device=A_continuous.device, dtype=A_continuous.dtype)
if A_continuous.dim() == 3:
eye = eye.unsqueeze(0).expand(A_continuous.size(0), -1, -1)
B_disc = B_continuous.to(dtype=A_continuous.dtype, device=A_continuous.device) \
.unsqueeze(0).expand(A_continuous.size(0), -1, -1)
else:
B_disc = B_continuous.to(dtype=A_continuous.dtype, device=A_continuous.device)
A_discrete = eye + A_continuous * dt
B_discrete = B_disc * dt
return A_discrete, B_discrete
###########################################################################################################################################
#############################################- - - LIQUID TIME CONSTANT CONTROLLER - - -###############################################
class LiquidTimeConstantController(nn.Module):
def __init__(self, state_dim, input_dim, init_tau=1.0):
super().__init__()
self.state_dim = state_dim
self.input_dim = input_dim
self.log_tau = nn.Parameter(torch.ones(state_dim) * math.log(init_tau))
self.tau_adaptation = nn.Sequential(
nn.Linear(state_dim + input_dim, state_dim * 2),
nn.LayerNorm(state_dim * 2),
nn.Tanh(),
nn.Linear(state_dim * 2, state_dim),
nn.Tanh() # Output in [-1, 1] for modulation
)
self.adaptation_rate = nn.Parameter(torch.tensor(0.1))
def get_time_constants(self, state, input_signal):
base_tau = torch.exp(self.log_tau)
base_tau = torch.clamp(base_tau, 0.01, 10.0)
combined_input = torch.cat([state, input_signal], dim=-1)
tau_modulation = self.tau_adaptation(combined_input)
adaptation_rate = torch.clamp(self.adaptation_rate, 0.001, 1.0)
modulated_tau = base_tau * (1.0 + adaptation_rate * tau_modulation)
return torch.clamp(modulated_tau, 0.01, 10.0)
def get_effective_dt(self, tau, target_dt=0.1):
min_tau_val = torch.min(tau).item()
effective_dt = max(0.001, min(float(target_dt), min_tau_val * 0.1))
return effective_dt
###########################################################################################################################################
################################################- - - LIQUID SSM CORE - - -############################################################
class LiquidSSMCore(nn.Module):
def __init__(self, state_dim, input_dim, output_dim, dt=0.1, init_method='hippo'):
super().__init__()
self.state_dim = state_dim
self.input_dim = input_dim
self.output_dim = output_dim
self.dt = dt
if init_method == 'hippo':
self.A_continuous = nn.Parameter(self._init_hippo_matrix(state_dim))
else:
self.A_continuous = nn.Parameter(torch.randn(state_dim, state_dim) * 0.1)
self.B_continuous = nn.Parameter(torch.randn(state_dim, input_dim) * 0.1)
self.C = nn.Parameter(torch.randn(output_dim, state_dim) * 0.1)
self.D = nn.Parameter(torch.zeros(output_dim, input_dim))
self.time_controller = LiquidTimeConstantController(state_dim, input_dim, init_tau=1.0)
self.output_scale = nn.Parameter(torch.ones(output_dim))
self.output_bias = nn.Parameter(torch.zeros(output_dim))
self.state_normalizer = nn.LayerNorm(state_dim)
self.register_buffer('continuous_state', torch.zeros(1, state_dim))
def _init_hippo_matrix(self, N):
A = torch.zeros(N, N)
for i in range(N):
for j in range(N):
if i > j:
A[i, j] = math.sqrt(2 * i + 1) * math.sqrt(2 * j + 1)
elif i == j:
A[i, j] = -(2 * i + 1)
A = A * 0.1
with torch.no_grad():
eig = torch.linalg.eigvals(A).real.abs().max()
if eig > 0:
A = A / eig * 0.9
return A
def reset_state(self, batch_size=1):
device = self.A_continuous.device
self.continuous_state = torch.zeros(batch_size, self.state_dim, device=device)
def liquid_state_evolution(self, input_signal, num_steps=10):
batch_size = input_signal.shape[0]
if self.continuous_state.shape[0] != batch_size:
self.reset_state(batch_size)
tau = self.time_controller.get_time_constants(self.continuous_state, input_signal)
effective_dt = self.time_controller.get_effective_dt(tau, self.dt)
tau_matrix = torch.diag_embed(1.0 / tau)
liquid_A = self.A_continuous - tau_matrix
liquid_A = make_safe(liquid_A, min_val=-10.0, max_val=10.0)
A_discrete, B_discrete = continuous_to_discrete_time(
liquid_A, self.B_continuous, effective_dt
)
step_dt = float(effective_dt) / num_steps
A_discrete, B_discrete = continuous_to_discrete_time(
liquid_A, self.B_continuous, step_dt
)
current_state = self.continuous_state
if A_discrete.dim() == 3:
A_T = A_discrete.transpose(1, 2)
B_T = B_discrete.transpose(1, 2)
input_update = torch.bmm(input_signal.unsqueeze(1), B_T).squeeze(1)
for _ in range(num_steps):
state_update = torch.bmm(current_state.unsqueeze(1), A_T).squeeze(1)
current_state = state_update + input_update
current_state = make_safe(current_state)
else:
A_T = A_discrete.T
B_T = B_discrete.T
input_update = input_signal @ B_T
for _ in range(num_steps):
current_state = current_state @ A_T + input_update
current_state = make_safe(current_state)
current_state = make_safe(current_state)
self.continuous_state = current_state
return current_state, tau, effective_dt
def compute_output(self, state, input_signal):
normalized_state = self.state_normalizer(state)
state_output = torch.matmul(normalized_state, self.C.T)
direct_output = torch.matmul(input_signal, self.D.T)
raw_output = state_output + direct_output
output = self.output_scale * raw_output + self.output_bias
return make_safe(output)
def forward(self, input_signal, return_diagnostics=False):
evolved_state, tau, effective_dt = self.liquid_state_evolution(input_signal)
output = self.compute_output(evolved_state, input_signal)
result = {
'output': output,
'state': evolved_state
}
if return_diagnostics:
result.update({
'time_constants': tau,
'effective_dt': effective_dt,
'state_norm': torch.norm(evolved_state, dim=-1),
'adaptation_rate': self.time_controller.adaptation_rate
})
return result
###########################################################################################################################################
############################################- - - LIQUID SSM SEQUENCE LAYER - - -######################################################
class LiquidSSMSequenceLayer(nn.Module):
def __init__(self, input_dim, state_dim, output_dim, seq_len=None):
super().__init__()
self.input_dim = input_dim
self.state_dim = state_dim
self.output_dim = output_dim
self.seq_len = seq_len
self.liquid_ssm = LiquidSSMCore(state_dim, state_dim, output_dim)
self.input_projection = nn.Sequential(
nn.Linear(input_dim, state_dim),
nn.LayerNorm(state_dim),
nn.GELU()
)
self.output_projection = nn.Sequential(
nn.Linear(output_dim, output_dim * 2),
nn.LayerNorm(output_dim * 2),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(output_dim * 2, output_dim)
)
self.residual_weight = nn.Parameter(torch.tensor(0.1))
self.sequence_adapter = nn.Sequential(
nn.Linear(state_dim, state_dim),
nn.Tanh(),
nn.Linear(state_dim, 1),
nn.Sigmoid()
)
def forward(self, sequence, return_diagnostics=False):
batch_size, seq_len, input_dim = sequence.shape
self.liquid_ssm.reset_state(batch_size)
outputs = []
diagnostics = [] if return_diagnostics else None
for t in range(seq_len):
current_input = sequence[:, t, :]
projected_input = self.input_projection(current_input)
ssm_result = self.liquid_ssm(projected_input, return_diagnostics=return_diagnostics)
adaptation_factor = self.sequence_adapter(ssm_result['state'])
adapted_output = ssm_result['output'] * adaptation_factor
final_output = self.output_projection(adapted_output)
if final_output.shape == current_input.shape:
residual_strength = torch.clamp(self.residual_weight, 0.0, 1.0)
final_output = final_output + residual_strength * current_input
outputs.append(final_output)
if return_diagnostics:
diagnostics.append({
'timestep': t,
'adaptation_factor': adaptation_factor.mean().item(),
**ssm_result
})
output_sequence = torch.stack(outputs, dim=1)
result = {'output': output_sequence}
if return_diagnostics:
result['diagnostics'] = diagnostics
return result
###########################################################################################################################################
###########################################- - - LIQUID SSM LANGUAGE MODEL - - -#######################################################
class LiquidSSMLanguageModel(nn.Module):
def __init__(self, vocab_size, d_model=512, state_dim=256, num_layers=6, max_seq_len=2048):
super().__init__()
self.vocab_size = vocab_size
self.d_model = d_model
self.state_dim = state_dim
self.num_layers = num_layers
self.max_seq_len = max_seq_len
self.token_embedding = nn.Embedding(vocab_size, d_model)
self.position_embedding = nn.Embedding(max_seq_len, d_model)
self.liquid_layers = nn.ModuleList([
LiquidSSMSequenceLayer(d_model, state_dim, d_model)
for _ in range(num_layers)
])
self.layer_norms = nn.ModuleList([
nn.LayerNorm(d_model) for _ in range(num_layers)
])
self.output_norm = nn.LayerNorm(d_model)
self.lm_head = nn.Linear(d_model, vocab_size)
self.global_adaptation = nn.Sequential(
nn.Linear(d_model, d_model // 4),
nn.GELU(),
nn.Linear(d_model // 4, 1),
nn.Sigmoid()
)
self._init_weights()
def _init_weights(self):
for module in self.modules():
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, input_ids, labels=None, return_diagnostics=False):
batch_size, seq_len = input_ids.shape
device = input_ids.device
if seq_len > self.max_seq_len:
input_ids = input_ids[:, :self.max_seq_len]
seq_len = self.max_seq_len
if labels is not None:
labels = labels[:, :self.max_seq_len]
input_ids = torch.clamp(input_ids, 0, self.vocab_size - 1)
token_emb = self.token_embedding(input_ids)
pos_ids = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1)
pos_emb = self.position_embedding(pos_ids)
x = token_emb + pos_emb
x = make_safe(x)
layer_diagnostics = [] if return_diagnostics else None
for layer_idx, (liquid_layer, layer_norm) in enumerate(zip(self.liquid_layers, self.layer_norms)):
residual = x
x = layer_norm(x)
layer_result = liquid_layer(x, return_diagnostics=return_diagnostics)
x = layer_result['output']
adaptation = self.global_adaptation(x.mean(dim=1, keepdim=True))
x = x * adaptation
x = residual + x
x = make_safe(x)
if return_diagnostics:
layer_diagnostics.append({
'layer': layer_idx,
'adaptation': adaptation.mean().item(),
**layer_result
})
x = self.output_norm(x)
logits = self.lm_head(x)
logits = make_safe(logits, min_val=-50, max_val=50)
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = F.cross_entropy(
shift_logits.view(-1, self.vocab_size),
shift_labels.view(-1),
ignore_index=-100
)
result = {
'logits': logits,
'loss': loss
}
if return_diagnostics:
result['layer_diagnostics'] = layer_diagnostics
return result
@torch.no_grad()
def generate(self, input_ids, max_length=100, temperature=1.0, top_p=0.95, return_diagnostics=False):
self.eval()
generated = input_ids.clone()
all_diagnostics = [] if return_diagnostics else None
for step in range(max_length - input_ids.shape[1]):
if generated.shape[1] > self.max_seq_len:
break
outputs = self(generated, return_diagnostics=return_diagnostics)
logits = outputs['logits']
if return_diagnostics:
all_diagnostics.append(outputs.get('layer_diagnostics', []))
next_token_logits = logits[:, -1, :] / max(temperature, EPS)
next_token_logits = make_safe(next_token_logits, min_val=-50, max_val=50)
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = False
for b in range(next_token_logits.size(0)):
indices_to_remove = sorted_indices[b][sorted_indices_to_remove[b]]
next_token_logits[b, indices_to_remove] = -float('inf')
probs = F.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
next_token = torch.clamp(next_token, 0, self.vocab_size - 1)
generated = torch.cat([generated, next_token], dim=1)
if next_token.item() == 2: # EOS token
break
result = {'generated_ids': generated}
if return_diagnostics:
result['diagnostics'] = all_diagnostics
return result