| """ |
| Published baseline models for NeurIPS 2026 benchmark experiments. |
| |
| Contains faithful implementations of 6 published models: |
| 1. DeepConvLSTM (Ordonez & Roggen, Sensors 2016) - Exp1/Exp3 |
| 2. InceptionTime (Fawaz et al., DMKD 2020) - Exp1/Exp3 |
| 3. MS-TCN++ (Li et al., TPAMI 2020) - Exp2 |
| 4. DiffAct (Liu et al., ICCV 2023) - Exp2 |
| 5. UnderPressure (Mourot et al., SCA/CGF 2022) - Exp3/Exp4a |
| 6. emg2pose (Meta, NeurIPS 2024 D&B) - Exp4b |
| """ |
|
|
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
|
|
|
|
| |
| |
| |
| |
| |
| |
|
|
| class DeepConvLSTMBackbone(nn.Module): |
| """DeepConvLSTM backbone for sequence-level classification (Exp1). |
| |
| Input: (B, T, C), optional mask |
| Output: (B, output_dim) |
| """ |
|
|
| def __init__(self, input_dim, hidden_dim=128, num_conv_layers=4, |
| conv_filters=64, conv_kernel=5, num_lstm_layers=2): |
| super().__init__() |
| conv_layers = [] |
| in_ch = input_dim |
| for i in range(num_conv_layers): |
| out_ch = conv_filters |
| conv_layers.append(nn.Sequential( |
| nn.Conv1d(in_ch, out_ch, conv_kernel, padding=conv_kernel // 2), |
| nn.BatchNorm1d(out_ch), |
| nn.ReLU(), |
| nn.Dropout(0.1 if i < num_conv_layers - 1 else 0.2), |
| )) |
| in_ch = out_ch |
| self.convs = nn.ModuleList(conv_layers) |
|
|
| self.lstm = nn.LSTM( |
| conv_filters, hidden_dim, num_layers=num_lstm_layers, |
| batch_first=True, bidirectional=False, |
| dropout=0.2 if num_lstm_layers > 1 else 0, |
| ) |
| self.output_dim = hidden_dim |
|
|
| def forward(self, x, mask=None): |
| |
| x = x.permute(0, 2, 1) |
| for conv in self.convs: |
| x = conv(x) |
| x = x.permute(0, 2, 1) |
|
|
| out, (h_n, _) = self.lstm(x) |
| |
| feat = h_n[-1] |
| return feat |
|
|
|
|
| class DeepConvLSTMContact(nn.Module): |
| """DeepConvLSTM for frame-level contact detection (Exp3). |
| |
| Input: (B, T, C) |
| Output: (B, T, 2) |
| """ |
|
|
| def __init__(self, input_dim, hidden_dim=64, num_conv_layers=4, |
| conv_filters=64, conv_kernel=5): |
| super().__init__() |
| conv_layers = [] |
| in_ch = input_dim |
| for i in range(num_conv_layers): |
| conv_layers.append(nn.Sequential( |
| nn.Conv1d(in_ch, conv_filters, conv_kernel, padding=conv_kernel // 2), |
| nn.BatchNorm1d(conv_filters), |
| nn.ReLU(), |
| nn.Dropout(0.1), |
| )) |
| in_ch = conv_filters |
| self.convs = nn.ModuleList(conv_layers) |
| self.lstm = nn.LSTM(conv_filters, hidden_dim, num_layers=2, |
| batch_first=True, bidirectional=True, dropout=0.2) |
| self.head = nn.Linear(hidden_dim * 2, 2) |
|
|
| def forward(self, x): |
| x = x.permute(0, 2, 1) |
| for conv in self.convs: |
| x = conv(x) |
| x = x.permute(0, 2, 1) |
| out, _ = self.lstm(x) |
| return self.head(out) |
|
|
|
|
| |
| |
| |
| |
| |
|
|
| class InceptionModule(nn.Module): |
| """Single Inception module for time series.""" |
|
|
| def __init__(self, in_channels, n_filters=32, kernel_sizes=(9, 19, 39), |
| bottleneck_channels=32): |
| super().__init__() |
| |
| self.bottleneck = nn.Conv1d(in_channels, bottleneck_channels, 1, bias=False) |
|
|
| |
| self.convs = nn.ModuleList() |
| for ks in kernel_sizes: |
| self.convs.append( |
| nn.Conv1d(bottleneck_channels, n_filters, ks, |
| padding=(ks - 1) // 2, bias=False) |
| ) |
|
|
| |
| self.maxpool_conv = nn.Sequential( |
| nn.MaxPool1d(3, stride=1, padding=1), |
| nn.Conv1d(in_channels, n_filters, 1, bias=False), |
| ) |
|
|
| self.bn = nn.BatchNorm1d(n_filters * (len(kernel_sizes) + 1)) |
| self.relu = nn.ReLU() |
|
|
| def forward(self, x): |
| |
| x_bottleneck = self.bottleneck(x) |
| conv_outputs = [conv(x_bottleneck) for conv in self.convs] |
| conv_outputs.append(self.maxpool_conv(x)) |
| out = torch.cat(conv_outputs, dim=1) |
| return self.relu(self.bn(out)) |
|
|
|
|
| class InceptionBlock(nn.Module): |
| """Stack of Inception modules with a residual connection.""" |
|
|
| def __init__(self, in_channels, n_filters=32, depth=3): |
| super().__init__() |
| n_out = n_filters * 4 |
| modules = [] |
| for i in range(depth): |
| inc = in_channels if i == 0 else n_out |
| modules.append(InceptionModule(inc, n_filters)) |
| self.modules_list = nn.ModuleList(modules) |
|
|
| |
| self.use_residual = (in_channels != n_out) |
| if self.use_residual: |
| self.residual = nn.Sequential( |
| nn.Conv1d(in_channels, n_out, 1, bias=False), |
| nn.BatchNorm1d(n_out), |
| ) |
| self.relu = nn.ReLU() |
|
|
| def forward(self, x): |
| residual = x |
| for mod in self.modules_list: |
| x = mod(x) |
| if self.use_residual: |
| residual = self.residual(residual) |
| return self.relu(x + residual) |
|
|
|
|
| class InceptionTimeBackbone(nn.Module): |
| """InceptionTime backbone for sequence-level classification (Exp1). |
| |
| Input: (B, T, C), optional mask |
| Output: (B, output_dim) |
| """ |
|
|
| def __init__(self, input_dim, hidden_dim=128, n_filters=32, num_blocks=2, depth=3): |
| super().__init__() |
| blocks = [] |
| in_ch = input_dim |
| for i in range(num_blocks): |
| blocks.append(InceptionBlock(in_ch, n_filters, depth)) |
| in_ch = n_filters * 4 |
| self.blocks = nn.ModuleList(blocks) |
| self.output_dim = n_filters * 4 |
|
|
| def forward(self, x, mask=None): |
| |
| x = x.permute(0, 2, 1) |
| for block in self.blocks: |
| x = block(x) |
| |
| if mask is not None: |
| x = (x * mask.unsqueeze(1).float()).sum(2) / mask.sum(1, keepdim=True).float().clamp(min=1) |
| else: |
| x = x.mean(2) |
| return x |
|
|
|
|
| class InceptionTimeContact(nn.Module): |
| """InceptionTime for frame-level contact detection (Exp3). |
| |
| Input: (B, T, C) |
| Output: (B, T, 2) |
| """ |
|
|
| def __init__(self, input_dim, hidden_dim=64, n_filters=32, num_blocks=2, depth=3): |
| super().__init__() |
| blocks = [] |
| in_ch = input_dim |
| for i in range(num_blocks): |
| blocks.append(InceptionBlock(in_ch, n_filters, depth)) |
| in_ch = n_filters * 4 |
| self.blocks = nn.ModuleList(blocks) |
| self.head = nn.Conv1d(n_filters * 4, 2, 1) |
|
|
| def forward(self, x): |
| x = x.permute(0, 2, 1) |
| for block in self.blocks: |
| x = block(x) |
| out = self.head(x) |
| return out.permute(0, 2, 1) |
|
|
|
|
| |
| |
| |
| |
| |
| |
|
|
| class DualDilatedResBlock(nn.Module): |
| """Dual dilated residual block (MS-TCN++ key contribution). |
| |
| Uses two parallel dilated convolutions with different dilation rates |
| to capture both short-range and long-range temporal patterns. |
| """ |
|
|
| def __init__(self, channels, dilation1, dilation2): |
| super().__init__() |
| |
| self.conv1_dilated = nn.Conv1d( |
| channels, channels, 3, |
| padding=dilation1, dilation=dilation1 |
| ) |
| |
| self.conv2_dilated = nn.Conv1d( |
| channels, channels, 3, |
| padding=dilation2, dilation=dilation2 |
| ) |
| self.conv_fusion = nn.Conv1d(channels, channels, 1) |
| self.bn = nn.BatchNorm1d(channels) |
| self.dropout = nn.Dropout(0.3) |
|
|
| def forward(self, x): |
| residual = x |
| out1 = F.relu(self.conv1_dilated(x)) |
| out2 = F.relu(self.conv2_dilated(x)) |
| out = out1 + out2 |
| out = self.dropout(F.relu(self.bn(self.conv_fusion(out)))) |
| return out + residual |
|
|
|
|
| class MSTCNPPStage(nn.Module): |
| """Single stage of MS-TCN++ with dual dilated layers.""" |
|
|
| def __init__(self, in_channels, hidden_channels, num_classes, num_layers=10): |
| super().__init__() |
| self.input_conv = nn.Conv1d(in_channels, hidden_channels, 1) |
| self.layers = nn.ModuleList() |
| for i in range(num_layers): |
| dilation1 = 2 ** i |
| dilation2 = 2 ** (i + 1) if i < num_layers - 1 else 2 ** i |
| self.layers.append(DualDilatedResBlock(hidden_channels, dilation1, dilation2)) |
| self.output_conv = nn.Conv1d(hidden_channels, num_classes, 1) |
|
|
| def forward(self, x): |
| x = self.input_conv(x) |
| for layer in self.layers: |
| x = layer(x) |
| return self.output_conv(x) |
|
|
|
|
| class MSTCNPP(nn.Module): |
| """MS-TCN++ for temporal action segmentation (Exp2). |
| |
| Input: (B, T, C) |
| Output: list of (B, T, num_classes) per stage |
| """ |
|
|
| def __init__(self, input_dim, num_classes, hidden_dim=64, num_stages=4, num_layers=10): |
| super().__init__() |
| self.stages = nn.ModuleList() |
| |
| self.stages.append(MSTCNPPStage(input_dim, hidden_dim, num_classes, num_layers)) |
| |
| for _ in range(num_stages - 1): |
| self.stages.append(MSTCNPPStage(num_classes, hidden_dim, num_classes, num_layers)) |
|
|
| def forward(self, x): |
| x = x.permute(0, 2, 1) |
| outputs = [] |
| for stage in self.stages: |
| x = stage(x) |
| outputs.append(x.permute(0, 2, 1)) |
| |
| if stage != self.stages[-1]: |
| x = F.softmax(x, dim=1) |
| return outputs |
|
|
|
|
| |
| |
| |
| |
| |
| |
|
|
| class ConditionalLayerNorm(nn.Module): |
| """Layer norm conditioned on diffusion timestep.""" |
|
|
| def __init__(self, channels): |
| super().__init__() |
| self.norm = nn.GroupNorm(1, channels) |
|
|
| def forward(self, x): |
| return self.norm(x) |
|
|
|
|
| class DiffActBlock(nn.Module): |
| """Residual block for DiffAct denoising network.""" |
|
|
| def __init__(self, channels, dilation, time_emb_dim): |
| super().__init__() |
| self.conv1 = nn.Conv1d(channels, channels, 3, padding=dilation, dilation=dilation) |
| self.conv2 = nn.Conv1d(channels, channels, 1) |
| self.norm1 = ConditionalLayerNorm(channels) |
| self.norm2 = ConditionalLayerNorm(channels) |
| self.time_proj = nn.Linear(time_emb_dim, channels) |
| self.dropout = nn.Dropout(0.1) |
|
|
| def forward(self, x, time_emb): |
| residual = x |
| x = self.norm1(x) |
| x = F.relu(self.conv1(x)) |
| |
| t = self.time_proj(time_emb).unsqueeze(-1) |
| x = x + t |
| x = self.norm2(x) |
| x = self.dropout(F.relu(self.conv2(x))) |
| return x + residual |
|
|
|
|
| class DiffActConditionEncoder(nn.Module): |
| """Temporal feature encoder for conditioning the denoising network.""" |
|
|
| def __init__(self, input_dim, hidden_dim, num_layers=6): |
| super().__init__() |
| self.input_conv = nn.Conv1d(input_dim, hidden_dim, 1) |
| self.layers = nn.ModuleList() |
| for i in range(num_layers): |
| dilation = 2 ** (i % 5) |
| self.layers.append(nn.Sequential( |
| nn.Conv1d(hidden_dim, hidden_dim, 3, padding=dilation, dilation=dilation), |
| nn.BatchNorm1d(hidden_dim), |
| nn.ReLU(), |
| nn.Dropout(0.1), |
| )) |
|
|
| def forward(self, x): |
| x = self.input_conv(x) |
| for layer in self.layers: |
| x = layer(x) + x |
| return x |
|
|
|
|
| class SinusoidalTimeEmbedding(nn.Module): |
| """Sinusoidal positional embedding for diffusion timestep.""" |
|
|
| def __init__(self, dim): |
| super().__init__() |
| self.dim = dim |
| self.mlp = nn.Sequential( |
| nn.Linear(dim, dim * 4), |
| nn.GELU(), |
| nn.Linear(dim * 4, dim), |
| ) |
|
|
| def forward(self, t): |
| half_dim = self.dim // 2 |
| emb = math.log(10000) / (half_dim - 1) |
| emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb) |
| emb = t.unsqueeze(-1).float() * emb.unsqueeze(0) |
| emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) |
| return self.mlp(emb) |
|
|
|
|
| class DiffAct(nn.Module): |
| """DiffAct: Diffusion Action Segmentation (Exp2). |
| |
| During training: noises ground-truth action probabilities and denoises. |
| During inference: iteratively denoises from pure noise. |
| |
| Input: (B, T, C) |
| Output: list of (B, T, num_classes) [final denoised prediction] |
| """ |
|
|
| def __init__(self, input_dim, num_classes, hidden_dim=64, |
| num_encoder_layers=6, num_denoise_layers=6, |
| num_diffusion_steps=10): |
| super().__init__() |
| self.num_classes = num_classes |
| self.num_steps = num_diffusion_steps |
|
|
| |
| self.condition_encoder = DiffActConditionEncoder(input_dim, hidden_dim, num_encoder_layers) |
|
|
| |
| self.initial_head = nn.Conv1d(hidden_dim, num_classes, 1) |
|
|
| |
| self.time_emb = SinusoidalTimeEmbedding(hidden_dim) |
|
|
| |
| self.denoise_input = nn.Conv1d(num_classes + hidden_dim, hidden_dim, 1) |
| self.denoise_blocks = nn.ModuleList() |
| for i in range(num_denoise_layers): |
| dilation = 2 ** (i % 5) |
| self.denoise_blocks.append(DiffActBlock(hidden_dim, dilation, hidden_dim)) |
| self.denoise_output = nn.Conv1d(hidden_dim, num_classes, 1) |
|
|
| |
| self._setup_noise_schedule() |
|
|
| def _setup_noise_schedule(self): |
| steps = self.num_steps |
| s = 0.008 |
| t = torch.linspace(0, steps, steps + 1) |
| alphas_cumprod = torch.cos(((t / steps) + s) / (1 + s) * math.pi * 0.5) ** 2 |
| alphas_cumprod = alphas_cumprod / alphas_cumprod[0] |
| betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) |
| betas = torch.clamp(betas, 0.0001, 0.999) |
| alphas = 1.0 - betas |
| alphas_cumprod = torch.cumprod(alphas, dim=0) |
| self.register_buffer('betas', betas) |
| self.register_buffer('alphas_cumprod', alphas_cumprod) |
| self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod)) |
| self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1 - alphas_cumprod)) |
|
|
| def _add_noise(self, x_start, t, noise=None): |
| """Add noise to x_start at timestep t.""" |
| if noise is None: |
| noise = torch.randn_like(x_start) |
| sqrt_alpha = self.sqrt_alphas_cumprod[t].view(-1, 1, 1) |
| sqrt_one_minus = self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1) |
| return sqrt_alpha * x_start + sqrt_one_minus * noise |
|
|
| def _denoise_step(self, x_noisy, cond_features, time_emb): |
| """Single denoising step.""" |
| x = torch.cat([x_noisy, cond_features], dim=1) |
| x = self.denoise_input(x) |
| for block in self.denoise_blocks: |
| x = block(x, time_emb) |
| return self.denoise_output(x) |
|
|
| def forward(self, x): |
| """ |
| Training: returns [initial_pred, denoised_pred] |
| Inference: returns [initial_pred, iteratively_denoised_pred] |
| """ |
| x_in = x.permute(0, 2, 1) |
| B, _, T = x_in.shape |
|
|
| |
| cond = self.condition_encoder(x_in) |
| initial_logits = self.initial_head(cond).permute(0, 2, 1) |
|
|
| if self.training: |
| |
| x_start = F.softmax(initial_logits, dim=-1).permute(0, 2, 1) |
| t = torch.randint(0, self.num_steps, (B,), device=x.device) |
| noise = torch.randn_like(x_start) |
| x_noisy = self._add_noise(x_start.detach(), t, noise) |
| time_emb = self.time_emb(t) |
| denoised = self._denoise_step(x_noisy, cond, time_emb) |
| return [initial_logits, denoised.permute(0, 2, 1)] |
| else: |
| |
| x_t = torch.randn(B, self.num_classes, T, device=x.device) |
| for step in reversed(range(self.num_steps)): |
| t = torch.full((B,), step, device=x.device, dtype=torch.long) |
| time_emb = self.time_emb(t) |
| pred_noise = self._denoise_step(x_t, cond, time_emb) |
| |
| alpha = self.alphas_cumprod[step] |
| alpha_prev = self.alphas_cumprod[step - 1] if step > 0 else torch.tensor(1.0) |
| beta = self.betas[step] |
| x_t = (1 / torch.sqrt(1 - beta)) * ( |
| x_t - beta / self.sqrt_one_minus_alphas_cumprod[step] * pred_noise |
| ) |
| if step > 0: |
| x_t = x_t + torch.sqrt(beta) * torch.randn_like(x_t) * 0.5 |
| return [initial_logits, x_t.permute(0, 2, 1)] |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| class UnderPressureContact(nn.Module): |
| """UnderPressure model adapted for hand contact detection (Exp3). |
| |
| Architecture: Conv feature extractor -> BiGRU -> contact prediction head |
| Input: (B, T, C) |
| Output: (B, T, 2) [right_contact, left_contact] |
| """ |
|
|
| def __init__(self, input_dim, hidden_dim=64, num_gru_layers=2): |
| super().__init__() |
| |
| self.feature_extractor = nn.Sequential( |
| nn.Conv1d(input_dim, hidden_dim, 7, padding=3), |
| nn.BatchNorm1d(hidden_dim), |
| nn.ReLU(), |
| nn.Conv1d(hidden_dim, hidden_dim, 5, padding=2), |
| nn.BatchNorm1d(hidden_dim), |
| nn.ReLU(), |
| ) |
| |
| self.gru = nn.GRU( |
| hidden_dim, hidden_dim, num_layers=num_gru_layers, |
| batch_first=True, bidirectional=True, |
| dropout=0.2 if num_gru_layers > 1 else 0, |
| ) |
| |
| self.contact_head = nn.Sequential( |
| nn.Linear(hidden_dim * 2, hidden_dim), |
| nn.ReLU(), |
| nn.Dropout(0.2), |
| nn.Linear(hidden_dim, 2), |
| ) |
|
|
| def forward(self, x): |
| |
| feat = self.feature_extractor(x.permute(0, 2, 1)) |
| feat = feat.permute(0, 2, 1) |
| gru_out, _ = self.gru(feat) |
| return self.contact_head(gru_out) |
|
|
|
|
| class UnderPressureRegressor(nn.Module): |
| """UnderPressure model adapted for MoCap -> Pressure regression (Exp4a). |
| |
| Architecture: Conv feature extractor -> BiGRU -> pressure regression head |
| Input: (B, T, input_dim) |
| Output: (B, T, output_dim) |
| """ |
|
|
| def __init__(self, input_dim, output_dim, hidden_dim=128, num_gru_layers=2): |
| super().__init__() |
| self.feature_extractor = nn.Sequential( |
| nn.Conv1d(input_dim, hidden_dim, 7, padding=3), |
| nn.BatchNorm1d(hidden_dim), |
| nn.ReLU(), |
| nn.Conv1d(hidden_dim, hidden_dim, 5, padding=2), |
| nn.BatchNorm1d(hidden_dim), |
| nn.ReLU(), |
| nn.Conv1d(hidden_dim, hidden_dim, 3, padding=1), |
| nn.BatchNorm1d(hidden_dim), |
| nn.ReLU(), |
| ) |
| self.gru = nn.GRU( |
| hidden_dim, hidden_dim, num_layers=num_gru_layers, |
| batch_first=True, bidirectional=True, |
| dropout=0.2 if num_gru_layers > 1 else 0, |
| ) |
| self.regression_head = nn.Sequential( |
| nn.Linear(hidden_dim * 2, hidden_dim), |
| nn.ReLU(), |
| nn.Dropout(0.2), |
| nn.Linear(hidden_dim, output_dim), |
| ) |
|
|
| def forward(self, x): |
| feat = self.feature_extractor(x.permute(0, 2, 1)) |
| feat = feat.permute(0, 2, 1) |
| gru_out, _ = self.gru(feat) |
| return self.regression_head(gru_out) |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| class EMG2PoseEncoder(nn.Module): |
| """CNN + Transformer encoder from emg2pose.""" |
|
|
| def __init__(self, input_dim, hidden_dim=128, num_transformer_layers=4, nhead=4): |
| super().__init__() |
| |
| self.conv_small = nn.Sequential( |
| nn.Conv1d(input_dim, hidden_dim // 2, 3, padding=1), |
| nn.BatchNorm1d(hidden_dim // 2), |
| nn.ReLU(), |
| ) |
| self.conv_medium = nn.Sequential( |
| nn.Conv1d(input_dim, hidden_dim // 4, 7, padding=3), |
| nn.BatchNorm1d(hidden_dim // 4), |
| nn.ReLU(), |
| ) |
| self.conv_large = nn.Sequential( |
| nn.Conv1d(input_dim, hidden_dim // 4, 15, padding=7), |
| nn.BatchNorm1d(hidden_dim // 4), |
| nn.ReLU(), |
| ) |
| |
| self.proj = nn.Sequential( |
| nn.Conv1d(hidden_dim, hidden_dim, 1), |
| nn.BatchNorm1d(hidden_dim), |
| nn.ReLU(), |
| ) |
| |
| encoder_layer = nn.TransformerEncoderLayer( |
| d_model=hidden_dim, nhead=nhead, |
| dim_feedforward=hidden_dim * 4, |
| dropout=0.1, batch_first=True, |
| ) |
| self.transformer = nn.TransformerEncoder(encoder_layer, num_transformer_layers) |
|
|
| def forward(self, x): |
| |
| x_t = x.permute(0, 2, 1) |
| f_small = self.conv_small(x_t) |
| f_medium = self.conv_medium(x_t) |
| f_large = self.conv_large(x_t) |
| feat = torch.cat([f_small, f_medium, f_large], dim=1) |
| feat = self.proj(feat).permute(0, 2, 1) |
| return self.transformer(feat) |
|
|
|
|
| class EMG2Pose(nn.Module): |
| """emg2pose model for EMG -> Hand Pose regression (Exp4b). |
| |
| Predicts per-frame hand joint positions from EMG signals. |
| Uses velocity-based integration (vemg2pose variant): |
| predict velocity -> integrate to get positions. |
| |
| Input: (B, T, input_dim) [EMG channels] |
| Output: (B, T, output_dim) [hand joint positions] |
| """ |
|
|
| def __init__(self, input_dim, output_dim, hidden_dim=128, |
| num_transformer_layers=4, use_velocity=True): |
| super().__init__() |
| self.use_velocity = use_velocity |
| self.encoder = EMG2PoseEncoder(input_dim, hidden_dim, num_transformer_layers) |
|
|
| if use_velocity: |
| |
| self.velocity_head = nn.Sequential( |
| nn.Linear(hidden_dim, hidden_dim // 2), |
| nn.ReLU(), |
| nn.Dropout(0.1), |
| nn.Linear(hidden_dim // 2, output_dim), |
| ) |
| |
| self.initial_pos = nn.Parameter(torch.zeros(1, 1, output_dim)) |
| else: |
| |
| self.position_head = nn.Sequential( |
| nn.Linear(hidden_dim, hidden_dim // 2), |
| nn.ReLU(), |
| nn.Dropout(0.1), |
| nn.Linear(hidden_dim // 2, output_dim), |
| ) |
|
|
| def forward(self, x): |
| features = self.encoder(x) |
|
|
| if self.use_velocity: |
| velocity = self.velocity_head(features) |
| |
| positions = torch.cumsum(velocity, dim=1) + self.initial_pos |
| return positions |
| else: |
| return self.position_head(features) |
|
|