|
|
"""
|
|
|
BaramNuri (바람누리) - Lightweight Driver Behavior Detection Model
|
|
|
|
|
|
A hybrid architecture combining:
|
|
|
- Video Swin Transformer (Stage 1-3) for spatial features
|
|
|
- Selective State Space Model (SSM) for temporal modeling
|
|
|
|
|
|
Trained via Knowledge Distillation from Video Swin-T teacher.
|
|
|
|
|
|
Author: C-Team
|
|
|
License: Apache-2.0
|
|
|
"""
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
from torchvision.models.video import swin3d_t, Swin3D_T_Weights
|
|
|
from typing import Dict, Tuple
|
|
|
|
|
|
|
|
|
class SelectiveSSM(nn.Module):
|
|
|
"""
|
|
|
Selective State Space Model (Mamba-style)
|
|
|
|
|
|
Key: Dynamically generates B, C, delta based on input
|
|
|
- Important information is remembered
|
|
|
- Less important information is quickly forgotten
|
|
|
"""
|
|
|
|
|
|
def __init__(self, d_model: int, d_state: int = 16, d_conv: int = 4, expand: int = 2, dropout: float = 0.1):
|
|
|
super().__init__()
|
|
|
|
|
|
self.d_model = d_model
|
|
|
self.d_state = d_state
|
|
|
self.d_conv = d_conv
|
|
|
self.expand = expand
|
|
|
self.d_inner = d_model * expand
|
|
|
|
|
|
|
|
|
self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)
|
|
|
|
|
|
|
|
|
self.conv1d = nn.Conv1d(
|
|
|
self.d_inner, self.d_inner,
|
|
|
kernel_size=d_conv,
|
|
|
padding=d_conv - 1,
|
|
|
groups=self.d_inner
|
|
|
)
|
|
|
|
|
|
|
|
|
self.x_proj = nn.Linear(self.d_inner, d_state * 2 + 1, bias=False)
|
|
|
|
|
|
|
|
|
self.A_log = nn.Parameter(torch.log(torch.arange(1, d_state + 1, dtype=torch.float32)))
|
|
|
self.D = nn.Parameter(torch.ones(self.d_inner))
|
|
|
|
|
|
|
|
|
self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)
|
|
|
|
|
|
self.dropout = nn.Dropout(dropout)
|
|
|
self.layer_norm = nn.LayerNorm(d_model)
|
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
"""
|
|
|
Args:
|
|
|
x: [B, T, D]
|
|
|
Returns:
|
|
|
y: [B, T, D]
|
|
|
"""
|
|
|
residual = x
|
|
|
x = self.layer_norm(x)
|
|
|
|
|
|
B, T, D = x.shape
|
|
|
|
|
|
|
|
|
xz = self.in_proj(x)
|
|
|
x, z = xz.chunk(2, dim=-1)
|
|
|
|
|
|
|
|
|
x = x.transpose(1, 2)
|
|
|
x = self.conv1d(x)[:, :, :T]
|
|
|
x = x.transpose(1, 2)
|
|
|
|
|
|
x = F.silu(x)
|
|
|
|
|
|
|
|
|
x_ssm = self.x_proj(x)
|
|
|
B_t = x_ssm[:, :, :self.d_state]
|
|
|
C_t = x_ssm[:, :, self.d_state:self.d_state*2]
|
|
|
delta = F.softplus(x_ssm[:, :, -1:])
|
|
|
|
|
|
|
|
|
A = -torch.exp(self.A_log)
|
|
|
|
|
|
|
|
|
A_bar = torch.exp(delta * A.view(1, 1, -1))
|
|
|
|
|
|
|
|
|
h = torch.zeros(B, self.d_inner, self.d_state, device=x.device, dtype=x.dtype)
|
|
|
outputs = []
|
|
|
|
|
|
for t in range(T):
|
|
|
x_t = x[:, t, :]
|
|
|
B_t_t = B_t[:, t, :]
|
|
|
C_t_t = C_t[:, t, :]
|
|
|
A_bar_t = A_bar[:, t, :]
|
|
|
|
|
|
|
|
|
h = h * A_bar_t.unsqueeze(1) + B_t_t.unsqueeze(1) * x_t.unsqueeze(-1)
|
|
|
|
|
|
|
|
|
y_t = (C_t_t.unsqueeze(1) * h).sum(dim=-1) + self.D * x_t
|
|
|
outputs.append(y_t)
|
|
|
|
|
|
y = torch.stack(outputs, dim=1)
|
|
|
|
|
|
|
|
|
y = y * F.silu(z)
|
|
|
|
|
|
|
|
|
y = self.out_proj(y)
|
|
|
y = self.dropout(y)
|
|
|
|
|
|
return y + residual
|
|
|
|
|
|
|
|
|
class TemporalSSMBlock(nn.Module):
|
|
|
"""
|
|
|
Temporal SSM Block for video
|
|
|
|
|
|
Takes [B, T, C] sequence and applies SSM layers
|
|
|
"""
|
|
|
|
|
|
def __init__(self, d_model: int, d_state: int = 16, n_layers: int = 2, dropout: float = 0.1):
|
|
|
super().__init__()
|
|
|
|
|
|
self.ssm_layers = nn.ModuleList([
|
|
|
SelectiveSSM(d_model, d_state=d_state, dropout=dropout)
|
|
|
for _ in range(n_layers)
|
|
|
])
|
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
"""
|
|
|
Args:
|
|
|
x: [B, T, D] sequence
|
|
|
Returns:
|
|
|
y: [B, D] final representation
|
|
|
"""
|
|
|
for ssm in self.ssm_layers:
|
|
|
x = ssm(x)
|
|
|
|
|
|
return x.mean(dim=1)
|
|
|
|
|
|
|
|
|
class BaramNuri(nn.Module):
|
|
|
"""
|
|
|
BaramNuri (바람누리) - Lightweight Driver Behavior Detection Model
|
|
|
|
|
|
Architecture:
|
|
|
1. Video Swin-T Stages 1-3 (spatial features, 384 dim)
|
|
|
2. Selective SSM Block (temporal modeling)
|
|
|
3. Classification Head
|
|
|
|
|
|
Parameters: 14.20M (49% reduction from teacher's 27.86M)
|
|
|
Performance: 96.17% accuracy, 0.9504 Macro F1
|
|
|
"""
|
|
|
|
|
|
CLASS_NAMES = ["정상", "졸음운전", "물건찾기", "휴대폰 사용", "운전자 폭행"]
|
|
|
CLASS_NAMES_EN = ["normal", "drowsy_driving", "searching_object", "phone_usage", "driver_assault"]
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
num_classes: int = 5,
|
|
|
pretrained: bool = True,
|
|
|
d_state: int = 16,
|
|
|
ssm_layers: int = 2,
|
|
|
dropout: float = 0.2,
|
|
|
):
|
|
|
super().__init__()
|
|
|
|
|
|
self.num_classes = num_classes
|
|
|
|
|
|
|
|
|
if pretrained:
|
|
|
print("Loading Swin backbone (Kinetics-400 pretrained)...")
|
|
|
full_swin = swin3d_t(weights=Swin3D_T_Weights.KINETICS400_V1)
|
|
|
else:
|
|
|
full_swin = swin3d_t(weights=None)
|
|
|
|
|
|
|
|
|
self.patch_embed = full_swin.patch_embed
|
|
|
|
|
|
|
|
|
self.features = nn.Sequential(*[full_swin.features[i] for i in range(5)])
|
|
|
|
|
|
|
|
|
self.feature_dim = 384
|
|
|
|
|
|
|
|
|
self.avgpool = nn.AdaptiveAvgPool3d(output_size=1)
|
|
|
|
|
|
|
|
|
self.temporal_ssm = TemporalSSMBlock(
|
|
|
d_model=self.feature_dim,
|
|
|
d_state=d_state,
|
|
|
n_layers=ssm_layers,
|
|
|
dropout=dropout,
|
|
|
)
|
|
|
|
|
|
|
|
|
self.head = nn.Sequential(
|
|
|
nn.LayerNorm(self.feature_dim),
|
|
|
nn.Dropout(p=dropout),
|
|
|
nn.Linear(self.feature_dim, num_classes),
|
|
|
)
|
|
|
|
|
|
|
|
|
self._init_head()
|
|
|
|
|
|
|
|
|
del full_swin
|
|
|
|
|
|
def _init_head(self):
|
|
|
"""Initialize head weights"""
|
|
|
for m in self.head.modules():
|
|
|
if isinstance(m, nn.Linear):
|
|
|
nn.init.trunc_normal_(m.weight, std=0.02)
|
|
|
if m.bias is not None:
|
|
|
nn.init.zeros_(m.bias)
|
|
|
|
|
|
def extract_features(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
"""
|
|
|
Extract features (for knowledge distillation)
|
|
|
|
|
|
Args:
|
|
|
x: [B, C, T, H, W]
|
|
|
Returns:
|
|
|
features: [B, feature_dim]
|
|
|
"""
|
|
|
|
|
|
x = self.patch_embed(x)
|
|
|
|
|
|
|
|
|
x = self.features(x)
|
|
|
|
|
|
B, T, H, W, C = x.shape
|
|
|
|
|
|
|
|
|
x = x.mean(dim=[2, 3])
|
|
|
|
|
|
|
|
|
x = self.temporal_ssm(x)
|
|
|
|
|
|
return x
|
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
"""
|
|
|
Forward pass
|
|
|
|
|
|
Args:
|
|
|
x: [B, C, T, H, W] video tensor
|
|
|
Returns:
|
|
|
logits: [B, num_classes]
|
|
|
"""
|
|
|
features = self.extract_features(x)
|
|
|
logits = self.head(features)
|
|
|
return logits
|
|
|
|
|
|
def forward_with_features(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
"""
|
|
|
Return both features and logits (for knowledge distillation)
|
|
|
"""
|
|
|
features = self.extract_features(x)
|
|
|
logits = self.head(features)
|
|
|
return logits, features
|
|
|
|
|
|
def predict(self, x: torch.Tensor, return_english: bool = False) -> Dict:
|
|
|
"""
|
|
|
Inference prediction
|
|
|
|
|
|
Args:
|
|
|
x: [1, C, T, H, W] single video
|
|
|
return_english: Return English class names
|
|
|
Returns:
|
|
|
dict with class, confidence, class_name
|
|
|
"""
|
|
|
self.eval()
|
|
|
with torch.no_grad():
|
|
|
logits = self.forward(x)
|
|
|
probs = F.softmax(logits, dim=-1)[0]
|
|
|
class_idx = probs.argmax().item()
|
|
|
|
|
|
class_names = self.CLASS_NAMES_EN if return_english else self.CLASS_NAMES
|
|
|
|
|
|
return {
|
|
|
"class": class_idx,
|
|
|
"confidence": probs[class_idx].item(),
|
|
|
"class_name": class_names[class_idx],
|
|
|
"all_probs": {
|
|
|
name: probs[i].item()
|
|
|
for i, name in enumerate(class_names)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@classmethod
|
|
|
def from_pretrained(cls, checkpoint_path: str, device: str = 'cpu'):
|
|
|
"""
|
|
|
Load pretrained model from checkpoint
|
|
|
|
|
|
Args:
|
|
|
checkpoint_path: Path to .pth file
|
|
|
device: 'cpu' or 'cuda'
|
|
|
Returns:
|
|
|
Loaded model in eval mode
|
|
|
"""
|
|
|
model = cls(num_classes=5, pretrained=True)
|
|
|
checkpoint = torch.load(checkpoint_path, map_location=device)
|
|
|
|
|
|
if 'model_state_dict' in checkpoint:
|
|
|
model.load_state_dict(checkpoint['model_state_dict'])
|
|
|
else:
|
|
|
model.load_state_dict(checkpoint)
|
|
|
|
|
|
model = model.to(device)
|
|
|
model.eval()
|
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
def count_parameters(model: nn.Module) -> int:
|
|
|
"""Count total model parameters"""
|
|
|
return sum(p.numel() for p in model.parameters())
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
print("=" * 60)
|
|
|
print("BaramNuri Model Test")
|
|
|
print("=" * 60)
|
|
|
|
|
|
|
|
|
model = BaramNuri(num_classes=5, pretrained=True)
|
|
|
|
|
|
|
|
|
total_params = count_parameters(model)
|
|
|
print(f"\nTotal parameters: {total_params:,} ({total_params/1e6:.2f}M)")
|
|
|
|
|
|
|
|
|
dummy_input = torch.randn(2, 3, 30, 224, 224)
|
|
|
print(f"\nInput shape: {dummy_input.shape}")
|
|
|
|
|
|
|
|
|
model.eval()
|
|
|
with torch.no_grad():
|
|
|
output = model(dummy_input)
|
|
|
print(f"Output shape: {output.shape}")
|
|
|
|
|
|
|
|
|
single_input = torch.randn(1, 3, 30, 224, 224)
|
|
|
prediction = model.predict(single_input)
|
|
|
print(f"\nPrediction (Korean): {prediction['class_name']} ({prediction['confidence']:.2%})")
|
|
|
|
|
|
prediction_en = model.predict(single_input, return_english=True)
|
|
|
print(f"Prediction (English): {prediction_en['class_name']} ({prediction_en['confidence']:.2%})")
|
|
|
|
|
|
print("\nModel test passed!")
|
|
|
|