| """ |
| Session Tracker Module |
| |
| Manages and tracks all network sessions across the virtual ISP stack: |
| - Unified session management across all modules |
| - Session lifecycle tracking |
| - Performance metrics and analytics |
| - Session correlation and debugging |
| """ |
|
|
| import time |
| import threading |
| import uuid |
| from typing import Dict, List, Optional, Set, Any, Tuple |
| from dataclasses import dataclass, field |
| from enum import Enum |
| import json |
|
|
| from .dhcp_server import DHCPLease |
| from .nat_engine import NATSession |
| from .tcp_engine import TCPConnection |
| from .socket_translator import SocketConnection |
|
|
|
|
| class SessionType(Enum): |
| DHCP_LEASE = "DHCP_LEASE" |
| NAT_SESSION = "NAT_SESSION" |
| TCP_CONNECTION = "TCP_CONNECTION" |
| SOCKET_CONNECTION = "SOCKET_CONNECTION" |
| BRIDGE_CLIENT = "BRIDGE_CLIENT" |
|
|
|
|
| class SessionState(Enum): |
| INITIALIZING = "INITIALIZING" |
| ACTIVE = "ACTIVE" |
| IDLE = "IDLE" |
| CLOSING = "CLOSING" |
| CLOSED = "CLOSED" |
| ERROR = "ERROR" |
|
|
|
|
| @dataclass |
| class SessionMetrics: |
| """Session performance metrics""" |
| bytes_in: int = 0 |
| bytes_out: int = 0 |
| packets_in: int = 0 |
| packets_out: int = 0 |
| errors: int = 0 |
| retransmits: int = 0 |
| rtt_samples: List[float] = field(default_factory=list) |
| |
| @property |
| def total_bytes(self) -> int: |
| return self.bytes_in + self.bytes_out |
| |
| @property |
| def total_packets(self) -> int: |
| return self.packets_in + self.packets_out |
| |
| @property |
| def average_rtt(self) -> float: |
| return sum(self.rtt_samples) / len(self.rtt_samples) if self.rtt_samples else 0.0 |
| |
| def update_bytes(self, bytes_in: int = 0, bytes_out: int = 0): |
| """Update byte counters""" |
| self.bytes_in += bytes_in |
| self.bytes_out += bytes_out |
| |
| def update_packets(self, packets_in: int = 0, packets_out: int = 0): |
| """Update packet counters""" |
| self.packets_in += packets_in |
| self.packets_out += packets_out |
| |
| def add_rtt_sample(self, rtt: float): |
| """Add RTT sample""" |
| self.rtt_samples.append(rtt) |
| |
| if len(self.rtt_samples) > 100: |
| self.rtt_samples = self.rtt_samples[-100:] |
| |
| def to_dict(self) -> Dict: |
| """Convert to dictionary""" |
| return { |
| 'bytes_in': self.bytes_in, |
| 'bytes_out': self.bytes_out, |
| 'packets_in': self.packets_in, |
| 'packets_out': self.packets_out, |
| 'total_bytes': self.total_bytes, |
| 'total_packets': self.total_packets, |
| 'errors': self.errors, |
| 'retransmits': self.retransmits, |
| 'average_rtt': self.average_rtt, |
| 'rtt_samples_count': len(self.rtt_samples) |
| } |
|
|
|
|
| @dataclass |
| class UnifiedSession: |
| """Unified session representation""" |
| session_id: str |
| session_type: SessionType |
| state: SessionState |
| created_time: float |
| last_activity: float |
| |
| |
| virtual_ip: Optional[str] = None |
| virtual_port: Optional[int] = None |
| real_ip: Optional[str] = None |
| real_port: Optional[int] = None |
| protocol: Optional[str] = None |
| |
| |
| related_sessions: Set[str] = field(default_factory=set) |
| parent_session: Optional[str] = None |
| child_sessions: Set[str] = field(default_factory=set) |
| |
| |
| metrics: SessionMetrics = field(default_factory=SessionMetrics) |
| |
| |
| metadata: Dict[str, Any] = field(default_factory=dict) |
| |
| def __post_init__(self): |
| if not self.session_id: |
| self.session_id = str(uuid.uuid4()) |
| if self.created_time == 0: |
| self.created_time = time.time() |
| if self.last_activity == 0: |
| self.last_activity = time.time() |
| |
| def update_activity(self): |
| """Update last activity timestamp""" |
| self.last_activity = time.time() |
| |
| def add_related_session(self, session_id: str): |
| """Add related session""" |
| self.related_sessions.add(session_id) |
| |
| def add_child_session(self, session_id: str): |
| """Add child session""" |
| self.child_sessions.add(session_id) |
| |
| def set_parent_session(self, session_id: str): |
| """Set parent session""" |
| self.parent_session = session_id |
| |
| @property |
| def duration(self) -> float: |
| """Get session duration in seconds""" |
| return time.time() - self.created_time |
| |
| @property |
| def idle_time(self) -> float: |
| """Get idle time in seconds""" |
| return time.time() - self.last_activity |
| |
| def to_dict(self) -> Dict: |
| """Convert to dictionary""" |
| return { |
| 'session_id': self.session_id, |
| 'session_type': self.session_type.value, |
| 'state': self.state.value, |
| 'created_time': self.created_time, |
| 'last_activity': self.last_activity, |
| 'duration': self.duration, |
| 'idle_time': self.idle_time, |
| 'virtual_ip': self.virtual_ip, |
| 'virtual_port': self.virtual_port, |
| 'real_ip': self.real_ip, |
| 'real_port': self.real_port, |
| 'protocol': self.protocol, |
| 'related_sessions': list(self.related_sessions), |
| 'parent_session': self.parent_session, |
| 'child_sessions': list(self.child_sessions), |
| 'metrics': self.metrics.to_dict(), |
| 'metadata': self.metadata |
| } |
|
|
|
|
| class SessionTracker: |
| """Unified session tracker""" |
| |
| def __init__(self, config: Dict): |
| self.config = config |
| self.sessions: Dict[str, UnifiedSession] = {} |
| self.session_index: Dict[Tuple[str, str], Set[str]] = {} |
| self.lock = threading.Lock() |
| |
| |
| self.max_sessions = config.get('max_sessions', 10000) |
| self.session_timeout = config.get('session_timeout', 3600) |
| self.cleanup_interval = config.get('cleanup_interval', 300) |
| self.metrics_retention = config.get('metrics_retention', 86400) |
| |
| |
| self.stats = { |
| 'total_sessions': 0, |
| 'active_sessions': 0, |
| 'expired_sessions': 0, |
| 'session_types': {t.value: 0 for t in SessionType}, |
| 'session_states': {s.value: 0 for s in SessionState}, |
| 'cleanup_runs': 0, |
| 'correlations_created': 0 |
| } |
| |
| |
| self.running = False |
| self.cleanup_thread = None |
| |
| def _generate_session_key(self, session_type: SessionType, **kwargs) -> str: |
| """Generate session key for indexing""" |
| if session_type == SessionType.DHCP_LEASE: |
| return f"dhcp_{kwargs.get('mac_address', 'unknown')}" |
| elif session_type == SessionType.NAT_SESSION: |
| return f"nat_{kwargs.get('virtual_ip', '')}_{kwargs.get('virtual_port', 0)}_{kwargs.get('protocol', '')}" |
| elif session_type == SessionType.TCP_CONNECTION: |
| return f"tcp_{kwargs.get('local_ip', '')}_{kwargs.get('local_port', 0)}_{kwargs.get('remote_ip', '')}_{kwargs.get('remote_port', 0)}" |
| elif session_type == SessionType.SOCKET_CONNECTION: |
| return f"socket_{kwargs.get('connection_id', 'unknown')}" |
| elif session_type == SessionType.BRIDGE_CLIENT: |
| return f"bridge_{kwargs.get('client_id', 'unknown')}" |
| else: |
| return f"unknown_{time.time()}" |
| |
| def _add_to_index(self, session: UnifiedSession): |
| """Add session to search index""" |
| |
| type_key = (session.session_type.value, 'all') |
| if type_key not in self.session_index: |
| self.session_index[type_key] = set() |
| self.session_index[type_key].add(session.session_id) |
| |
| |
| if session.virtual_ip: |
| ip_key = ('virtual_ip', session.virtual_ip) |
| if ip_key not in self.session_index: |
| self.session_index[ip_key] = set() |
| self.session_index[ip_key].add(session.session_id) |
| |
| if session.real_ip: |
| ip_key = ('real_ip', session.real_ip) |
| if ip_key not in self.session_index: |
| self.session_index[ip_key] = set() |
| self.session_index[ip_key].add(session.session_id) |
| |
| |
| if session.protocol: |
| proto_key = ('protocol', session.protocol) |
| if proto_key not in self.session_index: |
| self.session_index[proto_key] = set() |
| self.session_index[proto_key].add(session.session_id) |
| |
| def _remove_from_index(self, session: UnifiedSession): |
| """Remove session from search index""" |
| for key, session_set in self.session_index.items(): |
| session_set.discard(session.session_id) |
| |
| def create_session(self, session_type: SessionType, **kwargs) -> str: |
| """Create new session""" |
| with self.lock: |
| |
| if len(self.sessions) >= self.max_sessions: |
| |
| self._cleanup_expired_sessions() |
| if len(self.sessions) >= self.max_sessions: |
| return None |
| |
| |
| session = UnifiedSession( |
| session_id=kwargs.get('session_id', str(uuid.uuid4())), |
| session_type=session_type, |
| state=SessionState.INITIALIZING, |
| virtual_ip=kwargs.get('virtual_ip'), |
| virtual_port=kwargs.get('virtual_port'), |
| real_ip=kwargs.get('real_ip'), |
| real_port=kwargs.get('real_port'), |
| protocol=kwargs.get('protocol'), |
| metadata=kwargs.get('metadata', {}) |
| ) |
| |
| |
| self.sessions[session.session_id] = session |
| self._add_to_index(session) |
| |
| |
| self.stats['total_sessions'] += 1 |
| self.stats['active_sessions'] = len(self.sessions) |
| self.stats['session_types'][session_type.value] += 1 |
| self.stats['session_states'][SessionState.INITIALIZING.value] += 1 |
| |
| return session.session_id |
| |
| def update_session(self, session_id: str, **kwargs) -> bool: |
| """Update session""" |
| with self.lock: |
| session = self.sessions.get(session_id) |
| if not session: |
| return False |
| |
| |
| old_state = session.state |
| |
| for key, value in kwargs.items(): |
| if hasattr(session, key): |
| setattr(session, key, value) |
| |
| session.update_activity() |
| |
| |
| if 'state' in kwargs and kwargs['state'] != old_state: |
| self.stats['session_states'][old_state.value] -= 1 |
| self.stats['session_states'][kwargs['state'].value] += 1 |
| |
| return True |
| |
| def close_session(self, session_id: str, reason: str = "") -> bool: |
| """Close session""" |
| with self.lock: |
| session = self.sessions.get(session_id) |
| if not session: |
| return False |
| |
| old_state = session.state |
| session.state = SessionState.CLOSED |
| session.update_activity() |
| |
| if reason: |
| session.metadata['close_reason'] = reason |
| |
| |
| self.stats['session_states'][old_state.value] -= 1 |
| self.stats['session_states'][SessionState.CLOSED.value] += 1 |
| |
| return True |
| |
| def remove_session(self, session_id: str) -> bool: |
| """Remove session completely""" |
| with self.lock: |
| session = self.sessions.get(session_id) |
| if not session: |
| return False |
| |
| |
| self._remove_from_index(session) |
| |
| |
| del self.sessions[session_id] |
| |
| |
| self.stats['active_sessions'] = len(self.sessions) |
| self.stats['session_types'][session.session_type.value] -= 1 |
| self.stats['session_states'][session.state.value] -= 1 |
| |
| return True |
| |
| def get_session(self, session_id: str) -> Optional[UnifiedSession]: |
| """Get session by ID""" |
| with self.lock: |
| return self.sessions.get(session_id) |
| |
| def find_sessions(self, **criteria) -> List[UnifiedSession]: |
| """Find sessions by criteria""" |
| with self.lock: |
| matching_sessions = [] |
| |
| |
| if 'session_type' in criteria: |
| type_key = (criteria['session_type'].value if isinstance(criteria['session_type'], SessionType) else criteria['session_type'], 'all') |
| candidate_ids = self.session_index.get(type_key, set()) |
| elif 'virtual_ip' in criteria: |
| ip_key = ('virtual_ip', criteria['virtual_ip']) |
| candidate_ids = self.session_index.get(ip_key, set()) |
| elif 'real_ip' in criteria: |
| ip_key = ('real_ip', criteria['real_ip']) |
| candidate_ids = self.session_index.get(ip_key, set()) |
| elif 'protocol' in criteria: |
| proto_key = ('protocol', criteria['protocol']) |
| candidate_ids = self.session_index.get(proto_key, set()) |
| else: |
| candidate_ids = set(self.sessions.keys()) |
| |
| |
| for session_id in candidate_ids: |
| session = self.sessions.get(session_id) |
| if not session: |
| continue |
| |
| match = True |
| for key, value in criteria.items(): |
| if hasattr(session, key): |
| session_value = getattr(session, key) |
| if isinstance(value, (SessionType, SessionState)): |
| if session_value != value: |
| match = False |
| break |
| elif session_value != value: |
| match = False |
| break |
| else: |
| match = False |
| break |
| |
| if match: |
| matching_sessions.append(session) |
| |
| return matching_sessions |
| |
| def correlate_sessions(self, session_id1: str, session_id2: str, relationship: str = 'related') -> bool: |
| """Create correlation between sessions""" |
| with self.lock: |
| session1 = self.sessions.get(session_id1) |
| session2 = self.sessions.get(session_id2) |
| |
| if not session1 or not session2: |
| return False |
| |
| if relationship == 'parent_child': |
| session1.add_child_session(session_id2) |
| session2.set_parent_session(session_id1) |
| else: |
| session1.add_related_session(session_id2) |
| session2.add_related_session(session_id1) |
| |
| self.stats['correlations_created'] += 1 |
| return True |
| |
| def update_metrics(self, session_id: str, **metrics) -> bool: |
| """Update session metrics""" |
| with self.lock: |
| session = self.sessions.get(session_id) |
| if not session: |
| return False |
| |
| session.update_activity() |
| |
| |
| if 'bytes_in' in metrics or 'bytes_out' in metrics: |
| session.metrics.update_bytes( |
| metrics.get('bytes_in', 0), |
| metrics.get('bytes_out', 0) |
| ) |
| |
| if 'packets_in' in metrics or 'packets_out' in metrics: |
| session.metrics.update_packets( |
| metrics.get('packets_in', 0), |
| metrics.get('packets_out', 0) |
| ) |
| |
| if 'rtt' in metrics: |
| session.metrics.add_rtt_sample(metrics['rtt']) |
| |
| if 'errors' in metrics: |
| session.metrics.errors += metrics['errors'] |
| |
| if 'retransmits' in metrics: |
| session.metrics.retransmits += metrics['retransmits'] |
| |
| return True |
| |
| def _cleanup_expired_sessions(self): |
| """Clean up expired sessions""" |
| current_time = time.time() |
| expired_sessions = [] |
| |
| for session_id, session in self.sessions.items(): |
| |
| if (session.state == SessionState.CLOSED and |
| current_time - session.last_activity > self.cleanup_interval): |
| expired_sessions.append(session_id) |
| elif (session.state != SessionState.CLOSED and |
| current_time - session.last_activity > self.session_timeout): |
| expired_sessions.append(session_id) |
| |
| |
| for session_id in expired_sessions: |
| self.remove_session(session_id) |
| self.stats['expired_sessions'] += 1 |
| |
| def _cleanup_loop(self): |
| """Background cleanup loop""" |
| while self.running: |
| try: |
| with self.lock: |
| self._cleanup_expired_sessions() |
| self.stats['cleanup_runs'] += 1 |
| |
| time.sleep(self.cleanup_interval) |
| |
| except Exception as e: |
| print(f"Session tracker cleanup error: {e}") |
| time.sleep(60) |
| |
| def get_sessions(self, limit: int = 100, offset: int = 0, **filters) -> List[Dict]: |
| """Get sessions with pagination and filtering""" |
| with self.lock: |
| if filters: |
| sessions = self.find_sessions(**filters) |
| else: |
| sessions = list(self.sessions.values()) |
| |
| |
| sessions.sort(key=lambda s: s.last_activity, reverse=True) |
| |
| |
| paginated_sessions = sessions[offset:offset + limit] |
| |
| return [session.to_dict() for session in paginated_sessions] |
| |
| def get_session_summary(self) -> Dict: |
| """Get session summary statistics""" |
| with self.lock: |
| summary = { |
| 'total_sessions': len(self.sessions), |
| 'by_type': {}, |
| 'by_state': {}, |
| 'by_protocol': {}, |
| 'active_sessions_by_age': { |
| 'last_hour': 0, |
| 'last_day': 0, |
| 'older': 0 |
| } |
| } |
| |
| current_time = time.time() |
| hour_ago = current_time - 3600 |
| day_ago = current_time - 86400 |
| |
| for session in self.sessions.values(): |
| |
| session_type = session.session_type.value |
| summary['by_type'][session_type] = summary['by_type'].get(session_type, 0) + 1 |
| |
| |
| session_state = session.state.value |
| summary['by_state'][session_state] = summary['by_state'].get(session_state, 0) + 1 |
| |
| |
| if session.protocol: |
| summary['by_protocol'][session.protocol] = summary['by_protocol'].get(session.protocol, 0) + 1 |
| |
| |
| if session.last_activity > hour_ago: |
| summary['active_sessions_by_age']['last_hour'] += 1 |
| elif session.last_activity > day_ago: |
| summary['active_sessions_by_age']['last_day'] += 1 |
| else: |
| summary['active_sessions_by_age']['older'] += 1 |
| |
| return summary |
| |
| def get_stats(self) -> Dict: |
| """Get tracker statistics""" |
| with self.lock: |
| stats = self.stats.copy() |
| stats['active_sessions'] = len(self.sessions) |
| |
| return stats |
| |
| def reset_stats(self): |
| """Reset statistics""" |
| self.stats = { |
| 'total_sessions': len(self.sessions), |
| 'active_sessions': len(self.sessions), |
| 'expired_sessions': 0, |
| 'session_types': {t.value: 0 for t in SessionType}, |
| 'session_states': {s.value: 0 for s in SessionState}, |
| 'cleanup_runs': 0, |
| 'correlations_created': 0 |
| } |
| |
| |
| with self.lock: |
| for session in self.sessions.values(): |
| self.stats['session_types'][session.session_type.value] += 1 |
| self.stats['session_states'][session.state.value] += 1 |
| |
| def export_sessions(self, format: str = 'json') -> str: |
| """Export sessions data""" |
| with self.lock: |
| sessions_data = [session.to_dict() for session in self.sessions.values()] |
| |
| if format == 'json': |
| return json.dumps(sessions_data, indent=2, default=str) |
| else: |
| raise ValueError(f"Unsupported export format: {format}") |
| |
| def start(self): |
| """Start session tracker""" |
| self.running = True |
| self.cleanup_thread = threading.Thread(target=self._cleanup_loop, daemon=True) |
| self.cleanup_thread.start() |
| print("Session tracker started") |
| |
| def stop(self): |
| """Stop session tracker""" |
| self.running = False |
| if self.cleanup_thread: |
| self.cleanup_thread.join() |
| print("Session tracker stopped") |
|
|
|
|